"""
nmr.py - a library of nmr routines

classes: nmr.Trace, nmr.Spectrum, nmr.SignalMismatchError
function: nmr.rfftfreq

7 Dec 2009, Lev
"""
#
# Incomplete Changelog
# ...
# 30 May 2008
# Add 'Trace.loadallbinaries', 'Trace.loadallasciis' functions.
#
# 25 June 2008
# Add [] behavour to Trace and Spectrum classes (__getitem__/__setitem__)
# Add unary +/- to Trace and Spectrum classes
#
# 17 July 2008
# Add Spectrum.ifft, fix string representations
#
# 4 September 2008
# add 'y0' optional argument to Spectrum.plotabs/plotri routines
# add 'Spectrum.mhz' routine
#
# 3 Dec 2008
# migrate from __div__ to __truediv__ in 'Spectrum' and 'Trace'
#
# ~ May 2009
# add 'nmr.Trace.downsample'
#
# 10 Jul 2009
# update savebinary/loadbinary to the new format
#
# 7 Dec 2009
# add 'nmr.Spectrum.trace_swing_Vpp' method

import numpy
import flib
import pylab
import math
import glob

class SignalMismatchError(Exception):
    """
    Exception raised when incompatible signals are combined.
    If not None, 'self.signal1' and 'self.signal2' are the signals in question.
    """

    def __init__(self, signal1=None, signal2=None):
        """
        Exception constructor. Two arguments can be specified,
        containing two signals in question
        """
        self.signal1 = signal1
        self.signal2 = signal2
    
    def __str__(self): return 'Signal mismatch'

