"""
twoD.py

Library for handling 2D datasets

Lev, July 2009
"""

import numpy, flib, ascii2numpy, gc
import numpy.fft
import _twoD

class Dataset(object):
    """
    2D data set
    
    This class is used both for AFM scans and for the correlations
    """
    
    def __init__(self, x, y, z, mask=None):
        """
        Dataset(x, y, z)
        Create a dataset.
        
        formats:
        
        1) x, y - scalars, z - 2D array NxM.
            z - the data, x and y are used as steps in the NxM grid
        
        2) x, y - N and M long 1D arrays, z - 2D array NxM.
            z - the data, x and y are used as the points of the grid
        
        3) x, y, z - 1D arrays
            create a grid based on all combinations of values in x and y
            and fill it with z, set NaN where z is not specified.
            
        1a), 2a), 3) - similar to 1-3, but an extra argument 'mask' is supplied, similar in shape to 'z'.
        """
        
        z = numpy.asarray(z)
        if mask is not None:
            mask = numpy.asarray(mask)
        else:
            mask = numpy.ones(z.shape)
        
        if numpy.isscalar(x) and numpy.isscalar(y):
            if len(z.shape) != 2:
                raise ValueError("'z' should be 2D, if x and y are scalar")
            x = numpy.arange(z.shape[0]) * x
            y = numpy.arange(z.shape[1]) * y
        else:
            x = numpy.asarray(x)
            y = numpy.asarray(y)
            
            if len(z.shape) == 1 and z.shape == x.shape and z.shape == y.shape:
                X = numpy.asarray(sorted(set(x)))
                XX = {}
                for i in range(len(X)): XX[X[i]] = i
                
                Y = numpy.asarray(sorted(set(y)))
                YY = {}
                for i in range(len(Y)): YY[Y[i]] = i
                
                Z = numpy.zeros([len(X), len(Y)]) * numpy.NaN
                M = numpy.ones([len(X), len(Y)])
                for i in range(len(x)):
                    Z[XX[x[i]], YY[y[i]]] = z[i]
                    M[XX[x[i]], YY[y[i]]] = mask[i]
                
                x = X
                y = Y
                z = Z
                mask = M                
            elif not (len(z.shape) == 2 and z.shape[:1] == x.shape and z.shape[1:] == y.shape):
                raise ValueError("Unsupported arguments")
        
        self.x = x
        self.y = y
        self.z = z
        self.mask = mask
        gc.collect()
    
    def n_x(self): "number of columns"; return len(self.x)
    def n_y(self): "number of rows"; return len(self.y)
    
    def dx(self): "average column spacing"; return float(self.x.max() - self.x.min()) / (self.n_x() - 1) if self.n_x() > 0 else numpy.NaN
    def dy(self): "average row spacing"; return float(self.y.max() - self.y.min()) / (self.n_y() - 1) if self.n_y() > 0 else numpy.NaN
    def min_x(self): "the minimum x coordinate"; return self.x.min()
    def max_x(self): "the maximum x coordinate"; return self.x.max()
    def min_y(self): "the minimum y coordinate"; return self.y.min()
    def max_y(self): "the maximum y coordinate"; return self.y.max()
    
    @staticmethod
    def loadascii(filename, load_mask=False):
        """
        load a 2D dataset from an ASCII file. If 'load_mask' is True,
        load a four column file: X, Y, Z, mask,
        otherwise load a three column file: X, Y, Z.
        """
        data = ascii2numpy.loadascii(filename, 4 if load_mask else 3)
        if load_mask:
            x, y, z, mask = data.transpose()
        else:
            x, y, z = data.transpose()
            mask = None
        return Dataset(x, y, z, mask)
    
    def saveascii(self, filename, save_mask=False):
        """
        Save the dataset as an ASCII file.
        Depending on the value of 'save_mask', either 3 (X, Y, Z) or
        4 (X, Y, Z, mask) columns are stored.
        """
        data = self.flatten(save_mask)
        flib.saveascii(filename, data, append=False,
            header="# [0] x\n# [1] y\n# [2] z\n# [3] mask" if save_mask else "# [0] x\n# [1] y\n# [2] z")
    
    @staticmethod
    def loadbinary(filename):
        """
        Load a dataset from a binary file. Compatible with 'savebinary'
        """
        d = flib.loadbinary(filename, full=True)[2]
        try:
            x = d['x']
            y = d['y']
            z = d['z'].reshape([len(x), len(y)])
            mask = d['mask'].reshape([len(x), len(y)])
            dataset = Dataset(x, y, z, mask)
            if 'header' in d.keys():
                dataset.header = d['header']
            return dataset
        except KeyError:
            raise ValueError('Invalid file format')
    
    def savebinary(self, filename, header=''):
        """
        Save the dataset in a binary format, compatible with 'loadbinary'
        The following sections are stored in the output file:
        
        'header' - equal to the optional 'header' argument of this method.
        'x' - 1D array of possible X values (M elements)
        'y' - 1D array of possible Y values (N elements)
        'z' - 2D array containing the Z (MxN)
        'mask' - 2D array containing the mask (MxN)
        """
        flib.savebinary(filename, {'header': repr(header), 'x': self.x, 'y': self.y, 'z': self.z.flatten(), 'mask': self.mask.flatten()})
    
    def x2D(self):
        "return a 2D array of x coordinates of corresponding z's"
        return numpy.vstack([self.x] * self.n_y()).transpose()
    
    def y2D(self):
        "return a 2D array of y coordinates of corresponding z's"
        return numpy.vstack([self.y] * self.n_x())
    
    def r2D(self, x0=0.0, y0=0.0):
        "return a 2D array of distances from given point (x0,y0) to the points of the Dataset"
        return ((self.x2D() - x0)**2 + (self.y2D() - y0)**2)**0.5
    
    def flatten(self, flatten_mask=False):
        "convert into three/four 1D arrays: x, y, z (, mask)"
        
        x = self.x2D().flatten()
        y = self.y2D().flatten()
        z = self.z.flatten()
        
        if flatten_mask:
            return x, y, z, self.mask.flatten()
        else:
            return x, y, z
    
    def window(self, min_x = None, min_y = None, max_x = None, max_y = None):
        """
        return a subset of the dataset framed within:
        min_x <= x <= max_x,
        min_y <= y <= max_y.
        """
        if min_x is None: min_x = self.min_x()
        if max_x is None: max_x = self.max_x()
        if min_y is None: min_y = self.min_y()
        if max_y is None: max_y = self.max_y()
        
        ii_x = (self.x >= min_x) & (self.x <= max_x)
        ii_y = (self.y >= min_y) & (self.y <= max_y)
        return Dataset(self.x[ii_x], self.y[ii_y], self.z[ii_x,:][:,ii_y], self.mask[ii_x,:][:,ii_y])
    
    def correlation(self, max_di=-1, max_dj=-1, max_dx=None, max_dy=None):
        """
        Calculate 2D correlation over the dataset.
        The range of Delta x, Delta y is symmetric around origin and
        defined by max_di/max_dx, max_dj/max_dy.
        max_di and max_dj define the range in units of dx() and dy().
        max_dx and max_dy define them in the units of the original dataset.
        
        'mask' of the dataset is used as a boolean map:
        only points with 'mask != 0' are used.
        
        The result is return as a Dataset object with a range spanning
        from -max_dx, -max_dy to max_dy, max_dy.
        In the result 'mask' contains numbers of pairs of points of the
        original dataset used to calculate a particular point in the correlation.
        """
        if max_dx is not None:
            max_di = int(max_dx / self.dx())
        if max_dy is not None:
            max_dj = int(max_dy / self.dy())
        n_x = self.n_x()
        n_y = self.n_y()
        if max_di >= n_x or max_di < 0:
            max_di = n_x - 1
        if max_dj >= n_y or max_dj < 0:
            max_dj = n_y - 1
        
        ii, jj, K, N = _twoD.correlation(self.z, numpy.asarray(self.mask, dtype=bool), max_di, max_dj)
        
        gc.collect()
        return Dataset(ii * self.dx(), jj * self.dy(), K, N)
    
    def filter(self, di=1, dj=1, dx=None, dy=None, filter_mask=None, residue=False):
        """
        Apply a low pass filter to the dataset.
        The filter has a step-function kernel, e.g.
        For every point except near the edge a rectangle di by dj points around it are averaged
        to get value value. The rectangle can be defined in number of pixels
        using 'di' and 'dj' or in lengths using 'dx' and 'dy'.
        
        'filter_mask' can be used to make a non-square filtering window.
        It is then a filtering 
        of different points within a window.
        
        The result is return as a Dataset object with a di by dj smaller size
        than the original one.
        
        If 'residue' is True, the difference between a subset of the original
        dataset, in a window where the filtered version exists,
        and the filtered version is returned.
        """
        if filter_mask is None:
            if dx is not None:
                di = int(dx / self.dx())
            if dy is not None:
                dj = int(dy / self.dy())
            filter_mask = numpy.ones([di, dj])
        else:
            filter_mask = numpy.asarray(filter_mask)
            if len(filter_mask.shape) != 2:
                raise ValueError("'filter_mask' should be a 2D array")
            di, dj = filter_mask.shape
            
        filter_mask = numpy.asarray(filter_mask, dtype=numpy.float)
        filter_mask /= numpy.sum(filter_mask)
        
        min_i = di / 2
        min_j = dj / 2
        max_i = self.n_x() + min_i - di + 1
        max_j = self.n_y() + min_j - dj + 1
        
        f = _twoD.filter(self.z, filter_mask)
        if residue:
            f = self.z[min_i:max_i, min_j:max_j] - f
        
        gc.collect()
        if (self.n_x() - di) % 2 == 0:
            x = 0.5 * (self.x[min_i:max_i] + self.x[min_i-1:max_i-1])
        else:
            x = self.x[min_i:max_i]

        if (self.n_y() - dj) % 2 == 0:
            y = 0.5 * (self.y[min_j:max_j] + self.y[min_j-1:max_j-1])
        else:
            y = self.y[min_j:max_j]
            
        return Dataset(x, y, f)
    
    def round_filter(self, n=None, r=None, residue=False):
        """
        Return a version of the dataset filtered using a "round" filter:
        the filter function is "1" for points within a distance 'r'
        or under 'n' pixels away and "0" outside. Either 'n' or 'r'
        is expected.
        
        the 'residue' optional argument works in the same way as
        for the 'filter' method.
        """
        if n is None and r is None:
            raise ValueError("Either 'n' or 'r' should be specified")
        if r is not None:
            n = r / self.dx()
        n = int(numpy.ceil(n))
        if n < 1:
            return self.copy()
        
        mask = Dataset(numpy.arange(-n, n+1), numpy.arange(-n, n+1), numpy.zeros([2*n+1, 2*n+1]))
        x = mask.x2D()
        y = mask.y2D()
        r = (x**2 + y**2)**0.5
        mask = numpy.asarray((r <= n), dtype=int)
        
        return self.filter(filter_mask = mask, residue=residue)

    def round_gaussian_filter(self, n=None, r=None, N=3, residue=False):
        """
        Return a version of the dataset filtered using a "round" gaussian filter:
        the filter function that is a gaussian with a width specified
        by 'r' (length) or 'n' (in number of pixels). The mask is created
        2*n*N+1 by 2*n*N+1 large (a real gaussian never goes to zero, so
        for a finite sample a cutoff is needed).
        
        the 'residue' optional argument works in the same way as
        for the 'filter' method.
        """
        if n is None and r is None:
            raise ValueError("Either 'n' or 'r' should be specified")
        if r is not None:
            n = r / self.dx()
        n = int(numpy.ceil(n))
        if n < 1:
            return self.copy()
        
        mask = Dataset(numpy.arange(-N*n, N*n+1), numpy.arange(-N*n, N*n+1), numpy.zeros([2*N*n+1, 2*N*n+1]))
        mask = numpy.exp(-0.5 * (mask.r2D()*1.0/n)**2)
        
        return self.filter(filter_mask = mask, residue=residue)
    
    def __repr__(self): return "2D Dataset (%dx%d pixels, %.3gx%.3g um)" % (self.n_x(), self.n_y(), self.max_x() - self.min_x(), self.max_y() - self.min_y())
    
    def copy(self):
        "return a copy of the dataset"
        return Dataset(self.x.copy(), self.y.copy(), self.z.copy(), self.mask.copy())
    
    def hist_z(self, bins=100):
        """
        return a histogram of the height distribution.
        """
        return numpy.histogram(self.z, bins, normed=True, new=True)
    
    def rz(self, minN = 100, return_mask=False):
        """
        Return a tuple of (r, z), where r=(x**2 + y**2)**0.5 is the distance
        from the origin to the specified point on the dataset.
        Only points for which the 'mask' greater than or equal to 'minN'
        are returned. If 'return_mask' is True, a tuploe (r, z, mask) is returned instead.
        """
        x, y, z, N = self.flatten(True)
        ii = N >= minN
        r = (x**2 + y**2)**0.5
        r, z, N = r[ii], z[ii], N[ii]
        ii = numpy.argsort(r)
        if return_mask:
            return r[ii], z[ii], N[ii]
        else:
            return r[ii], z[ii]
    
    def group_rz(self, bins = None, dr = None, di = None, minN = 100, return_mask=False):
        """
        Return a tuple of (r, z), where results of rz() are grouped together based on 'bins'
        and averaged within each group. If 'return_mask' is True, (r, z, mask) is returned
        instead, similar to the
        
        Only points with 'mask' greater than or equal to 'minN' are used in the averaging.
        During the averaging the 'mask' is used to weight different points averaged together.
        
        Only one out of 'bins', 'dr' and 'di' should be specified.
        
        * if 'bins' is specified and is a number, than this number is used as
        a number of bins extending from the minimum to the maximum value 'r' takes.
        
        * if 'bins' is specified and is an array, the values in this array are used
        as the boundaries: bins[n] <= r < bins[n+1], etc. The array should be sorted
        in an accending order.
        
        * if 'dr' is specified than the range from the minimum to the maximum value
        'r' takes is divided into bins 'dr' long.
        
        * specifying 'di' is equivalent to setting 'dr = di * dx()', e.g. the
        step of binning is set in units of the scan lattice.
        
        If none of 'bins', 'dr' and 'di' is specified, the default behaviour is to set 'di=0.5'
        """
        r, z, N = self.rz(minN, True)
        
        # check the arguments specifying the bins
        i = int(bins is not None) + int(dr is not None) + int(di is not None)
        if i > 1:
            raise ValueError("only one out of 'bins', 'dr' and 'di' can be specified")
        if i == 0:
            di = 0.5
        
        if di is not None:
            dr = di * self.dx()
        
        if dr is not None:
            bins = numpy.arange(min(r), max(r) + dr, dr)
        else:
            if numpy.isscalar(bins):
                dr = (max(r) - min(r)) / bins
                bins = numpy.arange(min(r), max(r) + dr, dr)
            else:
                bins = numpy.asarray(bins)
        
        r, z, N = _twoD.group1d(r, z, N, bins)
        ii = N > 0
        if return_mask:
            return r[ii], z[ii], N[ii]
        else:
            return r[ii], z[ii]
    
    def plot(self, interpolation='nearest', origin='lower', accurate_edges=False, **vargs):
        """
        Plot the map using pylab.imshow
        """
        k = 0.5 if accurate_edges else 0.0
        extent = [
            self.min_x() - k*self.dx(), self.max_x() + k*self.dx(),
            self.min_y() - k*self.dy(), self.max_y() + k*self.dy()]
        if origin != 'lower':
            extent[2], extent[3] = extent[3], extent[2]
        from pylab import imshow
        return imshow(self.z.transpose(), interpolation=interpolation, origin=origin,\
                extent=extent, **vargs)
    
    def plot_and_colorbar(self, interpolation='nearest', origin='lower', orientation='vertical', accurate_edges=False, label=None, cbar_vargs={}, **vargs):
        """
        Plot the map using pylab.imshow and display a colorbar.
        'label' optional argument defines the colorbar label.
        """
        from pylab import colorbar, draw
        
        m = self.plot(interpolation=interpolation, origin=origin, accurate_edges=accurate_edges, **vargs)
        p = colorbar(orientation=orientation, **cbar_vargs)
        if label is not None:
            p.set_label(label)
            draw()
        return m, p
    
    def plot_hist_z(self, bins=100, **vargs):
        """
        plot a histogram (PDF) of the height distribution.
        """
        from pylab import bar
        n, bins = self.hist_z(bins)
        w = (bins[1:] - bins[:-1])
        return bar(left=bins[:-1], height=n, width=w, **vargs)
    
    def plot_rz(self, minN = 100, **vargs):
        """
        plot a dependence of the height 'z' on the size of the matrix.
        """
        from pylab import plot
        r, z = self.rz(minN)
        return plot(r, z, **vargs)

    def plot_log_rz(self, bins=1000, **vargs):
        """
        plot a dependence of the height 'z' on the size of the matrix
        on a log-log scale
        """
        from pylab import loglog
        self.mask[:] = 1.0
        r, z = self.group_rz(minN = 0, bins=bins)
        ii = numpy.isfinite(numpy.log(r)) & numpy.isfinite(numpy.log(z))
        return loglog(r[ii], z[ii], **vargs)
    
    def plot_grouped_rz(self, bins = None, dr = None, di = None, minN = 1000, **vargs):
        """
        plot a correlation grouped over different directions in the plane.
        """
        from pylab import plot
        r, z = self.group_rz(bins, dr, di, minN)
        return plot(r, z, **vargs)

    @staticmethod
    def random(Nx, Ny, dx, dy, mean=0.0, std=1.0):
        """
        return a randomly generated dataset Nx by Ny points at 'dx'/'dy' spacing,
        heights are uncorrelated normally distributed with given 'mean' and 'std'.
        """
        import numpy.random
        return Dataset(dx, dy, numpy.random.normal(mean, std, (Nx, Ny)))
    
    def fft(self, best_cut=True):
        """
        return an FFT of the scan.
        the result is in a form of a 'Dataset' object with 'z' containing
        the amplitude and 'mask' the phase (rad) of the Fourier transform
        if 'best_cut' is True, then a maximum subset of the dataset is used
        with both powers of two points in both dimensions, the subset is
        centered in the scan.
        """
        if best_cut:
            z = self.z
            n_x, n_y = z.shape
            N_x = int(2**numpy.floor(numpy.log2(n_x)))
            N_y = int(2**numpy.floor(numpy.log2(n_y)))
            z = z[n_x/2 - N_x/2 : n_x/2 + N_x - N_x/2, :][:, n_y/2 - N_y/2 : n_y/2 + N_y - N_y/2]
        else:
            z = self.z
        
        w = numpy.fft.fft2(z) * self.dx() * self.dy()
        n_x, n_y = w.shape
        w = w[numpy.concatenate([range(n_x/2, n_x), range(0, n_x/2)]),:]\
             [:,numpy.concatenate([range(n_y/2, n_y), range(0, n_y/2)])]
        x = Dataset(1./self.dx()/n_x * numpy.arange(n_x/2 - n_x, n_x/2),
                    1./self.dy()/n_y * numpy.arange(n_y/2 - n_y, n_y/2),
                    abs(w), flib.phase(w))
        gc.collect()
        return x

