"""
nmrfit.py

Fit NMR spectra

SpectrumModel - base class for all fit models
FilteredStep - models a spectrum of a step function passed through
a high-pass filter

FitError - exception, raised when a fit diverges

12 September 2008, Lev, based on Michael's stunning code
"""

# Changelog
#
# Initial version 18 June 2008
#
# June 26, 2008
# Add NoisePowerLorentzian fit
#
# September 12, 2008
# Add Q and linewidth to derived properties of a Lorentzian fit
#
# January 24, 2010
# add 'fit_and_cut' method to DoubleLorentzian
#
# January 28, 2010
# add an option to supply a T2 guess to Lorentzian, ProperLorentzian and DoubleLorentzian fits.
# add 'peak_height()' method to Lorentzian
#
# Fev 15, 2010
# add a 'SincSpectrum' spectrum for a sine wave pulse
# modify the check of a conversion of the fits for all models: accept 'ier=1,2,3,4'
# when using 'scipy.optimize.leastsq'

import flib
import nmr
import numpy
import math
import scipy.optimize

pi = math.pi

class FitError(Exception):
    """
    This exception is raised when a fit diverged
    'self.fit', if not None, points to the fit results
    """
    def __init__(self, fit = None):
        self.fit = fit

    def __repr__(self): return 'Fit Error'

class SpectrumModel(object):
    """
    A base class for models used to fit spetra    
    """

    def model(self, f):
        """
        Construct an nmr.Spectrum representation of the fit at at given
        frequencies. 'f' is expected to be either an array of frequencies
        or an nmr.Spectrum object.
        """
        if isinstance(f, nmr.Spectrum):
            f = f.f
        return nmr.Spectrum(f, self.calculate(f))

    def subtract(self, spectrum):
        """
        Return a difference between an NMR.Spectrum object 'spectrum' and
        the fit model.
        """
        return spectrum - self.model(spectrum)

    def calculate(self, f):
        """
        This function is to be overriden in child classes. It returns an array
        of frequency domain complex amplitudes for given array of frequencies.
        """        
        return numpy.zeros((len(f)))

class FilteredStep(SpectrumModel):
    """
    This model describes a step function going through a high-pass filter.
    The model is:
    
    w(f) = B/(f1 + i*f)
    
    where B is the step amplitude, f1 is the filter -3dB frequency
    and i is the imaginary unit. The fit assumes B to be real.
    """
    
    def __init__(self, B, f1, fit_f1):
        self.B = B
        self.f1 = f1
        self.fit_f1 = fit_f1

    def calculate(self, f): return self.B / (self.f1 + 1j*f)

    @staticmethod
    def fit(spectrum, f1 = None):
        """
        Fit a spectrum to a FilteredStep. If filter cutoff frequency 'f1' is
        not specified, it is inferred from the fit.
        """
        if f1 is None:
            f1 = -numpy.mean(spectrum.f * spectrum.re() / spectrum.im())
            fit_f1 = True
        else:
            fit_f1 = False
        
        B = numpy.mean(f1 * spectrum.re() - spectrum.f * spectrum.im())
        
        return FilteredStep(B, f1, fit_f1)

    @staticmethod
    def fitRun20LF(spectrum, usePXI=True):
        """
        Take care of transient in the ULT Run 20 Lowfield NMR due to a filtered
        FLL reset
        """
        return FilteredStep.fit(spectrum.frange(fmin=500, fmax=1e5), f1 = 2000 if usePXI else 8100)

    def __repr__(self): return 'FilteredStep(f1=%g, B=%g)' % (self.f1, self.B)

def distilRun20LF(trace, usePXI=True, truncate = 30e-6, Npad=2**17):
    """
    Take care of mismatch between the start and the end of the data 
    and subtract eddy current transient in the ULT Run 20 Lowfield NMR
    
    Return a modified spectrum
    """
    spectrum = trace.period(truncate).range(0, Npad).subtract_linearbg(1).fft(Npad=Npad)
    transient = FilteredStep.fitRun20LF(spectrum, usePXI)
    
    return transient.subtract(spectrum)

