#! /usr/bin/python

# PSPP - a program for statistical analysis.
# Copyright (C) 2017, 2018, 2019 Free Software Foundation, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import getopt
import os
import re
import struct
import sys

n_errors = 0

def error(msg):
    global n_errors
    sys.stderr.write("%s:%d: %s\n" % (file_name, line_number, msg))
    n_errors += 1


def fatal(msg):
    error(msg)
    sys.exit(1)


def get_line():
    global line
    global line_number
    line = input_file.readline()
    line = re.sub('#.*', '\n', line)
    line_number += 1


def expect(type):
    if token[0] != type:
        fatal("syntax error expecting %s" % type)


def match(type):
    if token[0] == type:
        get_token()
        return True
    else:
        return False


def must_match(type):
    expect(type)
    get_token()


def match_id(id_):
    if token == ('id', id_):
        get_token()
        return True
    else:
        return False


def is_idchar(c):
    return c.isalnum() or c in '-_'


def get_token():
    global token
    global line
    prev = token
    while True:
        if line == "":
            if token == ('eof', ):
                fatal("unexpected end of input")
            get_line()
            if not line:
                token = ('eof', )
                break
            elif line == '\n':
                token = (';', )
                break

        line = line.lstrip()
        if line == "":
            continue

        if line.startswith('=>'):
            token = (line[:2],)
            line = line[2:]
        elif line[0] in '[]()?|*+=:':
            token = (line[0],)
            line = line[1:]
        elif is_idchar(line[0]):
            n = 1
            while n < len(line) and is_idchar(line[n]):
                n += 1
            s = line[:n]
            token = ('id', s)
            line = line[n:]
        else:
            fatal("unknown character %c" % line[0])
        break


def usage():
    argv0 = os.path.basename(sys.argv[0])
    print('''\
%(argv0)s, parser generator for SPV XML members
usage: %(argv0)s GRAMMAR header PREFIX
       %(argv0)s GRAMMAR code PREFIX HEADER_NAME
  where GRAMMAR contains grammar definitions\
''' % {"argv0": argv0})
    sys.exit(0)


def parse_term():
    if match('('):
        sub = parse_alternation()
        must_match(')')
        return sub
    else:
        member_name, nonterminal_name = parse_name()
        if member_name.isupper():
            fatal('%s; unknown terminal' % member_name)
        else:
            return {'type': 'nonterminal',
                    'nonterminal_name': nonterminal_name,
                    'member_name': member_name}


def parse_quantified():
    item = parse_term()
    if token[0] in ['*', '+', '?']:
        item = {'type': token[0], 'item': item}
        get_token()
    return item


def parse_sequence():
    if match_id('EMPTY'):
        return {'type': 'empty'}
    items = []
    while True:
        sub = parse_quantified()
        if sub['type'] == 'sequence':
            items.extend(sub[1:])
        else:
            items.append(sub)
        if token[0] in ('|', ';', ')', 'eof'):
            break
    return {'type': 'sequence', 'items': items} if len(items) > 1 else items[0]


def parse_alternation():
    items = [parse_sequence()]
    while match('|'):
        items.append(parse_sequence())
    if len(items) > 1:
        return {'type': '|', 'items': items}
    else:
        return items[0]


def parse_name():
    # The name used in XML for the attribute or element always comes
    # first.
    expect('id')
    xml_name = token[1]
    get_token()

    # If a different name is needed to disambiguate when the same name
    # is used in different contexts in XML, it comes later, in
    # brackets.
    if match('['):
        expect('id')
        unique_name = token[1]
        get_token()
        must_match(']')
    else:
        unique_name = xml_name

    return unique_name, xml_name


