"""
massfft.py

Caching fourier transforms of NMR signals
"""

import numpy, nmr, flib, scopelog, os, gc, logging
from urllib import quote

class FFTSettings(object):
    """
    Fourier transform settings
    """
    __presets__ = {
        'a': (252e-6, 2**24, 1e6, 1.1e6, 2, 0.1, 0),
        'b': (260e-6, 2**24, 1e6, 1.1e6, 2, 0.1, 0),
        'c': (25.252e-3, 2**24, 1e6, 1.1e6, 2, 0.1, 0),
        'd': (50.252e-3, 2**24, 1e6, 1.1e6, 2, 0.1, 0),
        'e': (5.252e-3, 2**24, 1e6, 1.1e6, 2, 0.1, 0),
        'f': (10.252e-3, 2**24, 1e6, 1.1e6, 2, 0.1, 0),
        'n': (252e-6, 2**21, 1e6, 1.1e6, 2, 0.1, 0),
        'N': (300e-6, 2**20, 0.1e6, 2.0e6, 2, 0.1, 0),
 #       'b100_20': (165e-6, 2**24, 1e6, 1.1e6, 2, 0.1, 20e-6),
 #       'b100_10': (165e-6, 2**24, 1e6, 1.1e6, 2, 0.1, 10e-6),
        'A': (252e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 0, 1e6, 1.1e6, 1.04e6, 1.065e6, 1046100, 32.2),
        'A_': (260e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 0, 1e6, 1.1e6, 1.04e6, 1.065e6, 1046100, 32.2),
        'A_270us': (270e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 0, 0, 0, 0, 0, 1046100, 32.2),
        'A_270us_no_tuned_corr': (270e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 0, 0, 0, 0, 0, 1046100, 32.2),
        'A_10ms_270us': (270e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 10e-3, 0, 0, 0, 0, 1046100, 32.2),
        'A_10ms_270us_no_tuned_corr': (270e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 10e-3, 0, 0, 0, 0, 0, 0),
#        'A24': (252e-6, 2**23, 1.047e6, 1.057e6, 2, 0.1, 0, 1e6, 1.1e6, 1.04e6, 1.065e6, 1046100, 32.2),
        'B': (160e-6, 2**23, 1.050e6, 1.062e6, 2, 0.1, 0, 1e6, 1.1e6, 1.04e6, 1.075e6, 1046100, 32.2),
        'B24': (160e-6, 2**24, 1.050e6, 1.062e6, 2, 0.1, 0, 1e6, 1.1e6, 1.04e6, 1.075e6, 1046100, 32.2),
        'B_10ms': (160e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 10e-3, 1e6, 1.1e6, 1.04e6, 1.075e6, 1046100, 32.2),
        'B_10ms_180us': (180e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 10e-3, 0, 0, 0, 0, 1046100, 32.2),
        'B_10ms_180us_no_tuned_corr': (180e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 10e-3, 0, 0, 0, 0, 0, 0),
        'B_10ms_270us': (270e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 10e-3, 0, 0, 0, 0, 1046100, 32.2),
        'B_10ms_270us_no_tuned_corr': (270e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 10e-3, 0, 0, 0, 0, 0, 0),
        'B_noT2f_270us_no_tuned_corr': (270e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 0, 0, 0, 0, 0, 0, 0),
        'B_noT2f_180us_no_tuned_corr': (180e-6, 2**21, 1.050e6, 1.062e6, 2, 0.1, 0, 0, 0, 0, 0, 0, 0),
        'N16': (300e-6, 2**16, 10e3, 3e6, 1, 0, 0),
        'N16_0.15s': (0.15, 2**16, 1.0e6, 1.1e6, 1, 0, 0)
    }

    def __init__(self, truncate, Npad, fmin=1e6, fmax=1.1e6, BGmethod=2, BGavertime=0.1, T2filter=0,\
                ringing_fmin=0, ringing_fmax=0, ringing_no_fmin=0, ringing_no_fmax=0, ringing_f0=0, ringing_Q=0):
        """
        truncate - time domain trace truncation from the start [s]
        Npad - number of points feeded to FFT achieved by zero padding or truncation from the end
        fmin, fmax - frequency range of saved spectrum [Hz]
        BGmethod - transient background subtraction method:
            0 - none
            1 - subtract a mean of samples
            2 - subtract a mean of samples in the final portion of the signal 'BGavertime' [s] long
                this method uses the end of the available signal even if Npad is low enough
                to cause truncation for FFT.
        BGavertime - length of signal portions used for transient background subtraction.
        T2filer - T2* filter time constant [s], not applied if <= 0.
        ringing_fmin, ringing_fmax - frequency range used for subtracting tuned circuit ringing [Hz],
        not applied if ringing_fmax <= 0
        ringing_no_fmin, ringing_no_fmax - frequency range excluded from evaluation of tuned circuit ringing amplitude [Hz]
        ringing_f0, ringing_Q - parameters of tuned circuit for ringing subtraction and tuned spectrometer correction.
        """
        self.truncate = truncate
        self.Npad = Npad
        self.fmin = fmin
        self.fmax = fmax
        self.BGmethod = BGmethod
        self.BGavertime = BGavertime
        self.T2filter = T2filter
        self.ringing_fmin = ringing_fmin
        self.ringing_fmax = ringing_fmax
        self.ringing_no_fmin = ringing_no_fmin
        self.ringing_no_fmax = ringing_no_fmax
        self.ringing_f0 = ringing_f0
        self.ringing_Q = ringing_Q

    def useT2filter(self): "Return True if settings call for a T2* filter"; return self.T2filter > 0
    def subtract_ringing(self): "Return True if settings call for ringing subtraction"; return self.ringing_fmax > 0
    def correct_tuned(self): "Return True if settings call for correction for gain of a tuned spectrometer"; return self.ringing_f0 > 0 and self.ringing_Q > 0

    def totuple(self): return (self.truncate, self.Npad, self.fmin, self.fmax, self.BGmethod, self.BGavertime, self.T2filter, \
                               self.ringing_fmin, self.ringing_fmax, self.ringing_no_fmin, self.ringing_no_fmax, self.ringing_f0, self.ringing_Q)

    def __str__(self):
        """
        Create a string representation of the settings.
        Used as a folder name for saving spectra.
        """