class Lorentzian(SpectrumModel):
    """
    This model describes a phase-rich Lorentzian:
    
    w(f) = 1/2 * A * T2 * exp(i*phi) / (1 + 2*pi*i*(f-f0)*T2)

    f0 - resonance frequency
    A - amplitude
    T2 - relaxation time
    phi - phase
    
    w(f) is a Fourier Transform of:
    
    u(t) = A * cos(2*pi*f0 * t + phase) * exp(-t / T2) for t >= 0, 0 for t < 0
    
    Parameters derived from f0, A, T2 and phi are available as phase_deg(), linewidth() and Q()
    """

    def __init__(self, f0, A, T2, phi):
        if A < 0: # if a negative amplitude is supplied, make it positive and shift phase by 180 degrees
            A = -A
            phi += pi 
        
        self.f0 = f0
        self.A = A
        self.T2 = T2
        self.phi = phi % (2*pi) # shift phase to the [0, 2pi) range

    def calculate(self, f): return Lorentzian.l(self.f0, self.A, self.T2, self.phi, f)
    
    @staticmethod
    def l(f0, A, T2, phi, f):
        """
        For a frequency or an array of frequencies 'f' return a lorentzian
        defined by:
            f0  - resonance frequency
            A   - amplitude
            T2  - relaxation time
            phi - phase
        """
        return 0.5*A*T2*numpy.exp(1j*phi) / (1 + 2j*pi*(f-f0)*T2)

    @staticmethod
    def guess(spectrum, T2=None):
        """
        Make an initial guess for Lorentzian fit parameters.
        if 'T2' is specified, it is kept as a guess.
        """
        f0, peak = spectrum.findpeak()
        
        # guess for phi: tan(phi) = Im( w(f0) )/Re( w(f0) )
        phi = flib.phase(peak) % (2*pi) # shift phase to the [0, 2pi) range
        
        # guess for T2:
        # find a frequency, at which absolute amplitude is 1/sqrt(2) the peak height - it is 1/T2 away from f0
        if T2 is None:
            K = (spectrum.abs() - abs(peak)/2)**2
            f_half = spectrum.f[K == min(K)][0]
            T2 = 1/(2*pi*abs(f_half - f0))
        
        # guess for A:
        A = 2 * abs(peak) / T2
        
        return Lorentzian(f0=f0, A=A, T2=T2, phi=phi)

    @staticmethod
    def residuals(param, f, w):
        """
        Calculate residuals for the fit
        """
        f0, A, T2, phi = param
        return abs(w - Lorentzian.l(f0, A, T2, phi, f))

    @staticmethod    
    def fit(spectrum, T2=None):
        """
        fit a Lorentzian to the spectrum. if 'T2' is specified, it is used as
        the initial guess of T2.
        """
        g = Lorentzian.guess(spectrum, T2)
        res = scipy.optimize.leastsq(Lorentzian.residuals, [g.f0, g.A, g.T2, g.phi], args=(spectrum.f, spectrum.w), xtol=1e-12, maxfev=int(1e5))
        f = Lorentzian(f0=res[0][0], A=res[0][1], T2=res[0][2], phi=res[0][3])
#        if res[1] == 1:
        if res[1] >= 1 and res[1] <= 4:
            return f
        else:
            raise FitError(f)

    def phase_deg(self): "phase in degrees"; return self.phi * 180 / pi
    def Q(self): "quality factor of resonance"; return pi * self.f0 * self.T2
    def linewidth(self): "linewidth"; return 1 / (pi * self.T2)
    def peak_height(self): "peak height in a spectrum"; return 0.5*self.A*self.T2
    def __repr__(self): return 'Lorentzian(f0 = %.2f, A = %g, T2 = %g, phase = %g degrees)' % (self.f0, self.A, self.T2, self.phase_deg())

    def normalise(self):
        """
        Return a normalised Lorentzian with a peak height of 1 and zero phase.
        This is to adjust phases of signals aquired with
        tuned sensors/amplifiers.
        """
        return Lorentzian(self.f0, 2 / self.T2, self.T2, 0)

