#!/usr/bin/env python
#
# Copyright (c) 2008-2011 Wind River Systems, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the Lesser GNU General Public License version 2.1 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the Lesser GNU General Public License for more details.
#
# You should have received a copy of the Lesser GNU General Public License
# version 2.1 along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 
#
"""convert wrapfuncs.in to wrapper function stubs and tables"""

import datetime
import glob
import sys
import re
import os.path
import string
import subprocess
from templatefile import TemplateFile

class ArgumentList:
    """A (possibly empty) list of arguments"""

    def __init__(self, text):
        "parse a comma-separated argument list (including function prototypes)"
        self.args = []
        self.variadic = False
        self.variadic_decl = ""
        self.variadic_start = ""
        self.variadic_end = ""
        # (void) is an empty list, not a list of a single argument which is void
        if text == "void":
            return

        depth = 0
        accum = ''
        comma_sep = text.split(', ')
        # now, what if there was a comma embedded in an argument?
        for arg in comma_sep:
            lcount = arg.count('(')
            rcount = arg.count(')')
            depth = depth + lcount - rcount
            if (depth > 0):
                accum += arg + ', '
            else:
                self.args.append(Argument(accum + arg))
                accum = ''
        if depth != 0:
            raise Exception("mismatched ()s while parsing '%s'" % text)
        if self.args[-1].vararg:
            self.variadic = True
            self.variadic_arg = self.args[-1]
            self.last_fixed_arg = self.args[-2].name
            self.variadic_decl = "va_list ap;\n"
            self.variadic_start = "va_start(ap, %s);\n" % self.last_fixed_arg
            if self.variadic_arg.vararg_wraps:
                self.variadic_decl += "\t%s;\n" % \
                    self.variadic_arg.vararg_wraps.decl()
                self.variadic_start += ("\t%s = va_arg(ap, %s);"
                                        "\n\tva_end(ap);\n") % \
                                       (self.variadic_arg.name,
                                        self.variadic_arg.type)
            else:
                # lie blatantly; we don't handle this case
                self.variadic = False

    # for a wrap function, the outer foo() wrapper will convert to a va_list,
    # but the inner wrap_foo() just passes the va_list through.
    def maybe_variadic_start(self):
        """Use va_arg() to grab an optional argument if needed."""
        if self.variadic and self.variadic_arg.vararg_wraps:
            return self.variadic_start
        else:
            return ""

    def maybe_variadic_decl(self):
        """Declare va_list ap and optional argument, if needed."""
        if self.variadic and self.variadic_arg.vararg_wraps:
            return self.variadic_decl
        else:
            return ""

    def decl(self, comment=False, wrap=False):
        """Produce the declaration form of this argument list."""
        if not self.args:
            return "void"

        return ', '.join(x.decl(comment=comment, wrap=wrap) for x in self.args)

    def call(self):
        """Produce the calling form of this argument list."""
        if not self.args:
            return ""
        return ', '.join([x.call() for x in self.args])

    def __repr__(self):
        if not self.args:
            return "no arguments"
        else:
            return '::'.join([x.decl() for x in self.args])

    
