"""
comu.py

Implementation of the COMUnicator TCP/IP protocol
A client only has been implemented so far

25 Jan 2009, Lev
"""

# changelog
#
# 25 Jan 2009 (Lev)
# Rewrite Config.talk. Add an option to ask repeated queries once per request
# Accept requests in a form of a string '<instrument_name><space><request>'

import os.path, socket

class ComuError(Exception):
    "A communicator error"
    def __init__(self, errorcode):
        """
        Create an error instance.
        'errorcode' argument is treated as a numeric/string representation
        of error code, or, if not recognised, as an error message.
        """
        self.errorcode, self.error = ComuError.parse(errorcode)
        Exception.__init__(self, self.error)
    
    @staticmethod
    def parse(errorcode):
        """
        Create an error instance.
        'errorcode' argument is treated as a numeric/string representation
        of error code, or, if not recognised, as an error message.
        """
        try:
            errorcode = int(errorcode)
#            if errorcode not in ComuError.errors.keys():
#                raise ValueError
            return errorcode, ComuError.errors[errorcode]
        except (ValueError, KeyError):
            return -1, str(errorcode)
    
    def __repr__(self): return "%d: %s" % (self.errorcode, self.error)
    
    NO_ERROR = 0
    NOT_IN_LOCAL_LIST = 1
    NO_REMOTE_RESPONSE = 2
    NOT_IN_REMOTE_LIST = 3
    UNKNOWN_QUERY = 4
    OTHER_ERROR = -1

    # error messages
    errors = {
        0: 'no error',
        1: 'instrument not found in local list',
        2: 'no remote Communicator response',
        3: 'instrument not found in remote list',
        4: 'query keyword not found',
        -1: 'other error'
    }

class Request(object):
    """
    A Comunicator request to an instrument.
    Can be specified as '<instrument><space><message>' or as two separate strings.
    """
    def __init__(self, instrument, message = None):
        """
        Create a Comunicator request to an instrument.
        Can be specified as '<instrument><space><message>' or as two separate strings.
        """
        if message is None:
            self.instrument, self.message = instrument.split(' ', 1)
        else:
            self.instrument = instrument
            self.message = message
    
    def __repr__(self): return "%s: '%s'" % (self.instrument, self.message)
    def encode(self): return Codec.encode_str(self.instrument) + Codec.encode_str(self.message)
    
    def is_query(self):
        "Return True if this request is a query (starts with '?')";
        return self.message[0] == '?'
    
    def same(self, other):
        "Return True if 'other' is the same request as self (case insensitive)";
        return self.instrument.upper() == other.instrument.upper() and self.message.upper() == other.message.upper()

class Response(object):
    def __init__(self, response, errorcode):
        self.response = response
        self.errorcode, self.error = ComuError.parse(errorcode)

    def encode(self): return Codec.encode_str(self.response) + Codec.encode_str(str(self.errorcode))
    def __nonzero__(self): return self.errorcode == 0
    def __repr__(self): return "'%s' (%d: %s)" % (self.response, self.errorcode, self.error)
    def check(self):
        "Raise an exception if the response was not valid"
        if not self:
            raise ComuError(self.error)
        "%d: %s" % (self.errorcode, self.error)
    def copy(self): return Response(self.response, self.error if self.errorcode == -1 else self.errorcode)

class ConfigEntry(dict):
    """
    Config file entry about a certain server:    
    ip - IP address of the computer
    port - IP port
    name - the name of the server (also known as 'type')
    
    A 'ConfigEntry' inherits from a 'dict' class, a dictionary represents
    instruments served by this server.
    instruments - a dictionary with instrument names as keys and comments as values
    """
    def __init__(self, ip, port, name, comment, instruments):
        """
        Define a config entry about a certain server
        ip - IP address of the computer
        port - IP port
        name - the name of the server (also known as 'type')
        instruments - a dictionary with instrument names as keys and comments as values
        """
        self.ip = ip
        self.port = port
        self.name = name
        self.comment = comment
        dict.__init__(self, instruments)
        
    def register_instrument(self, name, comment): self[name] = comment
    
    def instruments(self): "Return a list of instruments"; return self.keys()

    def islocal(self, local_name):
        "Return True if this entry is about a server on the local computer and bears a specified name"
        return socket.gethostbyname(self.ip) == socket.gethostbyname(socket.gethostname()) and self.name.lower() == local_name.lower()
    
    def __repr__(self):
        s = '%s @ %s:%d (%s)' % (self.name, self.ip, self.port, self.comment)
        for i in self.instruments():
            s += '\n\t%s (%s)' % (i, self[i])
        return s