enums = {}
def parse_production():
    unique_name, xml_name = parse_name()

    attr_xml_names = set()
    attributes = {}
    while match(':'):
        attr_unique_name, attr_xml_name = parse_name()
        if match('='):
            if match('('):
                attr_value = set()
                while not match(')'):
                    expect('id')
                    attr_value.add(token[1])
                    get_token()
                    match('|')

                global enums
                if attr_unique_name not in enums:
                    enums[attr_unique_name] = attr_value
                elif enums[attr_unique_name] != attr_value:
                    sys.stderr.write('%s: different enums with same name\n'
                                     % attr_unique_name)
                    sys.exit(1)
            elif match_id('bool'):
                attr_value = set(('true', 'false'))
            elif match_id('dimension'):
                attr_value = 'dimension'
            elif match_id('real'):
                attr_value = 'real'
            elif match_id('int'):
                attr_value = 'int'
            elif match_id('color'):
                attr_value = 'color'
            elif match_id('ref'):
                if token[0] == 'id':
                    ref_type = token[1]
                    attr_value = ('ref', ref_type)
                    get_token()
                elif match('('):
                    ref_types = set()
                    while not match(')'):
                        expect('id')
                        ref_types.add(token[1])
                        get_token()
                        match('|')
                    attr_value = ('ref', ref_types)
                else:
                    attr_value = ('ref', None)
            else:
                fatal("unknown attribute value type")
        else:
            attr_value = 'string'
        attr_required = not match('?')

        if attr_xml_name == 'id':
            if attr_value != 'string':
                fatal("id attribute must have string type")
            attr_value = 'id'

        if attr_unique_name in attributes:
            fatal("production %s has two attributes %s" % (unique_name,
                                                           attr_unique_name))
        if attr_xml_name in attr_xml_names:
            fatal("production %s has two attributes %s" % (unique_name,
                                                           attr_xml_name))
        attr_xml_names.add(attr_xml_name)
        attributes[attr_unique_name] = (attr_xml_name,
                                        attr_value, attr_required)
    if 'id' not in attributes:
        attributes["id"] = ('id', 'id', False)

    must_match('=>')

    if match_id('TEXT'):
        rhs = {'type': 'text'}
    elif match_id('ETC'):
        rhs = {'type': 'etc'}
    else:
        rhs = parse_alternation()

    n = 0
    for a in rhs['items'] if rhs['type'] == '|' else (rhs,):
        for term in a['items'] if a['type'] == 'sequence' else (a,):
            if term['type'] == 'empty':
                pass
            elif term['type'] == 'nonterminal':
                pass
            elif term['type'] == '?' and term['item']['type'] == 'nonterminal':
                pass
            elif (term['type'] in ('*', '+')
                  and term['item']['type'] == 'nonterminal'):
                pass
            else:
                n += 1
                term['seq_name'] = 'seq' if n == 1 else 'seq%d' % n

    return unique_name, xml_name, attributes, rhs


used_enums = set()
def print_members(attributes, rhs, indent):
    attrs = []
    new_enums = set()
    for unique_name, (xml_name, value, required) in sorted(attributes.items()):
        c_name = name_to_id(unique_name)
        if type(value) is set:
            if len(value) <= 1:
                if not required:
                    attrs += [('bool %s_present;' % c_name,
                               'True if attribute present')]
            elif value == set(('true', 'false')):
                if required:
                    attrs += [('bool %s;' % c_name, None)]
                else:
                    attrs += [('int %s;' % c_name,
                               '-1 if not present, otherwise 0 or 1')]
            else:
                attrs += [('enum %s%s %s;' % (prefix, c_name, c_name),
                           'Always nonzero' if required else
                           'Zero if not present')]

                global used_enums
                if unique_name not in used_enums:
                    new_enums.add(unique_name)
        elif value == 'dimension' or value == 'real':
            attrs += [('double %s;' % c_name,
                       'In inches.  ' + ('Always present' if required else
                                         'DBL_MAX if not present'))]
        elif value == 'int':
            attrs += [('int %s;' % c_name,
                       'Always present' if required
                       else 'INT_MIN if not present')]
        elif value == 'color':
            attrs += [('int %s;' % c_name,
                       'Always present' if required
                       else '-1 if not present')]
        elif value == 'string':
            attrs += [('char *%s;' % c_name,
                       'Always nonnull' if required else 'Possibly null')]
        elif value[0] == 'ref':
            struct = ('spvxml_node'
                      if value[1] is None or type(value[1]) is set
                      else '%s%s' % (prefix, name_to_id(value[1])))
            attrs += [('struct %s *%s;' % (struct, c_name),
                       'Always nonnull' if required else 'Possibly null')]
        elif value == 'id':
            pass
        else:
            assert False

    for enum_name in sorted(new_enums):
        used_enums.add(enum_name)
        c_name = name_to_id(enum_name)
        print('\nenum %s%s {' % (prefix, c_name))
        i = 0
        for value in sorted(enums[enum_name]):
            print('    %s%s_%s%s,' % (prefix.upper(),
                                      c_name.upper(),
                                      name_to_id(value).upper(),
                                      ' = 1' if i == 0 else ''))
            i += 1
        print('};')
        print('const char *%s%s_to_string (enum %s%s);' % (
            prefix, c_name, prefix, c_name))

    print('\nstruct %s%s {' % (prefix, name_to_id(name)))
    print('%sstruct spvxml_node node_;' % indent)

    if attrs:
        print('\n%s/* Attributes. */' % indent)
        for decl, comment in attrs:
            line = '%s%s' % (indent, decl)
            if comment:
                n_spaces = max(35 - len(line), 1)
                line += '%s/* %s. */' % (' ' * n_spaces, comment)
            print(line)

    if rhs['type'] == 'etc' or rhs['type'] == 'empty':
        return

    print('\n%s/* Content. */' % indent)
    if rhs['type'] == 'text':
        print('%schar *text; /* Always nonnull. */' % indent)
        return

    for a in rhs['items'] if rhs['type'] == '|' else (rhs,):
        for term in a['items'] if a['type'] == 'sequence' else (a,):
            if term['type'] == 'empty':
                pass
            elif term['type'] == 'nonterminal':
                nt_name = name_to_id(term['nonterminal_name'])
                member_name = name_to_id(term['member_name'])
                print('%sstruct %s%s *%s; /* Always nonnull. */' % (
                    indent, prefix, nt_name, member_name))
            elif term['type'] == '?' and term['item']['type'] == 'nonterminal':
                nt_name = name_to_id(term['item']['nonterminal_name'])
                member_name = name_to_id(term['item']['member_name'])
                print('%sstruct %s%s *%s; /* Possibly null. */' % (
                    indent, prefix, nt_name, member_name))
            elif (term['type'] in ('*', '+')
                  and term['item']['type'] == 'nonterminal'):
                nt_name = name_to_id(term['item']['nonterminal_name'])
                member_name = name_to_id(term['item']['member_name'])
                print('%sstruct %s%s **%s;' % (indent, prefix,
                                               nt_name, member_name))
                print('%ssize_t n_%s;' % (indent, member_name))
            else:
                seq_name = term['seq_name']
                print('%sstruct spvxml_node **%s;' % (indent, seq_name))
                print('%ssize_t n_%s;' % (indent, seq_name))