class ProperLorentzian(Lorentzian):
    """
    This model describes a phase-rich Lorentzian with a peak at negative frequencies:
    
    w(f) = 1/2 * A * T2 * exp(i*phi) / (1 + 2*pi*i*(f-f0)*T2) + 1/2 * A * T2 * exp(-i*phi) / (1 + 2*pi*i*(f+f0)*T2)

    f0 - resonance frequency
    A - amplitude
    T2 - relaxation time
    phi - phase
    
    w(f) is a Fourier Transform of:
    
    u(t) = A * cos(2*pi*f0 * t + phase) * exp(-t / T2) for t >= 0, 0 for t < 0
    
    Parameters derived from f0, A, T2 and phi are available as phase_deg(), linewidth() and Q()
    """

    def calculate(self, f): return ProperLorentzian.l(self.f0, self.A, self.T2, self.phi, f)
    
    @staticmethod
    def l(f0, A, T2, phi, f):
        """
        For a frequency or an array of frequencies 'f' return a lorentzian
        defined by:
            f0  - resonance frequency
            A   - amplitude
            T2  - relaxation time
            phi - phase
        """
        return 0.5*A*T2*numpy.exp(1j*phi) / (1 + 2j*pi*(f-f0)*T2) \
                + 0.5*A*T2*numpy.exp(-1j*phi) / (1 + 2j*pi*(f+f0)*T2)

    @staticmethod
    def guess(spectrum, T2):
        """
        Make an initial guess for Lorentzian fit parameters.
        """
        f0, peak = spectrum.findpeak()
        
        # guess for phi: tan(phi) = Im( w(f0) )/Re( w(f0) )
        phi = flib.phase(peak) % (2*pi) # shift phase to the [0, 2pi) range
        
        # guess for T2:
        # find a frequency, at which absolute amplitude is 1/sqrt(2) the peak height - it is 1/T2 away from f0
        K = (spectrum.abs() - abs(peak)/2)**2
        f_half = spectrum.f[K == min(K)][0]
        if T2 is not None:
            T2 = 1/(2*pi*abs(f_half - f0))
        
        # guess for A:
        A = 2 * abs(peak) / T2
        
        return ProperLorentzian(f0=f0, A=A, T2=T2, phi=phi)

    @staticmethod
    def residuals(param, f, w):
        """
        Calculate residuals for the fit
        """
        f0, A, T2, phi = param
        return abs(w - ProperLorentzian.l(f0, A, T2, phi, f))

    @staticmethod    
    def fit(spectrum, T2=None):
        g = ProperLorentzian.guess(spectrum, T2)
        res = scipy.optimize.leastsq(ProperLorentzian.residuals, [g.f0, g.A, g.T2, g.phi], args=(spectrum.f, spectrum.w), xtol=1e-12, maxfev=int(1e5))
        f = ProperLorentzian(f0=res[0][0], A=res[0][1], T2=res[0][2], phi=res[0][3])
#        if res[1] == 1:
        if res[1] >= 1 and res[1] <= 4:
            return f
        else:
            raise FitError(f)

    def normalise(self):
        """
        Return a normalised Lorentzian with a peak height of 1 and zero phase.
        This is to adjust phases of signals aquired with
        tuned sensors/amplifiers.
        """
        return ProperLorentzian(self.f0, 2 / self.T2, self.T2, 0)


class PhaselessLorentzian(SpectrumModel):
    """
    This model describes a phase-poor Lorentzian:
    
    |w(f)| = 1/2 * A * T2 / (1 + 2*pi*i*(f-f0)*T2)

    f0 - resonance frequency
    A - amplitude
    T2 - relaxation time
        
    Parameters derived from f0, A, T2 are available as linewidth() and Q()
    """

    def __init__(self, f0, A, T2):
        if A < 0: # if a negative amplitude is supplied, make it positive and shift phase by 180 degrees
            A = -A
        
        self.f0 = f0
        self.A = A
        self.T2 = T2

    def calculate(self, f): return PhaselessLorentzian.l(self.f0, self.A, self.T2, f)
    
    @staticmethod
    def l(f0, A, T2, f):
        """
        For a frequency or an array of frequencies 'f' return a lorentzian
        defined by:
            f0  - resonance frequency
            A   - amplitude
            T2  - relaxation time
        """
        return 0.5*A*T2 / numpy.abs(1 + 2j*pi*(f-f0)*T2)

    @staticmethod
    def guess(spectrum):
        """
        Make an initial guess for Lorentzian fit parameters.
        """
        f0, peak = spectrum.findpeak()
        
        # guess for T2:
        # find a frequency, at which absolute amplitude is 1/sqrt(2) the peak height - it is 1/T2 away from f0
        K = (spectrum.abs() - abs(peak)/2)**2
        f_half = spectrum.f[K == min(K)][0]
        T2 = 1/(2*pi*abs(f_half - f0))
        
        # guess for A:
        A = 2 * abs(peak) / T2
        
        return PhaselessLorentzian(f0=f0, A=A, T2=T2)

    @staticmethod
    def residuals(param, f, w):
        """
        Calculate residuals for the fit
        """
        f0, A, T2 = param
        return abs(w - PhaselessLorentzian.l(f0, A, T2, f))

    @staticmethod    
    def fit(spectrum):
        g = PhaselessLorentzian.guess(spectrum)
        res = scipy.optimize.leastsq(PhaselessLorentzian.residuals, [g.f0, g.A, g.T2], args=(spectrum.f, abs(spectrum.w)), xtol=1e-12, maxfev=int(1e5))
        f = PhaselessLorentzian(f0=res[0][0], A=res[0][1], T2=res[0][2])