class Config(dict):
    """
    A loaded configuration.
    'config_file' - path to the configuration file.
    'config_modified' - timestamp of last config modification
    'local_name' - name of the server to be addressed locally. Not supported yet.
    'entries' - server entries
    'timeout' - timeout in seconds for communication by TCP
    
    'Config' inherits from 'dict'. The dictionary contains servers by instrument names.
    """
    def __init__(self, config_file, local_name = 'python', timeout = 0.2):
        self.config_file = config_file
        self.local_name = local_name
        self.entries = []
        self.timeout = timeout
        self.read_config()

    def refresh_config(self):
        if self.isoutdated():
            self.read_config()

    def read_config(self):
        del self.entries[:]        
        self.clear()
        self.config_modified = os.path.getmtime(self.config_file)
        handle = open(self.config_file, 'r')
        try:
            entry = None
            lineno = 0
            for line in handle:
                lineno += 1
                line = line.strip()
                if len(line) < 1:
                    continue                
                if line[0] == '#':
                    break
                if line[0] == ':':
                    tokens = line.split(':',3)
                    if len(tokens) < 3:
                        raise ValueError("Error parsing '%s': illegal server definition in line %d" % (self.config_file, lineno))
                    ip = tokens[1]
                    if len(tokens) == 3:
                        port,comment = tokens[2].split(' ', 1)
                        name = 'LABVIEW'
                    else:
                        port = tokens[2]
                        name,comment = tokens[3].split(' ', 1)
                    
                    try:
                        port = int(port)
                    except:
                        raise ValueError("Error parsing '%s': non-numeric port number in line %d" % (self.config_file, lineno))
                    entry = ConfigEntry(ip, port, name, comment, {})
                    self.entries.append(entry)
                else:
                    tokens = line.split(' ', 1)
                    name = tokens[0]
                    comment = '' if len(tokens) < 2 else tokens[1]
                    entry.register_instrument(name, comment)
                    self[name] = entry
        finally:
            handle.close()
    
    def isoutdated(self):
        "Return True is a logfile on disk is newer than this object"
        return os.path.getmtime(self.config_file) > self.config_modified

    def __repr__(self):
        if len(self.entries) < 1:
            return 'Empty Comunicator Config'
        s = 'Comunicator Config:'
        for e in self.entries:
            s += '\n' + repr(e)
        return s

    def talk(self, requests, timeout = None, reduce_repeated_queries=True):
        """
        A client side communication.
        
        'requests' can contain 'comu.Request' objects or strings of a form '<instrument><space><request>'.
        
        If specified, timeout overrides the default config timeout
        If 'reduce_repeated_queries' is True, similar queries are not asked
        multiple times, instead the response is multiplied locally.
        Commands (requests that do not start with '?') are sent as they are
        specified on the list.
        """
        self.refresh_config()
        
        if timeout is None:
            timeout = self.timeout
        
        requests = [Request(r) if isinstance(r, basestring) else r for r in requests]
        N = len(requests)
        responses = [None]*N
        
        for server in self.entries:
            # create a map of [n]->[n,...]
            # '[n]' - a index of a request in 'requests', only take those addressed to instruments at a server 'server'
            # '[n,...]' - an array containing the index of 'r' and indices of repetiions of the n-th request
            # if 'reduce_repeated_queries' is True, do not create new map entries for repeated queries,
            # but instead add their indices to the array '[n]' pointed by the first occurence of a repeated
            # request in 'requests'.
            # (Hope this is clear)
            map = {}
            for n, req in enumerate(requests):
                if req.instrument in server.instruments():
                    registered_request = False
                    if reduce_repeated_queries and req.is_query():
                        for other_n in map.keys():
                            if req.same(requests[other_n]):
                                map[other_n].append(n)
                                registered_request = True
                                break
                    if not registered_request:
                        map[n] = [n]
            
            if len(map) < 1:
                continue
            
            # now talk to the server
            requests_subset = [requests[n] for n in map.keys()]
            try:
                responses_subset = self.talk_to(server, requests_subset, timeout)
            except Exception, e:
                responses_subset = [Response('', str(e))] * len(requests_subset)

            print 'Server %s' % server.ip
            
            # now fill the corresponding elements in the array with all the responses
            for i in range(len(map)):
                n = map.keys()[i]
                nn = map[n]
                resp = responses_subset[i]
                for k in nn:
                    responses[k] = resp.copy()
                print '  request: %s, response: %s, indices: %s' % (requests[n], resp, nn)
            del requests_subset, responses_subset, map, resp, i, n, nn, k, req
        
        # now add responses to requests to unregistered instruments.
        for n in range(N):
            if responses[n] is None:
                responses[n] = Response('', ComuError.NOT_IN_LOCAL_LIST)
        
        del requests
        
        return responses

    def talk_to(self, server, requests, timeout):
        """
        Communicate with a server. The transport is determined automatically
        """
        if server.islocal(self.local_name):
            return talk_to_local(server, requests, timeout)
        else:
            return talk_by_tcp(server, requests, timeout)