class Argument:
    """A function argument such as 'char *path' or 'char (*foo)(void)'"""
    def __init__(self, text):
        """get the type and name of a trivial C declaration"""
        self.vararg = False
        self.function_pointer = False
        self.spacer = ''
        
        if text == 'void':
            raise Exception("Tried to declare a nameless object of type void.")

        if text.startswith('...'):
            self.vararg = True
            if len(text) > 3:
                # we're a wrapper for something else, declared as
                # ...{real_decl}, as in the third argument to open(2)
                text = text[4:-1]
                # stash a copy of these values without the vararg flag, so
                # we can declare them prettily later
                self.vararg_wraps = Argument(text)
            else:
                # nothing to do.
                self.vararg_wraps = None
                self.type, self.name = None, None
                return
        else:
            self.vararg = False

        # try for a function pointer
        match = re.match('(.*)\(\*([a-zA-Z0-9$_]*)\)\((.*)\)', text)
        if match:
            self.function_pointer = True
            self.args = match.group(3)
            self.type = match.group(1)
            self.name = match.group(2).rstrip()
        else:
            # plain declaration
            match = re.match('(.*[ *])\(?\*?([a-zA-Z0-9$_]*)\)?', text)
            # there may not be a match, say in the special case
            # where an arg is '...'
            if match:
                self.type, self.name = match.group(1).rstrip(), match.group(2)
            else:
                self.type, self.name = None, None

        # spacing between type and name, needed if type ends with a character
        # which could be part of an identifier
        if re.match('[_a-zA-Z0-9]', self.type[-1]):
            self.spacer = ' '

    def decl(self, comment=False, wrap=False):
        """Produce the declaration form of this argument."""
        if self.function_pointer:
            decl = "%s%s(*%s)(%s)" % \
                   (self.type, self.spacer, self.name, self.args)
        else:
            decl = "%s%s%s" % (self.type, self.spacer, self.name)

        if self.vararg:
            if self.vararg_wraps:
                if comment:
                    decl = "... { %s }" % decl
                else:
                    decl = "... /* %s */" % decl
            else:
                if wrap:
                    decl = "va_list ap"
                else:
                    decl = "..."
        return decl

    def call(self):
        """Produce the call form of this argument (usually its name)."""
        if self.type == 'void':
            return ''

        if self.vararg and not self.vararg_wraps:
            return "ap"

        return self.name

    def __str__(self):
        return self.decl()

    def __repr__(self):
        return self.decl()