#        for preset, settings in FFTSettings.__presets__.items():
#            if self.totuple() == settings:
#                return preset
        str = '%.8fs_%d_%d-%dHz_%d_%.8fs' % (self.truncate, self.Npad, self.fmin, self.fmax, self.BGmethod, self.BGavertime)
        if self.useT2filter():
            str += '_%.8fs' % self.T2filter
        
        if self.subtract_ringing():
            str += '_%d-%dHz_no%d-%dHz_f%dHz_Q%.3f' % (self.ringing_fmin, self.ringing_fmax, \
                    self.ringing_no_fmin, self.ringing_no_fmax, self.ringing_f0, self.ringing_Q)
        elif self.correct_tuned():
            str += '_0-0Hz_no0-0Hz_f%dHz_Q%.3f' % (self.ringing_f0, self.ringing_Q)
            
        return str

    def __repr__(self):
        return "FFTSettings(%g us truncate, %d points, %s T2 filter, %g-%g kHz)" % (self.truncate * 1e6, self.Npad, \
                "%g ms" % (self.T2filter*1e3) if self.useT2filter() else "no", \
                self.fmin/1e3, self.fmax/1e3)
    
    @staticmethod
    def preset(settings):
        """
        Convert a preset into FFTSettings and keep FFTSettings untouched.
        
        if argument is a string, load FFT settings form a preset
        if argument is an 'FFTSettings' object, return the argument itself
        otherwise raise a ValueError
        """
        if isinstance(settings, basestring):
            return FFTSettings(*FFTSettings.__presets__[settings])
        elif isinstance(settings, FFTSettings):
            return settings
        else:
            raise ValueError("can not convert argument into FFTSettings")

    def report_basic_settings(self):
        return 'FFT Settings: %d points, %.2f us truncation, %s T2 filter.' % (self.Npad, \
                    self.truncate * 1e6, \
                    (' %.3fms' % (self.T2filter * 1e3)) if self.T2filter > 0 else 'no')
    
    def report_ringing_settings(self):
        if self.subtract_ringing():
            return 'Tuned circuit ringing subtraction: f0 = %.2f kHz, Q = %.2f, fit to frequency ranges %.3f-%.3f and %.3f-%.3f kHz.' % \
                    (self.ringing_f0 / 1e3, self.ringing_Q, self.ringing_fmin / 1e3, self.ringing_no_fmin / 1e3,
                    self.ringing_no_fmax / 1e3, self.ringing_fmax / 1e3)
        elif self.correct_tuned():
            return 'Tuned circuit gain correction: f0 = %.2f kHz, Q = %.2f' % (self.ringing_f0 / 1e3, self.ringing_Q)
        else:
            return 'No tuned circuit correction or ringing subtraction'
    
    def report_settings(self, prefix = '# '):
        str = prefix + self.report_basic_settings()
        if self.correct_tuned() or self.subtract_ringing():
            str += '\n' + prefix + self.report_ringing_settings()
        return str