#        if res[1] == 1:
        if res[1] >= 1 and res[1] <= 4:
            return f
        else:
            raise FitError(f)

    def Q(self): "quality factor of resonance"; return pi * self.f0 * self.T2
    def linewidth(self): "linewidth"; return 1 / (pi * self.T2)
    def __repr__(self): return 'PhaselessLorentzian(f0 = %.2f, A = %g, T2 = %g)' % (self.f0, self.A, self.T2)

##class Burst(SpectrumModel):
##    """
##    This model describes a burst of sinewave:
##
##    w(f) = A * exp(i*phi) / 2 * (1 - exp(-2i*pi*(f-f0)*T)) / (2i*pi*(f-f0))
##    w(f0) = A * exp(i*phi) * T / 2
##
##    f0 - resonance frequency
##    A - amplitude
##    T - burst length
##    phi - phase
##    
##    w(f) is a Fourier Transform of:
##    
##    u(t) = A * cos(2*pi*f0 * t + phase) for 0 <= t <= T, 0 otherwise.    
##    """
##
##    def __init__(self, f0, A, T, phi):
##        self.f0 = f0
##        self.A = A
##        self.T = T
##        self.phi = phi % (2*pi) # shift phase to the [0, 2pi) range
##
##    def calculate(self, f): return Burst.b(self.f0, self.A, self.T, self.phi, f)
##    
##    @staticmethod
##    def b(f0, A, T, phi, f):
##        """
##        For a frequency or an array of frequencies 'f' return a lorentzian
##        defined by:
##            f0  - resonance frequency
##            A   - amplitude
##            T   - length
##            phi - phase
##        """
##        s = 0.5*A*numpy.exp(1j*phi) * (1 - numpy.exp(-2j*pi*(f-f0)*T)) / (2j*pi*(f-f0))
##        s[f == f0] = 0.5*A*numpy.exp(1j*phi)*T
##        return s
##
##    @staticmethod
##    def guess(spectrum, T):
##        """
##        Make an initial guess for Lorentzian fit parameters.
##        """
##        f0, peak = spectrum.findpeak()
##        
##        # guess for phi: tan(phi) = Im( w(f0) )/Re( w(f0) )
##        phi = flib.phase(peak) % (2*pi) # shift phase to the [0, 2pi) range
##                
##        # guess for A:
##        A = 2 * abs(peak) / T
##        
##        return Burst(f0=f0, A=A, T=T, phi=phi)
##
##    @staticmethod
##    def residuals(param, f, w, T):
##        """
##        Calculate residuals for the fit
##        """
##        f0, A, phi = param
##        return abs(w - Burst.b(f0, A, T, phi, f))
##
##    @staticmethod    
##    def fit(spectrum, T):
##        g = Burst.guess(spectrum, T)
##        res = scipy.optimize.leastsq(Burst.residuals, [g.f0, g.A, g.phi], args=(spectrum.f, spectrum.w, T), xtol=1e-12, maxfev=int(1e5))
##        f = Burst(f0=res[0][0], A=res[0][1], T=T, phi=res[0][2])
##        # if res[1] == 1:
##        if res[1] >= 1 and res[1] <= 4:
##            return f
##        else:
##            raise FitError(f)
##
##    def phase_deg(self): "phase in degrees"; return self.phi * 180 / pi
##    def __repr__(self): return 'Burst(f0 = %.2f, A = %g, T = %g, phase = %g degrees)' % (self.f0, self.A, self.T, self.phase_deg())