class Function:
    """A function signature and additional data about how the function works"""
    def __init__(self, port, line):
        # table of known default values:
        default_values = {
            'gid_t': '0',
            'uid_t': '0',
            'int': '-1',
            'long': '-1',
            'ssize_t': '-1'
        }
    
        self.dirfd = 'AT_FDCWD'
        self.flags = '0'
        self.port = port
        self.directory = ''
        # On Darwin, some functions are SECRETLY converted to foo$INODE64
        # when called.  So we have to look those up for real_*
        self.inode64 = None
        self.real_func = None
        self.paths_to_munge = []
        self.hand_wrapped = None
        # used for the copyright date when creating stub functions
        self.date = datetime.date.today().year
    
        function, comments = line.split(';')
        comment = re.search('/\* *(.*) *\*/', comments)
        if comment:
            self.comments = comment.group(1)
        else:
            self.comments = None
    
        bits = re.match('([^(]*)\((.*)\)', function)
        type_and_name = Argument(bits.group(1))
        self.type, self.name = type_and_name.type, type_and_name.name
        # convenient to have this declared here so we can use its .decl later
        if self.type != 'void':
            self.return_code = Argument("%s rc" % self.type)
    
        # Some args get special treatment:
        # * If the arg has a name ending in 'path', we will canonicalize it.
        # * If the arg is named 'dirfd' or 'flags', it becomes the default
        #   values for the dirfd and flags arguments when canonicalizing.
        # * Note that the "comments" field (/* ... */ after the decl) can
        #   override the dirfd/flags values.
        self.args = ArgumentList(bits.group(2))
        for arg in self.args.args:
            # ignore varargs, they never get these special treatments
            if arg.vararg:
                pass
            elif arg.name == 'dirfd':
                self.dirfd = 'dirfd'
            elif arg.name == 'flags':
                self.flags = 'flags'
            elif arg.name[-4:] == 'path':
                self.paths_to_munge.append(arg.name)
    
        # pick default values
        if self.type == 'void':
            self.default_value = ''
        elif self.type[-1:] == '*':
            self.default_value = 'NULL'
        else:
            try:
                self.default_value = default_values[self.type]
            except KeyError:
                raise KeyError("Function %s has return type %s,"
                               "for which there is no default value." %
                               (self.name, self.type))

        # handle special comments, such as flags=AT_SYMLINK_NOFOLLOW
        if self.comments:
            modifiers = self.comments.split(', ')
            for mod in modifiers:
                key, value = mod.split('=')
                value = value.rstrip()
                setattr(self, key, value)
    
    def maybe_inode64(self):
        if self.inode64 and os.uname()[0] == 'Darwin':
            return "$INODE64"
        else:
            return ""

    def end_maybe_skip(self):
        if self.hand_wrapped:
            return """/* Hand-written wrapper for this function. */
#endif
"""
        else:
            return ""

    def maybe_skip(self):
        if self.hand_wrapped:
            return """/* Hand-written wrapper for this function. */
#if 0
"""
        else:
            return ""

    def comment(self):
        """declare self (in a comment)"""
        return self.decl(comment = True)
        
    def decl(self, comment=False, wrap=True):
        """declare self"""
        if self.type[-1:] == '*':
            spacer = ''
        else:
            spacer = ' '
        return "%s%s%s(%s)" % \
                (self.type, spacer, self.name, self.args.decl(comment, wrap))

    def decl_args(self):
        """declare argument list"""
        return self.args.decl()

    def wrap_args(self):
        """declare argument list for wrap_foo() variant"""
        return self.args.decl(wrap = True)

    def call_args(self):
        """present argument list for a function call"""
        return self.args.call()

    def alloc_paths(self):
        """create/allocate canonical paths"""
        alloc_paths = []
        for path in self.paths_to_munge:
            alloc_paths.append(
                "%s = pseudo_root_path(__func__, __LINE__, %s, %s, %s);" %
                (path, self.dirfd, path, self.flags))
        return "\n\t\t\t".join(alloc_paths)

    def free_paths(self):
        """free any allocated paths"""
        free_paths = []
        # the cast is here because free's argument isn't const qualified, but
        # the original path may have been -- but we only GET here if the path
        # has been overwritten.
        for path in self.paths_to_munge:
            free_paths.append("free((void *) %s);" % path)
        return "\n\t\t\t".join(free_paths)

    def real_predecl(self):
        if self.real_func:
            return self.decl().replace(self.name, self.real_func, 1) + ";"
        else:
            return ""

    def real_init(self):
        if self.real_func:
            return self.real_func
        else:
            return "NULL"

    def rc_return(self):
        """return rc (or just return)"""
        if self.type == 'void':
            return "return;"
        else:
            return "return rc;"

    def rc_decl(self):
        """declare rc (if needed)"""
        if self.type == 'void':
            return ""
        else:
            return "%s = %s;" % (self.return_code.decl(), self.default_value)

    def rc_assign(self):
        """assign something to rc (or discard it)"""
        if self.type == 'void':
            return "(void)"
        else:
            return "rc ="

    def def_return(self):
        """return default value (or just return)"""
        if self.type == 'void':
            return "return;"
        else:
            return "return %s;" % self.default_value

    def __getitem__(self, key):
        """Make this object look like a dict for Templates to use"""
        try:
            attr = getattr(self, key)
        except AttributeError:
            # There's a few attributes that are handled inside the args
            # object, so check there too...
            attr = getattr(self.args, key)

        if callable(attr):
            return attr()
        else:
            return attr

    def __repr__(self):
        pretty = "%(name)s returns %(type)s and takes " % self
        pretty += repr(self.args)
        if self.comments:
            pretty += ' (%s)' % self.comments
        return pretty