class Codec(object):
    """
    Here are static methods for encoding/decoding data for transport over TCP/IP    
    """
    
    magic = 'comu'
    header_length = 12 # 'comu + length'
    length_length = 8 # length of encoded length

    @staticmethod
    def encode_length(l):
        if l < 0 or l > 99999999:
            raise ValueError('Comunicator Error: length %d is out of bounds and can not be encoded' % l)
        return "%08d" % l
    
    @staticmethod
    def encode_str(s): "encode a string"; return Codec.encode_length(len(s)) + s

    @staticmethod
    def decode_length(str): return int(str)

    @staticmethod
    def encode_pairs(pairs):
        "Encode a list of either 'Request's or 'Response's into a string"
        s = ''
        for p in pairs:
            s += p.encode()
        return Codec.magic + Codec.encode_str(s)
    
    @staticmethod
    def determine_length(header):
        """
        Parse header and determine the length of the rest of messages
        A 'ValueError' is raised if header is not recognised.
        """
        if len(header) != Codec.header_length or header[:4] != Codec.magic:
            raise ValueError('Unsupported data format')
        return Codec.decode_length(header[4:])

    @staticmethod
    def decode_pairs(message, pairtype = Response):
        "Parse the message (the header should be truncated) are return pairs of given type (either 'Reponse' or 'Request')"
        pairs = []
        while len(message) > Codec.length_length:
            l = Codec.decode_length(message[:Codec.length_length])
            s1 = message[Codec.length_length:Codec.length_length + l]
            message = message[Codec.length_length + l:]
            if len(message) > Codec.length_length:
                l = Codec.decode_length(message[:Codec.length_length])
                s2 = message[Codec.length_length:Codec.length_length + l]
                message = message[Codec.length_length + l:]
                pairs.append(pairtype(s1,s2))
        return pairs

def talk_by_tcp(server, requests, timeout = 0.200):
    """
    communicate with a given 'server' via TCP.
    'requests' is an array of 'Request's.
    'timeout' is TCP timeout in seconds.
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        s.settimeout(timeout)
        s.connect((server.ip, server.port))
        s.send(Codec.encode_pairs(requests))
        data = s.recv(Codec.header_length)
        data = s.recv(Codec.determine_length(data))
        return Codec.decode_pairs(data, pairtype=Response)
    except Exception, e:
        return [Response('', ComuError.NO_REMOTE_RESPONSE)] * len(requests)
    finally:
        s.close()

def talk_to_local(server, requests, timeout = 0.200):
    """
    Local communication stub.
    Not supported yet.
    """
    return [Response('', 'Local communication is not supported yet')] * len(requests)