class DoubleLorentzian(SpectrumModel):
    """
    This model describes a pair of phase-rich Lorentzians

    a.f0,  b.f0  - resonance frequency
    a.A,   b.A   - amplitude
    a.T2,  b.T2  - relaxation time
    a.phi, b.phi - phase
    
    """

    def __init__(self, a, b):
        self.a = a
        self.b = b

    def calculate(self, f): return self.a.calculate(f) + self.b.calculate(f)

    @staticmethod
    def residuals(param, f, w):
        """
        Calculate residuals for the fit
        """
        f0a, Aa, T2a, phia, f0b, Ab, T2b, phib = param
        return abs(w - Lorentzian.l(f0a, Aa, T2a, phia, f) - Lorentzian.l(f0b, Ab, T2b, phib, f))

    @staticmethod    
    def fit(spectrum, guess_a, guess_b, T2a=None, T2b=None):
        """
        fit a double lorentzian to the 'spectrum'.
        'guess_a' and 'guess_b' are either initial guesses
        for the lorentzians (objects of 'LorentzianFit' type)
        or tuples containing the frequency ranges to run
        the single Lorentzian fits on in order to obtain these guesses.
        
        If either of or both 'T2a' and 'T2b' are specified, they are used as
        initial guess of the T2 of the two lorentzians.
        """
        if not isinstance(guess_a, Lorentzian):
            guess_a = Lorentzian.fit(spectrum.frange(*guess_a), T2a)
        if not isinstance(guess_b, Lorentzian):
            guess_b = Lorentzian.fit(spectrum.frange(*guess_b), T2b)
        
        guess = [guess_a.f0, guess_a.A, guess_a.T2, guess_a.phi, guess_b.f0, guess_b.A, guess_b.T2, guess_b.phi]
        
        res = scipy.optimize.leastsq(DoubleLorentzian.residuals, guess, args=(spectrum.f, spectrum.w), xtol=1e-12, maxfev=int(1e5))
        a = Lorentzian(f0=res[0][0], A=res[0][1], T2=res[0][2], phi=res[0][3])
        b = Lorentzian(f0=res[0][4], A=res[0][5], T2=res[0][6], phi=res[0][7])
        f = DoubleLorentzian(a, b)
#        if res[1] == 1:
        if res[1] >= 1 and res[1] <= 4:
            return f
        else:
            raise FitError(f)
    
    @staticmethod    
    def cut_and_fit(spectrum, f_cut, T2a=None, T2b=None):
        """
        fit a double lorentzian to the 'spectrum'.
        Initial guesses are obtained by fitting
        single Lorentzians on parts of the spectrum below (a) and above (b)
        the frequency 'f_cut'.
        If either of or both 'T2a' and 'T2b' are specified, they are used as
        initial guess of the T2 of the two lorentzians.
        """
        return DoubleLorentzian.fit(spectrum, [spectrum.fmin(), f_cut], [f_cut, spectrum.fmax()], T2a, T2b)
    
    def __repr__(self):
        return "%s, %s" % (self.a, self.b)