def bytes_to_hex(s):
    return ''.join(['"'] + ["\\x%02x" % ord(x) for x in s] + ['"'])


class Parser_Context(object):
    def __init__(self, function_name, productions):
        self.suffixes = {}
        self.bail = 'error'
        self.need_error_handler = False
        self.parsers = {}
        self.parser_index = 0
        self.productions = productions

        self.function_name = function_name
        self.functions = []
    def gen_name(self, prefix):
        n = self.suffixes.get(prefix, 0) + 1
        self.suffixes[prefix] = n
        return '%s%d' % (prefix, n) if n > 1 else prefix
    def new_function(self, type_name):
        f = Function('%s_%d' % (self.function_name, len(self.functions) + 1),
                     type_name)
        self.functions += [f]
        return f


def print_attribute_decls(name, attributes):
    if attributes:
        print('    enum {')
        for unique_name, (xml_name, value, required) in sorted(attributes.items()):
            c_name = name_to_id(unique_name)
            print('        ATTR_%s,' % c_name.upper())
        print('    };')
    print('    struct spvxml_attribute attrs[] = {')
    for unique_name, (xml_name, value, required) in sorted(attributes.items()):
        c_name = name_to_id(unique_name)
        print('        [ATTR_%s] = { "%s", %s, NULL },'
              % (c_name.upper(), xml_name, 'true' if required else 'false'))
    print('    };')
    print('    enum { N_ATTRS = sizeof attrs / sizeof *attrs };')