slash_escape = '%%%02x' % (ord('/'))

def path2filename(settings, signal_filename, mkdir=False):
    """
    Construct a filename for a spectrum bearing information both about
    time domain trace 'signal_filename' (as a filename)
    and Fourier transform settings (as a directory name).
    If 'mkdir' is True a directory is created if it does not exist
    """
    prefix = str(FFTSettings.preset(settings))
    d = os.path.join('/data/exp/run20/sf/fft', prefix)
    if mkdir and not os.path.exists(d):
        os.mkdir(d, 0755)
    return os.path.join(d, quote(signal_filename).replace('/', slash_escape).replace('\\', slash_escape))

def dofft(settings, signal_filename, smplrate=None):
    """
    Compute Fourier transform of 'signal_filename' according to 'settings'.
    The resulting spectrum is both saved to the cache and returned.
    """
    try:
        if isinstance(settings, basestring):
            settings = FFTSettings.preset(settings)
        
        s = scopelog.loadbinary(signal_filename, smplrate=smplrate)
        trace_min = min(s.s)
        trace_max = max(s.s)
        if settings.BGmethod == 2:
            s.s -= numpy.mean(s.period(start = s.timespan() - settings.BGavertime).s)
        s1 = s.period(settings.truncate)
        w = s1.fft(Npad=settings.Npad, T2filter=settings.T2filter)
        del s1

        if settings.subtract_ringing() or settings.correct_tuned():
            ringing = process_tuned_spectrometer(w, settings.ringing_fmin, settings.ringing_fmax,\
                            settings.ringing_no_fmin, settings.ringing_no_fmax,\
                            settings.ringing_f0, settings.ringing_Q)
        else:
            ringing = 0
        
        w = w.frange(settings.fmin, settings.fmax)
        
        header = {
            'starttime': s.info.starttime,
            'finaltime': s.info.finaltime,
            'Nav': s.info.Nav,
            'delay': s.info.delay,
            'Npad': settings.Npad,
            'truncate': settings.truncate,
            'fmin': settings.fmin,
            'fmax': settings.fmax,
            'T2filter': settings.T2filter,
            'ringing_fmin': settings.ringing_fmin,
            'ringing_fmax': settings.ringing_fmax,
            'ringing_no_fmin': settings.ringing_no_fmin,
            'ringing_no_fmax': settings.ringing_no_fmax,
            'ringing_f0': settings.ringing_f0,
            'ringing_Q': settings.ringing_Q,
            'ringing_amplitude': ringing,
            'trace_max': trace_max,
            'trace_min': trace_min
        }
        
        w.f = numpy.asarray(w.f, dtype=numpy.float32)
        w.w = numpy.asarray(w.re(), dtype=numpy.float32) + 1j*numpy.asarray(w.im(), dtype=numpy.float32)
        
        try:
            fft_path = path2filename(settings, signal_filename, True)
            w.savebinary(fft_path, header)
        except Exception, e:
            logging.warn("Error saving a spectrum to '%s': %s" % (fft_path, e))
        
        w.info = s.info
        w.starttime = s.info.starttime
        w.finaltime = s.info.finaltime
        w.Nav = s.info.Nav
        w.delay = s.info.delay
        w.truncate = settings.truncate
        w.Npad = settings.Npad
        w.T2filter = settings.T2filter
        w.ringing_fmin = settings.ringing_fmin
        w.ringing_fmax = settings.ringing_fmax
        w.ringing_no_fmin = settings.ringing_no_fmin
        w.ringing_no_fmax = settings.ringing_no_fmax
        w.ringing_f0 = settings.ringing_f0
        w.ringing_Q = settings.ringing_Q
        w.ringing_ampltidue = ringing
        w.trace_max = trace_max
        w.trace_min = trace_min
        return w
    finally:
        gc.collect()