class Trace(object):
    """
    Time domain signal trace.
    
    An array of samples is stored at 's', 'dt' is the sampling interval.
    Timeline is returned by 't()'.
    
    Arithmetic operations +,-,*,/ are supported. Other signals or numbers can be
    added and subtracted. A signal can be multiplied or divided by a number.
    
    '&' appends samples to a trace. An array of samples can be concatenated,
    or another trace. The latter must the same 'dt' as the trace being extended.
    Also samples can be added to the beginning of a trace.
    
    '[t]' returns a sample at the moment 't', time is rounded down. Negative 't'
    is interpreted as time relative to the end of the trace. IndexError is raised
    if t is out of range.
    
    'info', set by default to 'None' is intended for metadata related to the
    trace, such as a record from a scope log.
    """

    def __init__(self, arg1, arg2 = None):
        """
        Create a time domain signal trace. Takes 1 or 2 arguments.
        
        If a single argument is supplied, it is expected to be a two-column
        array, first column - timeline, second column - samples.

        If two arguments are supplied, the first is a timeline if it is an
        array or a timebase if it is a scalar, the second argument is an array
        of samples.
        
        All supplied arrays are copied (so a modification of .s will not affect
        an array used to create a Trace. If complex numbers are given,
        imaginary part is discared.
        """
        if arg2 is None:
            dt = (arg1[-1,0] - arg1[0,0]) / (len(arg1) - 1)
            self.s = numpy.asarray(numpy.real(arg1[:,1]), dtype=numpy.float64)
        else:
            self.s = numpy.asarray(numpy.real(arg2), dtype=numpy.float64)
            if numpy.isscalar(arg1):
                dt = arg1
            else:
                if len(arg1) != len(arg2):
                    raise ValueError('time and sample arrays are of different lenght')
                dt = (arg1[-1] - arg1[0]) / (len(arg1) - 1)
        self.dt = numpy.float64(numpy.real(dt))
        self.info = None

    def __len__(self):
        'Return number of samples'
        return len(self.s)

    def timespan(self):
        """
        Return time over which samples span [s].
        This includes one 'dt' after the last sample, or one might think of it
        as 'dt/2' before the first and after the last sample.
        """
        return self.dt * len(self)

    def t(self):
        'Return timeline array'
        return numpy.arange(len(self.s)) * self.dt
    
    def srate(self):
        'Return sampling rate [Hz]'
        return 1./self.dt

    def compatible(self, other):
        'Return True if two traces have same timelines, False otherwise'
        return isinstance(other, Trace) and (len(self) == len(other)) and (self.dt == other.dt)

    def check(self, other):
        """
        Raise a ValueError, if other is not a Trace, a SignalMismatchError if
        two traces are not compatible.
        """
        if not isinstance(other, Trace):
            raise ValueError
        if not self.compatible(other):
            raise SignalMismatchError(self, other)

    def copy(self): 'Return a copy of the trace'; return Trace(self.dt, self.s.copy())

    def __eq__(self, other): return self.compatible(other) and all(self.s == other.s)
    def __ne__(self, other): return not (self == other)
    
    def __neg__(self): return Trace(self.dt, -self.s)
    def __pos__(self): return self

    def __add__(self, other):
        if numpy.isscalar(other):
            return Trace(self.dt, self.s + other)
        self.check(other)
        return Trace(self.dt, self.s + other.s)

    def __sub__(self, other):
        if numpy.isscalar(other):
            return Trace(self.dt, self.s - other)
        self.check(other)
        return Trace(self.dt, self.s - other.s)

    def __mul__(self, other):
        if numpy.isscalar(other):
            return Trace(self.dt, self.s * other)
        raise ValueError('a trace can be only multiplied by a number')

    def __truediv__(self, other):
        if numpy.isscalar(other):
            return Trace(self.dt, self.s / other)
        raise ValueError('a trace can be only divided by a number')
    
    __div__ = __truediv__
        
    # addition and multiplication are commutative
    __radd__ = __add__
    __rmul__ = __mul__

    def __rsub__(self, other):
        if numpy.isscalar(other):
            return Trace(self.dt, other - self.s)
        raise ValueError

    def __and__(self, other):
        if isinstance(other, Trace):
            if self.dt != other.dt:
                raise SignalMismatchError
            return Trace(self.dt, numpy.concatenate((self.s, other.s)))
        return Trace(self.dt, numpy.concatenate((self.s, numpy.array(numpy.real(other), dtype=numpy.float64))))

    def __rand__(self, other):
        if other is None:
            return self
        return Trace(self.dt, numpy.concatenate((numpy.array(numpy.real(other), dtype=numpy.float64), self.s)))

    def __repr__(self): return 'Trace(dt=%.3es,%d points: %s)%s'%(self.dt, len(self), self.s, '' if self.info is None else ' ' + repr(self.info))
    def __hash__(self): return hash(self.s)
    
    def __getitem__(self, t): return self.s[int(math.floor(t / self.dt))]
    def __setitem__(self, t, value): self.s[int(math.floor(t / self.dt))] = value
    def __contains__(self, t): return t >= 0 and t <= self.dt * (len(self) - 1)
    def __iter__(self): return iter(self.t())

    def fft(self, Nstart=0, Nstop=0, Npad=0, fmin=0, fmax=0, bgpower=0, T2filter=0):
        """
        (fast) Fourier transform a signal.
        
        Nstart, Nstop - subset of samples to use for FFT. Nstart defaults to 0, Nstop to number of samples.
        fmin, fmax    - frequency range to return FFT in
        bgpower       - power of a polynom to background subtract. Subtraction is not carried if bgpower is less than 0.
        """
        s = self.s.copy()
        t = self.t()

        if Nstop < 1 or Nstop > len(s):
            Nstop = len(s)
            
        if Nstart < 0:
            Nstart = 0

        if Npad < 1:
            Npad = int(2**math.ceil(math.log(Nstop - Nstart, 2))) 

        if bgpower >= 0:
            s -= numpy.polyval(numpy.polyfit(t, s, bgpower), t)

        if T2filter > 0:
            s *= numpy.exp(-t / T2filter)

        # Discrete Fourier transform:
        # Trace s(t) is real -> don't calculate negative frequency spectrum
        # Normalise transform S(f) by dt (to approximate continuous FT)
        ww = numpy.fft.rfft(s[Nstart:Nstop], Npad) * self.dt
        ff = rfftfreq(Npad, self.dt)

        if fmax <= 0:
            fmax = max(ff)

        ii = (ff >= fmin) & (ff <= fmax)
        
        X = Spectrum(ff[ii], ww[ii])
 
        del ww, ff, ii, s, t

        return X
        
    def subset(self, condition):
        "Return a subset of the trace based upon 'condition' argument."
        return Trace(self.dt, self.s[condition])
        
    def period(self, start = None, stop = None):
        """
        Return a subset of the trace happening between start and stop
        (in seconds), start defaults to the beginning of the trace,
        stop - to the end.
        """
        if start is None:
            start = 0
        if stop is None:
            stop = self.timespan()
            
        t = self.t()
        
        return self.subset((t >= start) & (t <= stop))

    def range(self, Nstart = 0, Nstop = None):
        """
        Return a subset of the trace starting with 'Nstart'th sample
        and up to 'Nstop'-1 sample. If 'Nstop' is not specified, it
        defaults to the end of the trace.
        """
        if Nstop is None:
            return Trace(self.dt, self.s[Nstart:])
        else:
            return Trace(self.dt, self.s[Nstart:Nstop])

    @staticmethod
    def loadbinary(filename, samplingrate, channel='channel0'):
        """
        Load a trace from a binary file. Sampling rate is specified explicitely.
        If the file is stored in binary format 2 or higher, 'channel' can be
        used to determine which section of the file to use. Defaults
        to 'channel0', which is the only section available in the earlier
        formats.
        """
        v,h,d = flib.loadbinary(filename, full=True)
        return Trace(1./samplingrate, d[channel])

    @staticmethod
    def loadbinaries(filenameTemplate, indices, samplingrate, channel='channel0'):
        """
        Load traces from binary files and average them. Sampling rate is specified explicitely
        If the file is stored in binary format 2 or higher, 'channel' can be
        used to determine which section of the file to use. Defaults
        to 'channel0', which is the only section available in the earlier
        formats.
        """
        return Trace.loadallbinaries([filenameTemplate%i for i in indices], samplingrate, channel)

    @staticmethod
    def loadallbinaries(filenames, samplingrate, channel='channel0'):
        """
        Load traces from binary files and average them. Sampling rate is specified explicitely.
        If 'filenames' is a string it is treated as a wildcard (see fnmatch), otherwise it is supposed
        to be a list of strings, each pointing to an existing filename.
        
        If the file is stored in binary format 2 or higher, 'channel' can be
        used to determine which section of the file to use. Defaults
        to 'channel0', which is the only section available in the earlier
        formats.
        """
        if isinstance(filenames, (str, unicode)):
            filenames = glob.glob(filenames)
        
        if len(filenames) < 1:
            return None
        
        data = 0 # 0 + TraceX is TraceX
        for filename in filenames:
            data += Trace.loadbinary(filename, samplingrate)
        data /= len(filenames)
        return data

    @staticmethod
    def loadascii(filename):
        """
        Load trace in ASCII format:
        (1) Time (sec)
        (2) Signal amplitude
        """
        return Trace(flib.loadascii(filename, usecols=(0,1)))

    @staticmethod
    def loadasciis(filenameTemplate, indices):
        """
        Load traces from ascii files and average them.
        Sampling rate is specified explicitely
        """
        return Trace.loadallasciis([filenameTemplate%i for i in indices])

    @staticmethod
    def loadallasciis(filenames):
        """
        Load traces from binary files and average them. Sampling rate is specified explicitely.
        If 'filenames' is a string it is treated as a wildcard (see fnmatch), otherwise it is supposed
        to be a list of strings, each pointing to an existing filename.
        """
        if isinstance(filenames, (str, unicode)):
            filenames = glob.glob(filenames)

        if len(filenames) < 1:
            return None

        data = 0 # 0 + TraceX is TraceX
        for filename in filenames:
            data += Trace.loadascii(filename)
        data /= len(filenames)
        return data

    def saveascii(self, filename, header=None):
        """
        Save trace in ASCII format: (all in %.15e, tab-delimited)
        (1) Timebase
        (2) Samples
        """
        flib.saveascii(filename, (self.t(), self.s), fmt='%.15e', header=header, pack=True)

    def savebinary(self, filename, header='', dtype=numpy.float32):
        'Save a trace in a binary format, sampling rate is not saved.'
        flib.savebinary(filename, {'header': repr(header), 'channel0': numpy.asarray(self.s, dtype=dtype)})
    
    def envelope(self, Nwindow = None, Twindow = None, Nstep = None, Tstep = None):
        """
        Return an envelope of the trace, as an nmr.Trace object.
        Envelope contains the peak-to-peak amplitude of a moving window 'Nwindow'
        points or 'Twindow' seconds long. Either 'Nwindow' or 'Twindow' must be
        specified, but not both.
        
        In a similar way 'Nstep' or 'Tstep' determines a step by which the
        window moves, it defaults to window length.
        """
        
        if Nwindow is None and Twindow is None:
            raise ValueError("'Nwindow' or 'Twindow' argument must be specified")
        elif Nwindow is not None and Twindow is not None:
            raise ValueError("Only one of 'Nwindow' and 'Twindow' arguments can be specified")
        elif Nwindow is None:
            Nwindow = int(Twindow / self.dt)
        
        N = len(self)

        if Nwindow > N:
            Nwindow = N
            
        if Nstep is None and Tstep is None:
            Nstep = Nwindow
        elif Nstep is not None and Tstep is not None:
            raise ValueError("Only one of 'Nstep' and 'Tstep' arguments can be specified")
        elif Nstep is None:
            Nstep = int(Tstep / self.dt)
            
        if Nstep < 1:
            Nstep = 1
        
        e = numpy.zeros(N/Nstep)
        Nleft = Nwindow / 2
        Nright = Nwindow - Nleft

        for n in range(len(e)):
            Nstart = n*Nstep - Nleft
            Nstop = n*Nstep + Nright
            if Nstart < 0:
                Nstart = 0
                Nstop = Nwindow
            elif Nstop >= N:
                Nstart = N - Nwindow
                Nstop = N
            e[n] = max(self.s[Nstart:Nstop]) - min(self.s[Nstart:Nstop])

        return Trace(self.dt * Nstep, e)

    def linearbg(self, window = None, Twindow = None):
        """
        Return an nmr.Trace object containing a straight line best fittin through first 'window' and
        last 'window' points (or first 'Twindow' and last 'Twindow' seconds). This can be subtracted
        from the signal to avoid ripple due to start and end level mismatch. If neither 'window'
        nor 'Twindow' is specified 'window' defaults to 1.
        """
        if window is None and Twindow is None:
            window = 1
        elif window is None:
            window = Twindow / self.dt
    
        if window < 1:
            window = 1
        if window > len(self)/2:
            window = len(self)/2

        start = numpy.mean(self.s[:window])
        tstart = window/2.0 * self.dt
        end   = numpy.mean(self.s[-window:])
        tend = self.timespan() - window/2.0 * self.dt
        
        k = (end - start) / (tend - tstart)
        b = start - k*tstart
        return Trace(self.dt, b + k*self.t())

    def subtract_linearbg(self, window=None, Twindow=None):
        """
        Return a difference between the trace and its linear background,
        see 'linearbg()'.
        """
        return self - self.linearbg(window, Twindow)
    
    def log(self):
        """
        Take a logarithm of every point of the trace
        """
        return Trace(self.dt, numpy.log(self.s))

    def plot(self, *args, **vargs):
        """
        Plot the trace, only finite points
        """
        cond = numpy.isfinite(self.s)
        return pylab.plot(self.t()[cond], self.s[cond], *args, **vargs)

    def logplot(self, *args, **vargs):
        """
        Plot the trace with log scale on y axis
        """
        cond = numpy.isfinite(self.s) & (self.s > 0)
        return pylab.semilogy(self.t()[cond], self.s[cond], *args, **vargs)

    def downsample(self, Nstep=None, newDT=None, Nstart=0):
        """
        Return a subset of the trace at a different sampling rate
        """
        if Nstep is None:
            if newDT is None:
                raise ValueError('Downsampling is not specified')
            Nstep = numpy.round(newDT / self.dt)
            
        return Trace(self.dt * Nstep, self.s[Nstart::Nstep])