def print_parser_for_attributes(name, attributes):
    print('    /* Parse attributes. */')
    print('    spvxml_parse_attributes (&nctx);')

    if not attributes:
        return

    for unique_name, (xml_name, value, required) in sorted(attributes.items()):
        c_name = name_to_id(unique_name)
        params = '&nctx, &attrs[ATTR_%s]' % c_name.upper()
        if type(value) is set:
            if len(value) <= 1:
                if required:
                    print('    spvxml_attr_parse_fixed (%s, "%s");'
                          % (params, tuple(value)[0]))
                else:
                    print('    p->%s_present = spvxml_attr_parse_fixed (\n'
                          '        %s, "%s");'
                          % (c_name, params, tuple(value)[0]))
            elif value == set(('true', 'false')):
                print('    p->%s = spvxml_attr_parse_bool (%s);'
                      % (c_name, params))
            else:
                map_name = '%s%s_map' % (prefix, c_name)
                print('    p->%s = spvxml_attr_parse_enum (\n'
                      '        %s, %s);'
                      % (c_name, params, map_name))
        elif value in ('real', 'dimension', 'int', 'color'):
            print('    p->%s = spvxml_attr_parse_%s (%s);'
                  % (c_name, value, params))
        elif value == 'string':
            print('    p->%s = attrs[ATTR_%s].value;\n'
                  '    attrs[ATTR_%s].value = NULL;'
                  % (c_name, c_name.upper(),
                     c_name.upper()))
        elif value == 'id':
            print('    p->node_.id = attrs[ATTR_%s].value;\n'
                  '    attrs[ATTR_%s].value = NULL;'
                  % (c_name.upper(), c_name.upper()))
        elif value[0] == 'ref':
            pass
        else:
            assert False
    print('''\
    if (ctx->error) {
        spvxml_node_context_uninit (&nctx);
        ctx->hard_error = true;
        %sfree_%s (p);
        return false;
    }'''
          % (prefix, name_to_id(name)))

class Function(object):
    def __init__(self, function_name, type_name):
        self.function_name = function_name
        self.type_name = type_name
        self.suffixes = {}
        self.code = []
    def gen_name(self, prefix):
        n = self.suffixes.get(prefix, 0) + 1
        self.suffixes[prefix] = n
        return '%s%d' % (prefix, n) if n > 1 else prefix
    def print_(self):
        print('''
static bool
%s (struct spvxml_node_context *nctx, xmlNode **input, struct %s *p)
{'''
              % (self.function_name, self.type_name))
        while self.code and self.code[0] == '':
            self.code = self.code[1:]
        for line in self.code:
            print('    %s' % line if line else '')
        print('    return true;')
        print('}')

STATE_START = 0
STATE_ALTERNATION = 1
STATE_SEQUENCE = 2
STATE_REPETITION = 3
STATE_OPTIONAL = 4
STATE_GENERAL = 5

def generate_content_parser(nonterminal, rhs, function, ctx, state, seq_name):
    seq_name = seq_name if seq_name else rhs.get('seq_name')
    ctx.parser_index += 1

    if rhs['type'] == 'etc':
        function.code += ['spvxml_content_parse_etc (input);']
    elif rhs['type'] == 'text':
        function.code += ['if (!spvxml_content_parse_text (nctx, input, &p->text))',
                          '    return false;']
    elif rhs['type'] == '|':
        for i in range(len(rhs['items'])):
            choice = rhs['items'][i]
            subfunc = ctx.new_function(function.type_name)
            generate_content_parser(nonterminal, choice, subfunc, ctx,
                                    STATE_ALTERNATION
                                    if state == STATE_START
                                    else STATE_GENERAL, seq_name)
            function.code += ['%(start)s!%(tryfunc)s (nctx, input, p, %(subfunc)s)%(end)s'
                               % {'start': 'if (' if i == 0 else '    && ',
                                  'subfunc': subfunc.function_name,
                                  'tryfunc': '%stry_parse_%s'
                                  % (prefix, name_to_id(nonterminal)),
                                  'end': ')' if i == len(rhs['items']) - 1 else ''}]
        function.code += ['  {',
                          '    spvxml_content_error (nctx, *input, "Syntax error.");',
                          '    return false;',
                          '  }']
    elif rhs['type'] == 'sequence':
        for element in rhs['items']:
            generate_content_parser(nonterminal, element, function, ctx,
                                    STATE_SEQUENCE
                                    if state in (STATE_START,
                                                 STATE_ALTERNATION)
                                    else STATE_GENERAL, seq_name)
    elif rhs['type'] == 'empty':
        function.code += ['(void) nctx;']
        function.code += ['(void) input;']
        function.code += ['(void) p;']
    elif rhs['type'] in ('*', '+', '?'):
        subfunc = ctx.new_function(function.type_name)
        generate_content_parser(nonterminal, rhs['item'], subfunc, ctx,
                                (STATE_OPTIONAL
                                 if rhs['type'] == '?'
                                 else STATE_REPETITION)
                                if state in (STATE_START,
                                             STATE_ALTERNATION,
                                             STATE_SEQUENCE)
                                else STATE_GENERAL, seq_name)
        next_name = function.gen_name('next')
        args = {'subfunc': subfunc.function_name,
                'tryfunc': '%stry_parse_%s' % (prefix,
                                               name_to_id (nonterminal))}
        if rhs['type'] == '?':
            function.code += [
                '%(tryfunc)s (nctx, input, p, %(subfunc)s);' % args]
        else:
            if rhs['type'] == '+':
                function.code += ['if (!%(subfunc)s (nctx, input, p))' % args,
                                  '    return false;']
            function.code += [
                'while (%(tryfunc)s (nctx, input, p, %(subfunc)s))' % args,
                '    continue;']
    elif rhs['type'] == 'nonterminal':
        node_name = function.gen_name('node')
        function.code += [
            '',
            'xmlNode *%s;' % node_name,
            'if (!spvxml_content_parse_element (nctx, input, "%s", &%s))'
            % (ctx.productions[rhs['nonterminal_name']][0], node_name),
            '    return false;']
        if state in (STATE_START,
                     STATE_ALTERNATION,
                     STATE_SEQUENCE,
                     STATE_OPTIONAL):
            target = '&p->%s' % name_to_id(rhs['member_name'])
        else:
            assert state in (STATE_REPETITION, STATE_GENERAL)
            member = name_to_id(rhs['member_name']) if state == STATE_REPETITION else seq_name
            function.code += ['struct %s%s *%s;' % (
                prefix, name_to_id(rhs['nonterminal_name']), member)]
            target = '&%s' % member
        function.code += [
            'if (!%sparse_%s (nctx->up, %s, %s))'
            % (prefix, name_to_id(rhs['nonterminal_name']), node_name, target),
            '    return false;']
        if state in (STATE_REPETITION, STATE_GENERAL):
            function.code += [
                'p->%s = xrealloc (p->%s, sizeof *p->%s * (p->n_%s + 1));'
                % (member, member, member, member),
                'p->%s[p->n_%s++] = %s;' % (member, member,
                                            '&%s->node_' % member
                                            if state == STATE_GENERAL
                                            else member)]
    else:
        assert False