class Port:
    """
A Port is a set of function declarations and code providing
details specific to a specific host environment, such as Linux.
Ports can override each other, and each port can indicate
additional ports to include.
"""

    def __init__(self, port, sources):
        self.name = port
        self.subports = []
        self.preports = []

        if os.path.exists(self.portfile("pseudo_wrappers.c")):
            self.wrappers = self.portfile("pseudo_wrappers.c")
        else:
            self.wrappers = None

        if os.path.exists(self.portfile("portdefs.h")):
            self.portdef_file = self.portfile("portdefs.h")
        else:
            self.portdef_file = None

        if os.path.exists(self.portfile("wrapfuncs.in")):
            self.funcs = process_wrapfuncs(port)
        else:
            self.funcs = {}

        for source in sources:
            source.emit('port', self)

        if os.path.exists(self.portfile("preports")):
            subport_proc = subprocess.Popen([self.portfile("preports"), self.name], stdout=subprocess.PIPE)
            portlist = subport_proc.communicate()[0]
            retcode = subport_proc.poll()
            if retcode:
                raise Exception("preports script failed for port %s" % self.name)

            for preport in string.split(portlist):
                next = Port(preport, sources)
                self.preports.append(next)

        if os.path.exists(self.portfile("subports")):
            subport_proc = subprocess.Popen([self.portfile("subports"), self.name], stdout=subprocess.PIPE)
            portlist = subport_proc.communicate()[0]
            retcode = subport_proc.poll()
            if retcode:
                raise Exception("subports script failed for port %s" % self.name)

            for subport in string.split(portlist):
                next = Port(subport, sources)
                self.subports.append(next)

    def functions(self):
        mergedfuncs = {}
        for pre in self.preports:
            prefuncs = pre.functions()
            for name in prefuncs.keys():
                if name in mergedfuncs:
                    print "Warning: %s from %s overriding %s" % (name, pre.name, mergedfuncs[name].port)
                mergedfuncs[name] = prefuncs[name]
        for name in self.funcs.keys():
            if name in mergedfuncs:
                print "Warning: %s from %s overriding %s" % (name, self.name, mergedfuncs[name].port)
            mergedfuncs[name] = self.funcs[name]
        for sub in self.subports:
            subfuncs = sub.functions()
            for name in subfuncs.keys():
                if name in mergedfuncs:
                    print "Warning: %s from %s overriding %s" % (name, sub.name, mergedfuncs[name].port)
                mergedfuncs[name] = subfuncs[name]
        return mergedfuncs

    def define(self):
        return '#define PSEUDO_PORT_%s 1' % string.upper(self.name).replace('/', '_')

    def portdefs(self):
        if self.portdef_file:
            return '#include "%s"' % self.portdef_file
        else:
            return '/* no portdefs for %s */' % self.name

    def include(self):
        if self.wrappers:
            return '#include "%s"' % self.wrappers
        else:
            return '/* no #include for %s */' % self.name

    def portfile(self, name):
        return "ports/%s/%s" % (self.name, name)

    def __getitem__(self, key):
        """Make this object look like a dict for Templates to use"""
        try:
            attr = getattr(self, key)
        except AttributeError:
            return None

        if callable(attr):
            return attr()
        else:
            return attr


def process_wrapfuncs(port):
    """Process a wrapfuncs.in file, generating a list of prototypes."""
    filename = "ports/%s/wrapfuncs.in" % port
    funcs = {}
    directory = os.path.dirname(filename)
    sys.stdout.write("%s: " % filename)
    funclist = open(filename)
    for line in funclist:
        line = line.rstrip()
        if line.startswith('#') or not line:
            continue
        try:
            func = Function(port, line)
            func.directory = directory
            funcs[func.name] = func
            sys.stdout.write(".")
        except Exception, e:
            print "Parsing failed:", e
            exit(1)
    funclist.close()
    print ""
    return funcs

def main():
    """Read in function definitions, write out files based on templates."""
    funcs = []
    sources = []

    # error checking helpfully provided by the exception handler
    copyright_file = open('guts/COPYRIGHT')
    TemplateFile.copyright = copyright_file.read()
    copyright_file.close()

    for path in glob.glob('templates/*'):
        try:
            source = TemplateFile(path)
            source.emit('copyright')
            source.emit('header')
            sources.append(source)
        except IOError:
            print "Invalid or malformed template %s.  Aborting." % path
            exit(1)

    try:
        port = Port('common', sources)

    except KeyError:
        print "Unknown uname -s result: '%s'." % uname_s
        print "Known system types are:"
        print "%-20s %-10s %s" % ("uname -s", "port name", "description")
        for key in host_ports:
            print "%-20s %-10s %s" % (key, host_ports[key],
                                      host_descrs[host_ports[key]])

    # the per-function stuff
    print "Writing functions...",
    all_funcs = port.functions()
    for name in sorted(all_funcs.keys()):
        # populate various tables and files with each function
        for source in sources:
            source.emit('body', all_funcs[name])
    print "done.  Cleaning up."
    
    for source in sources:
        # clean up files
        source.emit('footer')
        source.close()

if __name__ == '__main__':
    main()