class NoisePowerLorentzian(SpectrumModel):
    """
    This model describes a Lorentzian peak in the noise power:
    
    S(f) = Sp / (1 + {2*Q*(1 - f/f0)}^2) + S0 

    Sp - peak noise level
    f0 - resonance frequency
    Q - quality factor of the resonance
    S0 - offresonance noise level
    """

    def __init__(self, Sp, f0, Q, S0):
        self.Sp = numpy.real(Sp).mean()
        self.f0 = numpy.real(f0).mean()
        self.Q = abs(Q)
        self.S0 = S0

    def calculate(self, f): return NoisePowerLorentzian.pl(self.Sp, self.f0, self.Q, self.S0, f)
    
    @staticmethod
    def pl(Sp, f0, Q, S0, f):
        """
        For a frequency or an array of frequencies 'f' return a power lorentzian
        defined by:
            Sp - peak noise level
            f0 - resonance frequency
            Q - quality factor of the resonance
            S0 - offresonance noise level
        """
        return Sp / (1 + (2*Q*(1 - f/f0))**2) + S0

    @staticmethod
    def guess(spectrum, allowOffLevel = True):
        """
        Make an initial guess for noise power Lorentzian fit parameters.
        """
        if allowOffLevel:
            S0 = numpy.mean((spectrum.w[0], spectrum.w[-1]))
            spectrum = spectrum - S0
        else:
            S0 = 0.0
            
        f0,Sp = spectrum.findpeak()
        
        K = abs(spectrum.w - (Sp)/2)
        f_half = spectrum.f[K == min(K)][0]
        Q = 2*f0 / abs(f_half - f0)
        
        return NoisePowerLorentzian(f0=f0, Q=Q, Sp=Sp, S0=S0)

    @staticmethod
    def residuals(param, f, w):
        """
        Calculate residuals for the fit
        """
        if len(param) == 3:
            Sp, f0, Q = param
            S0 = 0.0
        else:
            Sp, f0, Q, S0 = param
        return abs(w - NoisePowerLorentzian.pl(Sp, f0, Q, S0, f))
        
    @staticmethod    
    def fit(spectrum, allowOffLevel = True):
        """
        Fit a (noise) power Lorentzian to the data.
        
        allowOffLevel (True/False) determines if off-resonance spectrum is non-zero
        """
        g = NoisePowerLorentzian.guess(spectrum, allowOffLevel)
        g = [g.Sp, g.f0, g.Q, g.S0] if allowOffLevel else [g.Sp, g.f0, g.Q]
        res = scipy.optimize.leastsq(NoisePowerLorentzian.residuals, g, args=(spectrum.f, spectrum.w), xtol=1e-12, maxfev=int(1e5))
        if allowOffLevel:
            l = NoisePowerLorentzian(Sp=res[0][0], f0=res[0][1], Q=res[0][2], S0=res[0][3])
        else:
            l = NoisePowerLorentzian(Sp=res[0][0], f0=res[0][1], Q=res[0][2], S0=0.0)
#        if res[1] == 1:
        if res[1] >= 1 and res[1] <= 4:
            return l
        else:
            raise FitError(l)

    def __repr__(self): return 'NoisePowerLorentzian(Sp = %g, f0 = %.2f, Q = %g, S0 = %g)' % (self.Sp, self.f0, self.Q, self.S0)

class SincSpectrum(SpectrumModel):
    """
    This model describes a phase-rich spectrum of a sine-wave pulse:
    
    w(f) = 1/2 * A * T * exp( i*phi - i*pi*(f-f0)*T) sinc((f-f0)*T)
         + 1/2 * A * T * exp(-i*phi - i*pi*(f+f0)*T) sinc((f+f0)*T)

    f0 - ringing frequency
    A - amplitude in time domain
    T - pulse length (pulse is from t=0 to t=T)
    phi - phase
    
    w(f) is a Fourier transform of:
    
    u(t) = A * cos(2*pi*f0 * t + phase) for 0 <= t <= T, 0 otherwise
    """

    def __init__(self, T, f0, A, phi):
        if A < 0: # if a negative amplitude is supplied, make it positive and shift phase by 180 degrees
            A = -A
            phi += pi 
        
        self.T = T
        self.f0 = f0
        self.A = A
        self.phi = phi % (2*pi) # shift phase to the [0, 2pi) range

    def calculate(self, f): return SincSpectrum.s(self.T, self.f0, self.A, self.phi, f)
    
    @staticmethod
    def s(T, f0, A, phi, f):
        """
        For a frequency or an array of frequencies 'f' return a sinc spectrum
        defined by:
            f0  - ringing frequency
            A   - amplitude
            T   - pulse duration (pulse is from t=0 to t=T)
            phi - phase
        """
        return 0.5*A*T*numpy.exp(1j*( phi - pi*(f-f0)*T))*numpy.sinc((f-f0)*T) \
             + 0.5*A*T*numpy.exp(1j*(-phi - pi*(f+f0)*T))*numpy.sinc((f+f0)*T)

    @staticmethod
    def guess(spectrum, T):
        """
        the pulse time 'T' is not guessed but expected to be a fixed parameter
        """
        f0, peak = spectrum.findpeak()
        phi = flib.phase(peak) % (2*pi) # shift phase to the [0, 2pi) range
        A = 2 * abs(peak) / T
        return SincSpectrum(T=T, f0=f0, A=A, phi=phi)

    @staticmethod
    def residuals(param, f, w, T):
        """
        Calculate residuals for the fit
        """
        f0, A, phi = param
        return abs(w - SincSpectrum.s(T, f0, A, phi, f))

    @staticmethod    
    def fit(spectrum, T):
        """
        fit a sinc-function to the spectrum. pulse length 'T' is not guessed,
        but should be specified.
        """
        g = SincSpectrum.guess(spectrum, T)
        res = scipy.optimize.leastsq(SincSpectrum.residuals, [g.f0, g.A, g.phi], args=(spectrum.f, spectrum.w, g.T), xtol=1e-12, maxfev=int(1e5))
        f = SincSpectrum(T=T, f0=res[0][0], A=res[0][1], phi=res[0][2])
        if res[1] >= 1 and res[1] <= 4:
            return f
        else:
            raise FitError(f)

    def phase_deg(self): "phase in degrees"; return self.phi * 180 / pi
    def peak_height(self): "peak height in a spectrum"; return 0.5*self.A*self.T
    def __repr__(self): return 'SincSpectrum(T = %g, f0 = %.2f, A = %g, phase = %g degrees)' % (self.T, self.f0, self.A, self.phase_deg())