def print_parser(name, production, productions, indent):
    xml_name, attributes, rhs = production

    print('''
static bool UNUSED
%(prefix)stry_parse_%(name)s (
    struct spvxml_node_context *nctx, xmlNode **input,
    struct %(prefix)s%(name)s *p,
    bool (*sub) (struct spvxml_node_context *,
                 xmlNode **,
                 struct %(prefix)s%(name)s *))
{
    xmlNode *next = *input;
    bool ok = sub (nctx, &next, p);
    if (ok)
        *input = next;
    else if (!nctx->up->hard_error) {
        free (nctx->up->error);
        nctx->up->error = NULL;
    }
    return ok;
}'''
          % {'prefix': prefix,
             'name': name_to_id(name)})

    ctx = Parser_Context('%sparse_%s' % (prefix, name_to_id(name)),
                         productions)
    if rhs['type'] not in ('empty', 'etc'):
        function = ctx.new_function('%s%s' % (prefix, name_to_id(name)))
        generate_content_parser(name, rhs, function, ctx, 0, None)
        for f in reversed(ctx.functions):
            f.print_()

    print('''
bool
%(prefix)sparse_%(name)s (
    struct spvxml_context *ctx, xmlNode *input,
    struct %(prefix)s%(name)s **p_)
{'''
          % {'prefix': prefix,
             'name': name_to_id(name)})

    print_attribute_decls(name, attributes)

    print('    struct spvxml_node_context nctx = {')
    print('        .up = ctx,')
    print('        .parent = input,')
    print('        .attrs = attrs,')
    print('        .n_attrs = N_ATTRS,')
    print('    };')
    print('')
    print('    *p_ = NULL;')
    print('    struct %(prefix)s%(name)s *p = xzalloc (sizeof *p);'
          % {'prefix': prefix,
             'name': name_to_id(name)})
    print('    p->node_.raw = input;')
    print('    p->node_.class_ = &%(prefix)s%(name)s_class;'
          % {'prefix': prefix,
             'name': name_to_id(name)})
    print('')

    print_parser_for_attributes(name, attributes)

    if rhs['type'] == 'empty':
        print('''
    /* Parse content. */
    if (!spvxml_content_parse_end (&nctx, input->children)) {
        ctx->hard_error = true;
        spvxml_node_context_uninit (&nctx);
        %sfree_%s (p);
        return false;
    }'''
              % (prefix, name_to_id(name)))
    elif rhs['type'] == 'etc':
        print('''
    /* Ignore content. */
''')
    else:
        print('''
    /* Parse content. */
    input = input->children;
    if (!%s (&nctx, &input, p)
        || !spvxml_content_parse_end (&nctx, input)) {
        ctx->hard_error = true;
        spvxml_node_context_uninit (&nctx);
        %sfree_%s (p);
        return false;
    }'''
              % (function.function_name,
                 prefix, name_to_id(name)))

    print('''
    spvxml_node_context_uninit (&nctx);
    *p_ = p;
    return true;''')

    print("}")