class Spectrum(object):
    """
    A Fourier Transform of a signal. A spectrum can contain only a subset of frequencies
    to reduce amount of data. Frequencies are stored as 'f', complex spectral densities as 'w'.
    
    Spectra with identical sets of frequency can be added to, subtracted from,
    multiplied by and divided by each other using '+-*/'.
    
    'self[f]' is the amplitude at frequency 'f'. IndexError is raised if f is not present
    in the spectrum.

    '&' can be used to combine two spectra without overlapping frequencies. ValueError is raised
    if there is an overlap.
    """
    
    def __init__(self, f, w=None):
        """
        Creates a Fourier Transform
        f - array of frequencies, only real component is used
        w - array of complex spectral densities.
        If a single argument is given, f[:,0] is frequencies, f[:,1] is spectral densities.
        """
        if w is None:
            self.f = numpy.asarray(numpy.real(f[:,0]), dtype=numpy.float64)
            self.w = numpy.asarray(f[:,1], dtype=numpy.complex128)
        else:
            if len(f) != len(w):
                raise ValueError('frequency and spectrum arrays are different length')
            self.f = numpy.asarray(numpy.real(f), dtype=numpy.float64)
            self.w = numpy.asarray(w, dtype=numpy.complex128)

    def copy(self):
        return Spectrum(self.f.copy(), self.w.copy())

    def subset(self, condition):
        """
        Return a subset of the Spectrum based upon 'condition' argument.
        """
        return Spectrum(self.f[condition], self.w[condition])

    def frange(self, fmin=None, fmax=None):
        """
        Return a spectrum in given frequency range.
        If fmin is not specified or out of range, minimum frequency of the spectrum is used.
        If fmax is not specified or out of range, maximum frequency of the spectrum is used.
        """
        if fmin is None: fmin = self.fmin()
        if fmax is None: fmax = self.fmax()
        return self.subset((self.f >= fmin) & (self.f <= fmax))

    def nofrange(self, fmin=None, fmax=None):
        """
        Return a subset of the spectrum excluding a frequency range limited
        by [fmin, fmax], boundaries being discared.
        fmin defaults to the minimum frequency,
        fmax defaults to the maximum frequency.
        """
        if fmin is None: fmin = self.fmin()
        if fmax is None: fmin = self.fmax()
        return self.subset((self.f < fmin) | (self.f > fmax))

    def comb(self, freq_blacklist):
        """
        Return a subset of a trace, lacking frequencies mentioned in 'freq_blacklist'.
        The list should contain numbers - the frequencies to discard, and (fmin, fmax)
        two-number pairs, which are interpreted as frequency range boundaries.
        """
        if len(freq_blacklist) < 1:
            return self

        cond = (self.f == self.f)
        
        for f in freq_blacklist:
            if numpy.isscalar(f):
                cond &= (self.f != f)
            else:
                cond &= (self.f < f[0]) | (self.f > f[1])

        return self.subset(cond)

    def isinrange(self, f):
        """
        Return True if f is in the range of frequencies of the spectrum,
        False otherwise
        """
        return (f >= self.fmin()) and (f <= self.fmax())

    def fmin(self): "minumum frequency"; return min(self.f)
    def fmax(self): "maximum frequency"; return max(self.f)
    def df(self): "frequency step"; return (self.fmax() - self.fmin()) / (len(self) - 1)
    
    def re(self): "real part of spectrum"; return pylab.real(self.w)
    def im(self): "imaginary part of spectrum"; return pylab.imag(self.w)
    def abs(self): "absolute value of spectrum"; return abs(self.w)
    
    def __len__(self): return len(self.f)

    def compatible(self, other):
        """
        Return True if two spectrums have same frequency arrays, False otherwise
        """
        return isinstance(other, Spectrum) and all(self.f == other.f)

    def check(self, other):
        """
        Raise a SignalMismatchError if two spectrums are not compatible.
        """
        if not self.compatible(other):
            raise SignalMismatchError(self, other)

    def findpeak(self):
        """
        Return a tuple containing the frequency and complex amplitude of the
        peak, defined as a point in the spectrum with maximum absolute value
        of the amplitude.
        """
        ii = (self.abs() == max(self.abs()))
        return (self.f[ii][0], self.w[ii][0])

    def peakf(self): "Return the frequency of a peak found by 'findpeak()'"; return self.findpeak()[0]
    def peak(self): "Return the complex amplitude of the peak found by 'findpeak()'"; return self.findpeak()[1]

    def findmin(self):
        """
        Return a tuple containing the frequency and complex amplitude of the
        peak, defined as a point in the spectrum with minimum absolute value
        of the amplitude.
        """
        ii = (self.abs() == min(self.abs()))
        return (self.f[ii][0], self.w[ii][0])

    def __eq__(self, other): return self.compatible(other) and all(self.f == other.f) and all(self.w == other.w)
    def __ne__(self, other): return not (self == other)

    def __neg__(self): return Spectrum(self.f, -self.w)
    def __pos__(self): return self

    def __add__(self, other):
        if pylab.isscalar(other):
            return Spectrum(self.f, self.w + other)
        self.check(other)
        return Spectrum(self.f, self.w + other.w)

    def __sub__(self, other):
        if pylab.isscalar(other):
            return Spectrum(self.f, self.w - other)
        self.check(other)
        return Spectrum(self.f, self.w - other.w)

    def __mul__(self, other):
        if pylab.isscalar(other):
            return Spectrum(self.f, self.w * other)
        self.check(other)
        return Spectrum(self.f, self.w * other.w)
