"""
lib/simfit.py

Simultaneous Larentzian Fit

Lev, 8 May 2009
"""

import numpy, nmr, nmrfit, math
import scipy.optimize
from pylab import *

def delayed_lorentzian(f0, A, T2, phi, delay, f):
    """
    Return spectrum of a Lorentzian described by (f0, A, T2, phi)
    delayed by 'delay' at given frequencies 'f'
    """
    return nmrfit.Lorentzian.l(f0, A*math.exp(-delay/T2), T2, phi + 2*math.pi*f0*delay, f)

def sim_lorentzian_residuals(param, spectra):
    """
    Calculate residuals for the fit
    """
    f0, A, T2, phi = param
    res = []
    for delay in spectra.keys():
        spectrum = spectra[delay]
        res = numpy.concatenate([res, abs(spectrum.w - delayed_lorentzian(f0, A, T2, phi, delay, spectrum.f))])
    return res

def simfit_lorentzian(trace, N, T0, deltaT, Npad = 2**15, fmin=240e3, fmax=330e3, fwing=5e3):
    """
    Fit a lorentzian to a spectrum of 'trace'. Produce 'N' spectra using 'Npad' long
    slices of the 'trace' starting at 0, deltaT, 2*deltaT, ..., (N-1)*deltaT
    and fit simultaneously.
    
    First fit in a range 'fmin' to 'fmax' and make a guess. Then choose a frequency range
    of 'guessed f0' +- 'fwing'.
    """
    g = nmrfit.Lorentzian.guess(trace.period(T0).fft(Npad=Npad, fmin=fmin, fmax=fmax))
    fmin = g.f0 - fwing
    fmax = g.f0 + fwing
    
    spectra = {}    
    for i in range(0,N):
        spectra[T0 + i*deltaT] = trace.period(T0 + i*deltaT).fft(Npad=Npad, fmin=fmin, fmax=fmax)

    g = nmrfit.Lorentzian.fit(spectra[T0])
    g.A *= math.exp(T0 / g.T2)
    g.phi = (g.phi - 2*pi*g.f0*T0) % (2*pi)
    
    res = scipy.optimize.leastsq(sim_lorentzian_residuals, [g.f0, g.A, g.T2, g.phi], args=(spectra,), xtol=1e-12, maxfev=int(1e5))
    fit = nmrfit.Lorentzian(f0=res[0][0], A=res[0][1], T2=res[0][2], phi=res[0][3])
    if res[1] == 1:
        plot_slices(fit, spectra)
        return fit, spectra
    else:
        raise FitError(fit)

def plot_slices(fit, spectra, modulus=False):
    figure()
    for delay in sorted(spectra.keys()):
        spectrum = spectra[delay]
        if modulus:
            p,=plot(spectrum.f, spectrum.abs(), '.', label='%d us' % (delay*1e6))
            plot(spectrum.f, abs(delayed_lorentzian(fit.f0, fit.A, fit.T2, fit.phi, delay, spectrum.f)), '-', color=p.get_color())
        else:
            subplot(211)
            p,=plot(spectrum.f, spectrum.re(), '.', label='%d us' % (delay*1e6))
            plot(spectrum.f, real(delayed_lorentzian(fit.f0, fit.A, fit.T2, fit.phi, delay, spectrum.f)), '-', color=p.get_color())
            subplot(212)
            plot(spectrum.f, spectrum.im(), '.', label='%d us' % (delay*1e6), color=p.get_color())
            plot(spectrum.f, imag(delayed_lorentzian(fit.f0, fit.A, fit.T2, fit.phi, delay, spectrum.f)), '-', color=p.get_color())