def print_free_members(attributes, rhs, indent):
    for unique_name, (xml_name, value, required) in attributes.items():
        c_name = name_to_id(unique_name)
        if (type(value) is set
            or value in ('dimension', 'real', 'int', 'color', 'id')
            or value[0] == 'ref'):
            pass
        elif value == 'string':
            print('    free (p->%s);' % c_name);
        else:
            assert False

    if rhs['type'] in ('etc', 'empty'):
        pass
    elif rhs['type'] == 'text':
        print('    free (p->text);')
    else:
        n = 0
        for a in rhs['items'] if rhs['type'] == '|' else (rhs,):
            for term in a['items'] if a['type'] == 'sequence' else (a,):
                if term['type'] == 'empty':
                    pass
                elif (term['type'] == 'nonterminal'
                      or (term['type'] == '?'
                          and term['item']['type'] == 'nonterminal')):
                    if term['type'] == '?':
                        term = term['item']
                    nt_name = name_to_id(term['nonterminal_name'])
                    member_name = name_to_id(term['member_name'])
                    print('    %sfree_%s (p->%s);' % (prefix, nt_name,
                                                      member_name))
                elif (term['type'] in ('*', '+')
                      and term['item']['type'] == 'nonterminal'):
                    nt_name = name_to_id(term['item']['nonterminal_name'])
                    member_name = name_to_id(term['item']['member_name'])
                    print('''\
    for (size_t i = 0; i < p->n_%s; i++)
        %sfree_%s (p->%s[i]);
    free (p->%s);'''
                          % (member_name,
                             prefix, nt_name, member_name,
                             member_name))
                else:
                    n += 1
                    seq_name = 'seq' if n == 1 else 'seq%d' % n
                    print('''\
    for (size_t i = 0; i < p->n_%s; i++)
        p->%s[i]->class_->spvxml_node_free (p->%s[i]);
    free (p->%s);'''
                          % (seq_name,
                             seq_name, seq_name,
                             seq_name))
    print('    free (p->node_.id);')
    print('    free (p);')


def print_free(name, production, indent):
    xml_name, attributes, rhs = production

    print('''
void
%(prefix)sfree_%(name)s (struct %(prefix)s%(name)s *p)
{
    if (!p)
        return;
''' % {'prefix': prefix,
       'name': name_to_id(name)})

    print_free_members(attributes, rhs, ' ' * 4)

    print('}')

def name_to_id(s):
    return s[0].lower() + ''.join(['_%c' % x.lower() if x.isupper() else x
                                   for x in s[1:]]).replace('-', '_')


def print_recurse_members(attributes, rhs, function):
    if rhs['type'] == 'etc' or rhs['type'] == 'empty':
        pass
    elif rhs['type'] == 'text':
        pass
    else:
        n = 0
        for a in rhs['items'] if rhs['type'] == '|' else (rhs,):
            for term in a['items'] if a['type'] == 'sequence' else (a,):
                if term['type'] == 'empty':
                    pass
                elif (term['type'] == 'nonterminal'
                      or (term['type'] == '?'
                          and term['item']['type'] == 'nonterminal')):
                    if term['type'] == '?':
                        term = term['item']
                    nt_name = name_to_id(term['nonterminal_name'])
                    member_name = name_to_id(term['member_name'])
                    print('    %s%s_%s (ctx, p->%s);'
                          % (prefix, function, nt_name, member_name))
                elif (term['type'] in ('*', '+')
                      and term['item']['type'] == 'nonterminal'):
                    nt_name = name_to_id(term['item']['nonterminal_name'])
                    member_name = name_to_id(term['item']['member_name'])
                    print('''\
    for (size_t i = 0; i < p->n_%s; i++)
        %s%s_%s (ctx, p->%s[i]);'''
                          % (member_name,
                             prefix, function, nt_name, member_name))
                else:
                    n += 1
                    seq_name = 'seq' if n == 1 else 'seq%d' % n
                    print('''\
    for (size_t i = 0; i < p->n_%s; i++)
        p->%s[i]->class_->spvxml_node_%s (ctx, p->%s[i]);'''
                          % (seq_name,
                             seq_name, function, seq_name))


