#!/usr/bin/python

# Extract remote descriptions from annotated source code
# Copyright (C) 2008 Openismus GmbH (www.openismus.com)
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
#
# Authors:
#   Mathias Hasselmann <mathias@openismus.com>
#
# Usage:
#   build-receiver-list [SRCDIR]
#
import gzip, logging, os, pwd, re, sys

from ConfigParser import SafeConfigParser
from datetime     import datetime

def find_srcdir():
    srcdir = len(sys.argv) > 1 and sys.argv[1] or ''
    filename = os.path.join(srcdir, 'daemons', 'lircd.c')

    if not os.path.isfile(filename):
        raise SystemExit, 'No LIRC code found at %s.' % (srcdir and srcdir or os.getcwd())

    return srcdir

def drop_hidden_names(filenames):
    enumeration = list(enumerate(filenames))
    enumeration.reverse()

    for i, name in enumeration:
        if name.startswith('.'):
            filenames.pop(i)

def parse_comments(text):
    comments  = re_comments.findall(text)
    receivers = []

    # parse comments to build list of property dictionaries:
    for details in filter(lambda s: s.count(':'), comments):
        details = [re_properties.match(p) for p in details.splitlines()]
        details = [p and p.groups() for p in details]
        details = dict(filter(None, details))

        receivers.append(details)

    return receivers

def scan_sources(path, scanner):
    daemons_path = os.path.join(srcdir, 'daemons')

    for path, subdirs, files in os.walk(path):
        drop_hidden_names(subdirs)

        for name in files:
            if not name.endswith('.c'): continue
            if name.startswith('.'): continue

            scanner(os.path.join(path, name))

def scan_userspace_driver(filename):
    driver_code = open(filename).read()

    for declaration in re_hardware.finditer(driver_code):
        declaration = declaration.group(1)

        # convert decaration text into list of C expressions:
        expressions = re_comments.sub('', declaration).split(',')
        expressions = [value.strip() for value in expressions]

        # last expression is the driver name:
        driver = expressions[-1].strip('"')
        if not driver: continue

        # extract receiver information from comments:
        receivers = parse_comments(declaration)

        if not receivers:
            logging.warning(
                'No receivers declared for userspace driver %s.',
                driver)

            continue

        # print receiver information for current driver:
        print '# receivers supported by %s userspace driver' % driver
        print '# %s' % ('-' * 70)

        for receiver in receivers:
            receiver['lirc-driver'] = driver
            print_receiver_details(receiver)

def expand_symbols(symbols, text):
    def replace_symbol(match):
        # lookup word in symbol table:
        expansion = symbols.get(match.group(0))

        if expansion:
            # expand symbol recursively when found:
            return expand_symbols(symbols, expansion)

        return match.group(0)

    return re.sub(r'\b\w+\b', replace_symbol, text)

def override_name(overrides, name):
    key = 'name-%s' % name.lower()
    return overrides.get(key, name)

def derive_name(lookup, match, overrides=dict()):
    derived = match and match.group(1).title() or None

    if derived:
        if len(derived) < 4:
            derived = derived.upper()

        derived = override_name(overrides, derived)

    if derived and lookup is not None:
        if lookup.name.lower().find(derived.lower()) >= 0:
            return lookup.name

        return '%s/%s' % (lookup.name, derived)

    if lookup is not None:
        return lookup.name
    if derived:
        return derived

    return None

def scan_kernel_driver(filename):
    overrides = SafeConfigParser()
    overrides.read('data/overrides.conf')

    vendor_overrides = dict(overrides.items('usb-vendors'))
    product_overrides = dict(overrides.items('usb-products'))

    overrides = None

    srcname = filename[len(srcdir):].strip(os.sep)
    driver_code = open(filename).read()

    def identify_usb_vendor(vendor_id):
        vendor_match = re_usb_vendor.match(vendor_id)
        vendor_id = int(expand_symbols(symbols, vendor_id), 0)

        vendor_lookup = usb_ids.get(vendor_id)
        vendor_name = derive_name(vendor_lookup, vendor_match, vendor_overrides)

        return vendor_id, (vendor_name or 'Unknown Vendor (USB-Id: %04X)' % vendor_id)

    def identify_usb_product(vendor_id, product_id):
        product_match = re_usb_product.match(product_id)
        product_id = int(expand_symbols(symbols, product_id), 0)

        product_table = usb_ids.get(vendor_id)
        product_lookup = product_table.get(product_id) if product_table else None
        product_name = derive_name(product_lookup, product_match, product_overrides)

        if product_name is None and device_block:
            product_name = override_name(product_overrides, device_block)

        return product_id, product_name

    # naively parse preprocessor symbols:
    symbols = dict()

    for declaration in re_define.finditer(driver_code):
        name, value = declaration.groups()
        symbols[name] = value

    # resolve driver name, from symbol table or filename:
    driver_name = symbols.get('DRIVER_NAME')

    if not driver_name:
        dirname     = os.path.dirname(filename)
        driver_name = os.path.basename(dirname)

    else:
        driver_name = driver_name.strip('"')

    # iterate source code lines:
    device_block = None

    for line, text in enumerate(driver_code.splitlines()):
        match = re_usb_device_block_begin.search(text)

        if match:
            device_block = match.group(1)
            continue

        match = re_usb_device_block_end.search(text)

        if match:
            device_block = None
            continue

        match = re_usb_device.search(text)

        if match:
            vendor_id, product_id = match.groups()

            vendor_id, vendor_name = identify_usb_vendor(vendor_id)
            product_id, product_name = identify_usb_product(vendor_id, product_id)

            vendor_name = product_overrides.get(
                '%04x-%04x-vendor' % (vendor_id, product_id),
                vendor_name)
            product_name = product_overrides.get(
                '%04x-%04x-product' % (vendor_id, product_id),
                product_name)

            if not product_name:
                logging.warning('%s:%d: Unknown USB device %04x:%04x',
                                srcname, line + 1, vendor_id, product_id)
                product_name = 'Unknown Device (USB-Id: %04X)' % product_id

            print '[%s: %s]' % (vendor_name, product_name)
            print 'source-code = %s, line %d' % (srcname, line + 1)
            print 'kernel-module = %s' % driver_name
            print 'product_id = 0x%04x' % product_id
            print 'vendor-id = 0x%04x' % vendor_id
            print