#        raise ValueError('a spectrum can be only multiplied by a number')

    def __truediv__(self, other):
        if pylab.isscalar(other):
            return Spectrum(self.f, self.w / other)
        self.check(other)
        return Spectrum(self.f, self.w / other.w)
#        raise ValueError('a spectrum can be only divided by a number')

    __div__ = __truediv__

    # addition and multiplication are commutative
    __radd__ = __add__
    __rmul__ = __mul__

    def __rsub__(self, other):
        if pylab.isscalar(other):
            return Spectrum(self.f, other - self.w)
        raise ValueError

    def __and__(self, other):
        if isinstance(other, Spectrum):
            if any(self.f == other.f):
                raise ValueError("Can't to merge two spectra with overlapping frequency sets")
            return Spectrum(numpy.concatenate(self.f, other.f), numpy.concatenate(self.w, other.w))
        return ValueError

    def __repr__(self):
        if len(self) > 0:
            return 'Spectrum [%d points in %g-%g Hz range]' % (len(self), self.fmin(), self.fmax())
        else:
            return 'Empty Spectrum'

    def __getitem__(self, f):
        if not f in self:
            raise IndexError('frequency is not present in the spectrum')
        return self.w[self.f == f][0]

    def __setitem__(self, f, value):
        if not f in self:
            raise IndexError('frequency is not present in the spectrum')
        self.w[self.f == f] = value

    def __contains__(self, f): return f in self.f
    def __iter__(self): return iter(self.f)

    @staticmethod
    def loadascii(filename):
        """
        Load trace in ASCII format:
        (1) Frequency [Hz]
        (2) Abs amplitude *** disregarded ***
        (3) Real part of amplitude
        (4) Imaginary part of amplitude
        """
        f,r,i = flib.loadascii(filename, usecols=(0,2,3), unpack=True, keep2D=True)
        return Spectrum(f, r + 1j * i)

    def saveascii(self, filename, header=None):
        """
        Save spectrum in ASCII format: (all in %.15e, tab-delimited)
        (1) Frequency [Hz]
        (2) Abs amplitude
        (3) Real part of amplitude
        (4) Imaginary part of amplitude
        """
        flib.saveascii(filename, (self.f, self.abs(), self.re(), self.im()), fmt='%.15e', header=header, pack=True)
    
    @staticmethod
    def loadbinary(filename):
        """
        Load a spectrum from a binary file.
        The dataset is assumed to be 3*N long, first N points - frequencies,
        next go the real parts and finally the imaginary parts of the spectrum.
        """
        v,h,d = flib.loadbinary(filename, full=True)
        d = d['channel0']
        N = len(d) / 3
        s = Spectrum(d[:N], d[N:2*N] + 1j*d[2*N:3*N])
        if v > 0:
            try:
                h = eval(h, {'nan': numpy.nan})
            except:
                h = {}
        else:
            h = {}
        del d
        for k, v in h.items():
            try:
                s.__getattribute__(k)
            except:
                s.__setattr__(k, v)
        return s
    
    def savebinary(self, filename, header='', dtype=numpy.float32):
        """
        Save the spectrum into a binary file.
        First all frequencies, then real and finally imaginary parts are stored
        in a single section of the file called 'channel0'.
        
        Optinally a header is saved (a string)
        
        Data format can be specified by 'dtype' argument, defaults to 4 byte floats.
        """
        data = numpy.asarray(numpy.concatenate([self.f, self.re(), self.im()]), dtype=numpy.float32)
        flib.savebinary(filename, {'header': repr(header), 'channel0': data})
    
    def plotabs(self, *args, **vargs):
        """
        Plot absolute amplitudes of the spectrum.
        Optional 'y0' argument shifts spectrum vertically
        """
        if vargs.has_key('y0'):
            y0 = vargs.pop('y0')
        else:
            y0 = 0
        return pylab.plot(self.f, self.abs() + y0, *args, **vargs)

    def plothodo(self, phase = 0, *args, **vargs):
        """
        Plot a hodograph of the spectrum - real on x and imaginary on y.
        Axis is set to proportional aspect ratio.
        """
        W=self.w*numpy.exp(1j*phase);
        pylab.axis('equal')
        return pylab.plot(numpy.real(W), numpy.imag(W), *args, **vargs)

    def plotri(self, phase=0, color=None, y0=0, **vargs):
        """
        Plot real (solid) and imaginary (dash-dot) parts of the spectrum.
        Both are plotted in the same colour, either specified with 'color'
        argument, or picked by pylab automatically.
        Optional 'phase' argument rotates the spectrum
        Optional 'y0' argument shifts the spectrum vertically
        If 'label' is specified, only the real part is labelled.
        """

        if color is None: color = ''

        W=self.w*numpy.exp(1j*phase)
        pre = pylab.plot(self.f, numpy.real(W) + y0, color + '-', **vargs)
        if color == '': color = pre[0].get_color()
        if vargs.has_key('label'): vargs.pop('label')
        pim = pylab.plot(self.f, numpy.imag(W) + y0, color + '--', **vargs)
        return pre + pim

    def ifft(self, N, dt):
        """
        Inverse FFT. Return nmr.Trace object.
        Only works if the spectrum is complete from 0 Hz to Nyquist Frequency,
        as if was produced by FFT of an 'N'-point trace with a sampling
        interval of 'dt'
        """
        if numpy.all(self.f == rfftfreq(N, dt)):
            return Trace(dt, numpy.fft.irfft(self.w / dt, N))
        else:
            raise NotImplementedError
            