def print_collect_ids(name, production):
    xml_name, attributes, rhs = production

    print('''
void
%(prefix)scollect_ids_%(name)s (struct spvxml_context *ctx, struct %(prefix)s%(name)s *p)
{
    if (!p)
        return;

    spvxml_node_collect_id (ctx, &p->node_);
''' % {'prefix': prefix,
       'name': name_to_id(name)})

    print_recurse_members(attributes, rhs, 'collect_ids')

    print('}')


def print_resolve_refs(name, production):
    xml_name, attributes, rhs = production

    print('''
bool
%(prefix)sis_%(name)s (const struct spvxml_node *node)
{
    return node->class_ == &%(prefix)s%(name)s_class;
}

struct %(prefix)s%(name)s *
%(prefix)scast_%(name)s (const struct spvxml_node *node)
{
    return (node && %(prefix)sis_%(name)s (node)
            ? UP_CAST (node, struct %(prefix)s%(name)s, node_)
            : NULL);
}

void
%(prefix)sresolve_refs_%(name)s (struct spvxml_context *ctx UNUSED, struct %(prefix)s%(name)s *p UNUSED)
{
    if (!p)
        return;
''' % {'prefix': prefix,
       'name': name_to_id(name)})

    i = 0
    for unique_name, (xml_name, value, required) in sorted(attributes.items()):
        c_name = name_to_id(unique_name)
        if type(value) is set or value[0] != 'ref':
            continue

        if value[1] is None:
            print('    p->%s = spvxml_node_resolve_ref (ctx, p->node_.raw, \"%s\", NULL, 0);'
                  % (c_name, xml_name))
        else:
            i += 1
            name = 'classes'
            if i > 1:
                name += '%d' % i
            if type(value[1]) is set:
                print('    static const struct spvxml_node_class *const %s[] = {' % name)
                for ref_type in sorted(value[1]):
                    print('        &%(prefix)s%(ref_type)s_class,'
                         % {'prefix': prefix,
                            'ref_type': name_to_id(ref_type)})
                print('    };');
                print('    const size_t n_%s = sizeof %s / sizeof *%s;'
                      % (name, name, name))
                print('    p->%(member)s = spvxml_node_resolve_ref (ctx, p->node_.raw, \"%(attr)s\", %(name)s, n_%(name)s);'
                      % {"member": c_name,
                         "attr": xml_name,
                         'prefix': prefix,
                         'name': name
                         })
            else:
                print('    static const struct spvxml_node_class *const %s' % name)
                print('        = &%(prefix)s%(ref_type)s_class;'
                      % {'prefix': prefix,
                         'ref_type': name_to_id(value[1])})
                print('    p->%(member)s = %(prefix)scast_%(ref_type)s (spvxml_node_resolve_ref (ctx, p->node_.raw, \"%(attr)s\", &%(name)s, 1));'
                      % {"member": c_name,
                         "attr": xml_name,
                         'prefix': prefix,
                         'name': name,
                         'ref_type': name_to_id(value[1])})


    print_recurse_members(attributes, rhs, 'resolve_refs')

    print('}')


def name_to_id(s):
    return s[0].lower() + ''.join(['_%c' % x.lower() if x.isupper() else x
                                   for x in s[1:]]).replace('-', '_')