def getfft(settings, signal_filename, indices=None, smplrate=None, redo=False, flux_jump_filter_Vpp=None):
    """
    Return a Fourier transform of signal 'signal_filename' or an average
    of '[signal_filename % i for i in indices]' according to FFT settings
    'settings'. Sampling rate can be specified using 'smplrate' [S/s] if
    this information is not available along with the data.
    
    A cache of spectra is maintained. If a spectrum exists in the cache,
    Fourier transform is not performed, unless 'redo' is set to True.
    If an average is asked, signals are FFT'ed separately and spectra
    are averaged afterwards for versatility if subsequent calls.
    
    if not None, 'flux_jump_filter_Vpp' is used to select only those files,
    in which the voltage span is no more than the value of this argument.
    If all traces were discarded, an exception is thrown.
    """
    try:
        if indices is None:
            filename = path2filename(settings, signal_filename, False)
            if os.path.exists(filename) and not redo:
                spectrum = nmr.Spectrum.loadbinary(filename)
                spectrum.info = scopelog.getinfo(signal_filename)
            else:
                spectrum = dofft(settings, signal_filename, smplrate)
            if flux_jump_filter_Vpp is not None and not (spectrum.trace_swing_Vpp() <= flux_jump_filter_Vpp):
                raise ValueError('All traces discarded by the flux jump filter')
            return spectrum
        else:
            if len(indices) < 1:
                raise ValueError('No filename indices given')
            new_indices = list(indices)
            average = 0
            N = 0
            trace_max = None
            trace_min = None
            for i in indices:
                spectrum = getfft(settings, signal_filename % i, smplrate=smplrate, redo=redo)
                if flux_jump_filter_Vpp is not None and not (spectrum.trace_swing_Vpp() <= flux_jump_filter_Vpp):
                    new_indices.remove(i)
                    continue
                try:
                    Nav = spectrum.info.safeNav()
                except:
                    Nav = 1
                average += Nav * spectrum
                N += Nav
                try:
                    if trace_max is None or trace_min is None:
                        trace_max = spectrum.trace_max
                        trace_min = spectrum.trace_min
                    else:
                        trace_max = numpy.maximum(trace_max, spectrum.trace_max)
                        trace_min = numpy.minimum(trace_min, spectrum.trace_min)
                except:
                    trace_max = trace_min = numpy.NaN                    
            indices = new_indices
            if flux_jump_filter_Vpp is not None and len(indices) < 1:
                raise ValueError('All traces discarded by the flux jump filter')
                
            average /= N
            average.info = scopelog.getinfos(signal_filename, indices)
            try:
                average.starttime = average.info.starttime
                average.finaltime = average.info.finaltime
            except:
                average.starttime = numpy.NaN
                average.finaltime = numpy.NaN
            average.Nav = N
            
            # cope fft settings to the averaged spectrum from the last loaded spectrum
            average.truncate = spectrum.truncate
            average.delay = spectrum.delay
            average.Npad = spectrum.Npad
            average.trace_max = trace_max
            average.trace_min = trace_min
            
            # if a trace does not contain this information
            try:
                average.T2filter = spectrum.T2filter
            except:
                average.T2filter = 0
            
            try:
                average.ringing_fmin = spectrum.ringing_fmin
                average.ringing_fmax = spectrum.ringing_fmax
                average.ringing_no_fmin = spectrum.ringing_no_fmin
                average.ringing_no_fmax = spectrum.ringing_no_fmax
                average.ringing_f0 = spectrum.ringing_f0
                average.ringing_Q = spectrum.ringing_Q
                average.ringing_amplitude = spectrum.ringing_amplitude
            except:
                average.ringing_fmin = 0
                average.ringing_fmax = 0
                average.ringing_no_fmin = 0
                average.ringing_no_fmax = 0
                average.ringing_fmin = 0
                average.ringing_f0 = 0
                average.ringing_Q = 0
                average.ringing_amplitude = 0 + 0j
            
            return average
    finally:
        gc.collect()    