class TraceModel(object):
    """
    A base class for models used to fit time domain signals
    """

    def model(self, t):
        """
        Construct an nmr.Trace representation of the fit at at given
        frequencies. 't' is expected to be either an array of sampling times
        or an nmr.Trace object.
        """
        if isinstance(t, nmr.Trace):
            t = t.t()
        return nmr.Trace(t, self.calculate(t))

    def subtract(self, trace):
        """
        Return a difference between an NMR.Trace object 'trace' and
        the fit model. Subtrace the fit FROM the data
        """
        return trace - self.model(trace)

    def calculate(self, t):
        """
        This function is to be overriden in child classes. It returns an array
        of signal samples at given times 't'.
        """        
        return numpy.zeros((len(f)))


class CosineWave(TraceModel):
    """
    A cosine wave fit in time domain

    A - amplitude,
    f0 - frequency,
    phi - initial phase
    
    S(t) = A * cos(2*pi*f0*t + phi)

    """
    def __init__(self, f0, A, phi):
        self.f0 = f0
        self.A = A
        self.phi = phi % (2*pi) # shift phase to the [0, 2pi) range

    @staticmethod
    def w(f0, A, phi, t): return A * numpy.cos(2*pi*f0*t + phi)

    def calculate(self, t): return CosineWave.w(self.f0, self.A, self.phi, t)

    @staticmethod
    def guess(trace, fmin=None, fmax=None):
        """
        Make a guess. Optional arguments fmin and fmax restrict the frequency
        range for the sinewave.
        """
        w = trace.fft().frange(fmin, fmax)
        return CosineWave(w.peakf(), abs(w.peak()) / trace.timespan(), flib.phase(w.peak()))

    @staticmethod    
    def fit(trace, fmin=None, fmax=None):
        """
        Fit a cosine wave through a trace.
        """
        g = CosineWave.guess(trace, fmin, fmax)
        res = scipy.optimize.leastsq(CosineWave.residuals, [g.f0, g.A, g.phi], args=(trace.t(), trace.s), xtol=1e-12, maxfev=int(1e5))
        f = CosineWave(f0=res[0][0], A=res[0][1], phi=res[0][2])
#        if res[1] == 1:
        if res[1] >= 1 and res[1] <= 4:
            return f
        else:
            raise FitError(f)

    @staticmethod
    def residuals(param, f, w):
        """
        Calculate residuals for the fit
        """
        f0, A, phi = param
        return abs(w - CosineWave.w(f0, A, phi, f))

    def phase_deg(self): "phase in degrees"; return self.phi * 180 / pi
    def __repr__(self): return 'CosineWave(f0 = %.2f, A = %g, phase = %g degrees)' % (self.f0, self.A, self.phase_deg())