#            exps = numpy.exp(2j*numpy.pi*arange(0, N) / N)
#
#        s = numpy.zeros((N))
#        for f in self:
#            for n in 
#            s += 

    def mhz(self):
        """
        Shift frequencies to focus on 1-1.1MHz range
        In a returned spectrum 1 MHz is subtracted from all frequencies and
        frequencies are expressed in kHz, rather than Hz.
        """
        return Spectrum((self.f - 1e6)/1e3, self.w)

    def findpeaks(self, alpha = 3, singlepeaks = True, minf=None, maxf=None):
        """
        Find peak frequencies. Peaks are those points which differ from
        neighbourgs by more than alpha times the rms absolute difference
        between the points. If 'singlepeaks' is True, a point is considered
        a peak only if it is higher that its both neighbours, otherwise
        points on sides of wider peaks are included as well.
        
        If 'minf'/'maxf' are specified, peaks are looked for only in a given
        frequency window.
        """
        if minf is None: minf = self.fmin()
        if maxf is None: maxf = self.fmax()
        
        peaks = self.f[flib.absdiff(self.w, not singlepeaks) > alpha * numpy.mean(flib.absdiff(self.w, True)**2)**0.5]
        return peaks[(peaks >= minf) & (peaks <= maxf)]

    def removepeaks(self, **vargs):
        """
        Return a subset of the spectrum excluding single point peaks. For the
        peak search algorithm and argument meaning see nmr.Spectrum.findpeaks()
        """
        return self.subset(~numpy.setmember1d(self.f, self.findpeaks(**vargs)))

    def average(self, Nwindow = 10):
        """
        Thin the spectrum down by averaging every 'Nwindow' neighbour points
        together
        """
        return Spectrum(flib.average(numpy.array((self.f, self.w)).transpose(), Nwindow))

    def integrate(self):
        """
        Evaluate an integral of the spectrum between the minimum frequency
        and every frequency present in the spectrum. Return the result as an
        nmr.Spectrum object.
        """
        return Spectrum(self.f, flib.integrate(self.f, self.w))

    def trace_swing_Vpp(self):
        """
        return the time domain trace voltage swing, if this information was
        saved into the spectrum or NaN otherwise.
        """
        try:
            return self.trace_max - self.trace_min
        except:
            return numpy.nan
        
def rfftfreq(N, dt):
    """
    Return an array of frequencies for Real Fast Fourier Transform (only
    positive frequencies) N - number of points in the original time trace
    dt - sampling rate of the original time trace
    """
    return numpy.arange(0, (N+2)/2) / (N*dt)   # last f-entry is Nyquist freq.