if __name__ == "__main__":
    argv0 = sys.argv[0]
    try:
        options, args = getopt.gnu_getopt(sys.argv[1:], 'h', ['help'])
    except getopt.GetoptError as e:
        sys.stderr.write("%s: %s\n" % (argv0, e.msg))
        sys.exit(1)

    for key, value in options:
        if key in ['-h', '--help']:
            usage()
        else:
            sys.exit(0)

    if len(args) < 3:
        sys.stderr.write("%s: bad usage (use --help for help)\n" % argv0)
        sys.exit(1)

    global file_name
    global prefix
    file_name, output_type, prefix = args[:3]
    input_file = open(file_name)

    prefix = '%s_' % prefix

    global line
    global line_number
    line = ""
    line_number = 0

    productions = {}

    global token
    token = ('start', )
    get_token()
    while True:
        while match(';'):
            pass
        if token[0] == 'eof':
            break

        name, xml_name, attributes, rhs = parse_production()
        if name in productions:
            fatal("%s: duplicate production" % name)
        productions[name] = (xml_name, attributes, rhs)

    print('/* Generated automatically -- do not modify!    -*- buffer-read-only: t -*- */')
    if output_type == 'code' and len(args) == 4:
        header_name = args[3]

        print("""\
#include <config.h>
#include %s
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include "libpspp/cast.h"
#include "libpspp/str.h"
#include "gl/xalloc.h"

""" % header_name)
        for enum_name, values in sorted(enums.items()):
            if len(values) <= 1:
                continue

            c_name = name_to_id(enum_name)
            print('\nstatic const struct spvxml_enum %s%s_map[] = {'
                  % (prefix, c_name))
            for value in sorted(values):
                print('    { "%s", %s%s_%s },' % (value, prefix.upper(),
                                                 c_name.upper(),
                                                 name_to_id(value).upper()))
            print('    { NULL, 0 },')
            print('};')
            print('\nconst char *')
            print('%s%s_to_string (enum %s%s %s)'
                  % (prefix, c_name, prefix, c_name, c_name))
            print('{')
            print('    switch (%s) {' % c_name)
            for value in sorted(values):
                print('    case %s%s_%s: return "%s";'
                       % (prefix.upper(), c_name.upper(),
                          name_to_id(value).upper(), value))
            print('    default: return NULL;')
            print('    }')
            print('}')

        for name, (xml_name, attributes, rhs) in sorted(productions.items()):
            print('static void %(prefix)scollect_ids_%(name)s (struct spvxml_context *, struct %(prefix)s%(name)s *);\n'
                  'static void %(prefix)sresolve_refs_%(name)s (struct spvxml_context *ctx UNUSED, struct %(prefix)s%(name)s *p UNUSED);\n'
                  % {'prefix': prefix,
                     'name': name_to_id(name)})
        for name, production in sorted(productions.items()):
            print_parser(name, production, productions, ' ' * 4)
            print_free(name, production, ' ' * 4)
            print_collect_ids(name, production)
            print_resolve_refs(name, production)
            print('''
static void
%(prefix)sdo_free_%(name)s (struct spvxml_node *node)
{
    %(prefix)sfree_%(name)s (UP_CAST (node, struct %(prefix)s%(name)s, node_));
}

static void
%(prefix)sdo_collect_ids_%(name)s (struct spvxml_context *ctx, struct spvxml_node *node)
{
    %(prefix)scollect_ids_%(name)s (ctx, UP_CAST (node, struct %(prefix)s%(name)s, node_));
}

static void
%(prefix)sdo_resolve_refs_%(name)s (struct spvxml_context *ctx, struct spvxml_node *node)
{
    %(prefix)sresolve_refs_%(name)s (ctx, UP_CAST (node, struct %(prefix)s%(name)s, node_));
}

struct spvxml_node_class %(prefix)s%(name)s_class = {
    "%(class)s",
    %(prefix)sdo_free_%(name)s,
    %(prefix)sdo_collect_ids_%(name)s,
    %(prefix)sdo_resolve_refs_%(name)s,
};
'''
            % {'prefix': prefix,
               'name': name_to_id(name),
               'class': (name if name == production[0]
                         else '%s (%s)' % (name, production[0]))})
    elif output_type == 'header' and len(args) == 3:
        print("""\
#ifndef %(PREFIX)sPARSER_H
#define %(PREFIX)sPARSER_H

#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#include "output/spv/spvxml-helpers.h"\
""" % {'PREFIX': prefix.upper()})
        for name, (xml_name, attributes, rhs) in sorted(productions.items()):
            print_members(attributes, rhs, ' ' * 4)
            print('''};

extern struct spvxml_node_class %(prefix)s%(name)s_class;

bool %(prefix)sparse_%(name)s (struct spvxml_context *, xmlNode *input, struct %(prefix)s%(name)s **);
void %(prefix)sfree_%(name)s (struct %(prefix)s%(name)s *);
bool %(prefix)sis_%(name)s (const struct spvxml_node *);
struct %(prefix)s%(name)s *%(prefix)scast_%(name)s (const struct spvxml_node *);'''
                  % {'prefix': prefix,
                     'name': name_to_id(name)})
        print("""\

#endif /* %(PREFIX)sPARSER_H */""" % {'PREFIX': prefix.upper()})
    else:
        sys.stderr.write("%s: bad usage (use --help for help)" % argv0)