def print_database_header():
    realname = pwd.getpwuid(os.getuid()).pw_gecos.split(',')[0]

    print '# LIRC Receiver Database'
    print '# %s' % ('=' * 70)
    print '# Generated on %s' % datetime.now().strftime('%c')
    print '# from %s' % os.path.realpath(srcdir)
    print '# by %s' % realname
    print '# %s' % ('=' * 70)
    print

def print_receiver_details(properties):
    main_properties = (
        'compatible-remotes', 'device-ids',
        'kernel-modules', 'lirc-driver')

    vendor, product = map(properties.pop, ('vendor-name', 'product-name', ))
    indent = max(map(len, properties.keys()))

    print '[%s: %s]' % (vendor, product)

    for key in main_properties:
        value = properties.pop(key, None)

        if not value:
            continue

        print '%*s = %s' % (-indent, key, value)

    for key, value in properties.items():
        print '%*s = %s' % (-indent, key, value)

    print

class DeviceDatabase(dict):
    re_record = re.compile(r'^(\t*)([0-9A-Fa-f]{4})\s+(.*)\s*$')

    class Record(dict):
        def __init__(self, code, name):
            super(DeviceDatabase.Record, self).__init__()

            self.__code = code
            self.__name = name

        def __str__(self):
            return self.name

        def __repr__(self):
            return '<%s: %r>' % (self.name, dict(self.items()))

        code = property(lambda self: self.__code)
        name = property(lambda self: self.__name)

    def __init__(self, fileobj):
        super(DeviceDatabase, self).__init__()
        vendor, device = None, None

        for line in fileobj:
            match = self.re_record.match(line)

            if not match:
                continue

            prefix, code, name = match.groups()
            code = int(code, 16)

            if 0 == len(prefix):
                vendor, device = self.Record(code, name), None
                self[code] = vendor
                continue

            if 1 == len(prefix):
                device = self.Record(code, name)
                vendor[code] = device
                continue

            if 2 == len(prefix):
                iface = self.Record(code, name)
                device[code] = iface
                continue

if '__main__' == __name__:
    # initialize logging facilities:
    logging.BASIC_FORMAT = '%(levelname)s: %(message)s'

    # find lirc sources:
    srcdir = find_srcdir()

    # declare some frequenty used regular expressions:
    re_hardware = r'struct\s+hardware\s+hw_\w+\s*=\s*{(.*?)};'
    re_hardware = re.compile(re_hardware, re.DOTALL)

    re_comments = r'/\*\s*(.*?)\s*\*/'
    re_comments = re.compile(re_comments, re.DOTALL)

    re_properties = r'^(?:\s|\*)*(\S+)\s*:\s*(.*?)(?:\s|\*)*$'
    re_properties = re.compile(re_properties)

    re_define = r'^#\s*define\s+(\w+)\s+(.*?)\s*$'
    re_define = re.compile(re_define, re.MULTILINE)

    re_usb_device_block_begin = r'/\*\s*USB Device ID for (.*) USB Control Board\s\*/'
    re_usb_device_block_begin = re.compile(re_usb_device_block_begin)

    re_usb_device_block_end = r'{\s*}'
    re_usb_device_block_end = re.compile(re_usb_device_block_end)

    re_usb_device = r'USB_DEVICE\s*\(\s*([^,]*),\s*(.*?)\s*\)'
    re_usb_device = re.compile(re_usb_device)

    re_usb_vendor = r'^(?:USB_|VENDOR_)?([A-Z]+?)[0-9]*(?:_VENDOR_ID)?$'
    re_usb_vendor = re.compile(re_usb_vendor)

    re_usb_product = r'^(?:USB_|PRODUCT_)?([A-Z]+?)[0-9]*(?:_PRODUCT_ID)?$'
    re_usb_product = re.compile(re_usb_product)

    usb_ids = DeviceDatabase(gzip.open('/usr/share/misc/usb.ids'))
    pci_ids = DeviceDatabase(open('/usr/share/misc/pci.ids'))

    # scan source code for receiver information,
    # and dump this information immediatly:
    print_database_header()

    # TODO: merge information from data/receivers.conf
    # TODO: scan_sources(os.path.join(srcdir, 'daemons'), scan_userspace_driver)
    scan_sources(os.path.join(srcdir, 'drivers'), scan_kernel_driver)