def process_tuned_spectrometer(spectrum, ringing_fmin=1.0e6, ringing_fmax=1.1e6, \
        ringing_no_fmin=1.04e6, ringing_no_fmax=1.065e6, ringing_f0=1046100, ringing_Q=32.2):
    """
    Process a 'spectrum' acquired in a tuned spectrometer. Spectrum is modified in place
    by dividing it by an instrumental function of the spectrometer.

    Then, if 'ringing_fmax' is positive, the ringing amplitude is evaluated:
    
    A mean value in a range (ringing_fmin, ringing_fmax) except for a range (ringing_no_fmin, ringing_no_fmax)
    is calculated and subtracted from the spectrum. This is ringing in the input circuit.
    The function returns the amplitude this mean value (e.g. amplitude of ringing).
    """
    spectrum.w *= (1 + 2j*ringing_Q*(spectrum.f/ringing_f0 - 1))
    if ringing_fmax <= 0:
        return 0.0
    wr = spectrum.frange(ringing_fmin, ringing_fmax).removepeaks()
    if ringing_no_fmax > 0:
        wr = wr.nofrange(ringing_no_fmin, ringing_no_fmax)
    if len(wr) <= 0:
        raise ValueError('FFT settings define empty frequency range for tuned circuit ringing subtraction')
    ringing = numpy.mean(wr.w)
    spectrum.w -= ringing
    del wr
    return ringing

def get_sf_fft(fft_settings, filename, indices, smplrate=None, \
        ringing_fmin=1.0e6, ringing_fmax=1.1e6, \
        ringing_no_fmin=1.04e6, ringing_no_fmax=1.065e6, ringing_f0=1046100, ringing_Q=32.2):
    """
    Load a spectrum (see 'getfft') and, if no tuned circuit instrumental function
    deconvolution and ringing subtraction was a part of fft, do it now.
    Otherwise no extra processing is performed on the signal.
    """
    w = getfft(fft_settings, filename, indices, smplrate)
    if w.ringing_f0 <= 0:
        ringing = process_tuned_spectrometer(w, ringing_fmin, ringing_fmax,\
                            ringing_no_fmin, ringing_no_fmax, ringing_f0, ringing_Q)
        w.ringing_fmin = ringing_fmin
        w.ringing_fmax = ringing_fmax
        w.ringing_no_fmin = ringing_no_fmin
        w.ringing_no_fmax = ringing_no_fmax
        w.ringing_f0 = ringing_f0
        w.ringing_Q = ringing_Q
        w.ringing_amplitude = ringing
    return w

##def walkdir(settings, topdir, rewrite=False, minsize=2**16):
##    """
##    FFT and save spectra for all signals in a given directory 'topdir' including any subdirectories.
##    'rewrite' controls rewriting existing spectra.
##    """
##
##    for root, dirs, files in os.walk(topdir):
##        for file in files:
##            try:
##                filename = os.path.join(root, file)
##                if os.path.getsize(filename) >= minsize:
##                    (dofft if rewrite else getfft)(settings, filename)
##            except Exception, e:
##                print 'Error %s while processing %s' % (e, filename)
##            gc.collect()

def walkdir(settings, topdir, redo=False, ramble=False, minsize=2**20, skip_filenames=['logT.dat']):
    """
    FFT and save spectra for all signals in a given directory 'topdir' including any subdirectories.
    Boolean argument 'redo' controls recalculating existing spectra.
    if 'ramble' is true, keep processing if new files appear inside 'topdir' while working.
    files shorter than 'minsize' or with names listed in 'skip_filenames' (case insensitive) are skipped.
    
    return a list of files visited
    """
    skip_filenames = [s.lower() for s in skip_filenames]
    
    processed = []

    count = 1
    round = 1
    while count > 0 and (ramble or round == 1):
        count = 0
        for root, dirs, files in os.walk(topdir):
            for file in files:
                try:
                    filename = os.path.join(root, file)
                    if filename in processed:
                        continue
                    processed.append(filename)
                    count += 1
                    if os.path.basename(filename).lower() in skip_filenames or os.path.getsize(filename) < minsize:
                        print 'skip %s' % filename
                        continue
                    getfft(settings, filename, redo=redo)
                except Exception, e:
                    print 'Error %s while processing %s' % (e, filename)
                gc.collect()
        print 'round %d: %d files processed' % (round, count)
        round += 1

    return processed