##    def sub_poly(self, Nx=1, Ny=1):
##        """
##        return a copy of the dataset with a background subtracted.
##        the background is a sum of best fit polynomials of orders 'Nx' and 'Ny' in x and y,
##        respectively. If both equal to 0, only the mean value is subtracted.
##        The mask is not taken into account.
##        """
##        s = self.copy()
##        s.z -= numpy.mean(s.z)
##        if Nx > 0:
##            s.z -= numpy.polyval(numpy.polyfit(s.x2D().flatten(), s.z.flatten(), Nx), s.x2D())
##        if Ny > 0:
##            s.z -= numpy.polyval(numpy.polyfit(s.y2D().flatten(), s.z.flatten(), Ny), s.y2D())
##        return s

    def sub_poly(self, N=1):
        """
        return a copy of the dataset with a polynomial background subtracted.
        the polynomial is 2D of order N (e.g. for N=2 the scan is approximated
        using a sum of 1, x, y, x*2, x*y, y*2)
        if both equal to 0, only the mean value is subtracted.
        the mask is not taken into account and preserved.
        """
        s = self.copy()
        s.z -= numpy.mean(s.z)
        
        if N > 0:
            x = s.x2D().flatten()
            y = s.y2D().flatten()
            
            # A_i,j - values of polynomials at all points around the scan:
            # i - an index of a polynomial (arbitrary)
            # j - 1D index of a point in the scan (flattened x2D,y2D,z)
            A = numpy.zeros([(N+1)*(N+2)/2, len(x)])
            i = 0
            for iX in range(N+1):
                for iY in range(N+1 - iX):
                    A[i,:] = x**iX * y**iY
                    i += 1
            
            A = numpy.matrix(A)
            fit = numpy.matrix(s.z.flatten()) * A.I
            
            s.z -= numpy.array(fit * A).reshape(s.z.shape)
        return s
    
    def sub_mean_x(self):
        s = self.copy()
        s.z -= numpy.vstack([s.z.mean(axis=1)] * self.n_y()).transpose()
        return s
    
    def sub_mean_y(self):
        s = self.copy()
        s.z -= numpy.vstack([s.z.mean(axis=0)] * self.n_x())
        return s
