"""
aver.py

Classes: 'Mean' represents an average of a set of numbers or a random number
'LinearFit' represents a result of a straight line fit.

Functions: 'average', 'combine', 'stdcorrection'

To get an array of mean values from an 'array' of 'Mean' objects you can try
'[x.mean for x in array]' or 'numpy.array([x.mean for x in array])'.

Lev, 17 Dec 2008
"""

# Changelog
#
# 4 Dec 2008
# First draft
#
# 5 Dec 2008
# add 'finites' option to 'average' and 'combine';
#
# 17 Dec 2008
# Add 'Mean.totex()'
#
# 20 Dec 2008
# Modify standard deviation handling to make it ((<x^2> - <x>^2)*(N+1)/N)^0.5
# this accounts for a fact that the mean value is uncertain.
# Add 'stdcorrection' helper function and 'Mean.rmsdev' property
# for this purpose
#
# 11-13 Jul 2010
# rework 'combine' heavily.
#
# 14 Jul 2010
# add 'x_func' to LinearFit.plot_envelope

import numpy, flib

class Mean(object):
    """
    Represent a result of averaging numbers together or a random number:
    
    mean - (the best estimate of) the mean

    std - standard deviation. Note that for correct combination of uncertainties,
        this is should be calculated using
        std = \sqrt{\frac{\sum x^2 - (\sum x)^2}/(N-1)},
        with N-1 in the denominator, rather N (and 1 for N < 2).
        This accounts for the fact that the average, from which individual
        points deviate is not known precisely.

    min, max - extrema of the data

    N - amount of data. 0 means unknown/undefined, such as in case of combining
        random quantities.
    
    derived properties (calculated on creation):
    
    error = std/sqrt(N) - error of the mean (equal to 'std' for N < 2)
    
    median = (max + min)/2
    halfspan = (max - min)/2 - to be used for errorbars that encompass all data points
    
    rmsdev - root of mean square of deviation of individual values from the best estimate of the mean
    rmsdev = \sqrt{\frac{N-1}{N}}std for N > 1, 0 for N = 1 and just std for N < 1.
    
    stderror - error of the standard deviation (equal to 0 for N < 2)
    
    Arithmetic operations +,-,*,/,** are supported between Mean objects as well
    as between Mean objects and numbers. The arguments are considered uncorrelated.
    
    When combining uncertainties all information about number of points is lost
    (because it can be different for the two arguments). Standard deviation,
    rather than error of the mean is retained, such that the result represents
    scatter of the hypothetic combination of the numbers averaged together.
    Use 'be()' before combining uncertainties to get best estimates of the
    means with its errors.
    
    DISCLAIMER: When dividing by or deriving a power of a 'Mean' object the
    standard deviation of the result is only valid under the assumption that
    for the argument it is much smaller than the mean.
    
    Only a real number can be used as a right hand side argument for **.

    A number is equal ('==' returns True) to the Mean object if it is within one
    'std' from the mean. Two Mean objects are equal if their difference is equal
    to 0 in the above sense.

    'fromtuple()'/'totuple()' methods are primarily for spreadsheet I/O.
    """

    def __init__(self, mean=numpy.NaN, std=numpy.NaN, min=-numpy.Inf, max=numpy.Inf, N=0):
        """
        Construct a Mean from its statistical properties.
        'mean' is (the best estimate of) the mean
        'std' is standard deviation
        'min' and 'max' are value limits
        'N' is number of points that contributed to the estimate of the mean/std
        or '0' if unavailable.
        """
        self.mean = mean
        self.std = abs(std)
        self.min = numpy.min([min, max])
        self.max = numpy.max([min, max])
        if numpy.isnan(self.min): self.min = -numpy.Inf
        if numpy.isnan(self.max): self.max = numpy.Inf
        self.N = int(N)
        self.calculate()

    def calculate(self):
        "Calculate derived quantities"
        self.error = self.std if self.N < 2 else self.std / self.N**0.5
        self.median = 0.5 * (self.max + self.min)
        self.halfspan = 0.5 * (self.max - self.min)
        self.rmsdev = self.std / stdcorrection(self.N) if self.N != 1 else 0
        self.stderror = 0 if self.N < 2 else self.error / numpy.sqrt(2)

    def be(self):
        """
        Return the best estimate of the mean, an object with both 'std' and 'error' set to the 'error' of the processed Mean
        and zero as number of points. This is useful when combining uncertainties, because then the number of points
        is lost and the standard deviation, the scatter of points, is kept as both 'std' and 'error'.
        """
        return Mean(self.mean, self.error, self.min, self.max, 0)

    def totuple(self): "Convert into a tuple (mean, std, min, max, N)"; return (self.mean, self.std, self.min, self.max, self.N)
    @staticmethod
    def fromtuple(array): "Convert a tuple/list/array (mean, std, min, max, N) into a Mean object"; return Mean(*array)

    def copy(self): return Mean.fromtuple(self.totuple())

    def __repr__(self): return "x = %g +- %g (%g <= x <= %g)%s" % (self.mean, self.std, self.min, self.max,
                        (", %d samples" % self.N if self.N > 0 else ""))
    def __min__(self): return self.min
    def __max__(self): return self.max
    def __neg__(self): return Mean(-self.mean, self.std, -self.max, -self.min, self.N)
    def __abs__(self): return abs(self.mean)
    def __complex__(self): return complex(self.mean)
    def __int__(self): return int(self.mean)
    def __float__(self): return float(self.mean)

    def __add__(self, other):
        if isinstance(other, Mean):
            return Mean(self.mean + other.mean, (self.std**2 + other.std**2)**0.5,
                                self.min + other.min, self.max + other.max, 0)
        elif numpy.isscalar(other):
            return Mean(self.mean + other, self.std,
                                self.min + other, self.max + other, self.N)
        raise NotImplemented

    def __sub__(self, other): return self + (-other)

    def __mul__(self, other):
        if isinstance(other, Mean):
            extrema = [self.min * other.min, self.max * other.max, self.min * other.max, self.max * other.min]
            return Mean(self.mean * other.mean,
                                (self.mean**2 * other.std**2 + other.mean**2 * self.std**2 +
                                self.std**2 * other.std**2)**0.5,
                                min(extrema), max(extrema), 0)
        elif numpy.isscalar(other):
            extrema = [self.min * other, self.max * other]
            return Mean(self.mean * other, abs(self.std * other),
                                min(extrema), max(extrema), self.N)
        raise NotImplemented

    def __truediv__(self, other): return self * (other**-1)
    __div__ = __truediv__

    def __pow__(self, other):
        if not numpy.isscalar(other):
            raise NotImplemented
        
        # if max and min are of oposite signs, the negative powers hit infinity
        if self.min * self.max < 0 and other < 0:
            return Mean(self.mean**other, abs(other * self.mean**(other-1) * self.std), -numpy.Inf, numpy.Inf, self.N)
        
        extrema = [self.min**other, self.max**other]
        # if min and max are of oposite signs, then 0 can be one of the extrema (for even powers)
        if self.min*self.max < 0:
            extrema.append(0)
            
        return Mean(self.mean**other, abs(other * self.mean**(other-1) * self.std),
            min(extrema), max(extrema), self.N)

    __radd__ = __add__
    __rmul__ = __mul__

    def __rsub__(self, other): return (-self) + other
    def __rtruediv__(self, other): return other * (self**(-1))
    __rdiv__ = __rtruediv__

    def __eq__(self, other):
        if isinstance(other, Mean):
            return abs(self.mean - other.mean)**2 <= self.std**2 + other.std**2
        elif numpy.isscalar(other):
            return abs(self.mean - other)**2 <= self.std**2
        raise NotImplemented

    def same(self, other):
        """
        Compares two Mean objects to be identical. Return True only if all
        5 fields, the mean, standard deviation, value boundaries and number of points
        are the same.
        """
        return (isinstance(other, Mean) and flib.floatcmp(self.mean, other.mean)
                and flib.floatcmp(self.std, other.std) and flib.floatcmp(self.min, other.min)
                and flib.floatcmp(self.max, other.max) and self.N == other.N)

    def isfinite(self):
        "Return True if a Mean object has all properties finite, otherwise False"
        return all(numpy.isfinite(self.totuple()))

    def brief(self, fmt=None, tex=False):
        "Return a string representation without boundaries"
        if fmt is None:
            if self.std > 0:
                std_order = numpy.floor(numpy.log10(self.std))
            else:
                std_order = 0
            if std_order < 0:
                fmt = '%%.%df' % (-std_order)
            elif std_order == 0:
                fmt = '%g'
            else:
                fmt = '%d'
        
        return fmt % self.mean + (r'\pm ' if tex else '+-') + fmt % self.std

    def totex(self, varname = 'x', fmt = '%g'):
        "Return a neat TeX represenation"
        min = (fmt % self.min + " \\le") if numpy.isfinite(self.min) else "-\\infty <"
        max = ("\\le " + fmt % self.max) if numpy.isfinite(self.max) else "< \\infty"
        return ('%s = %g \pm %g\,(%s %s %s)'.replace('%g', fmt)) % (varname, self.mean, self.std, min, varname, max)

def stdcorrection(N):
    """
    Return sqrt(N/(N-1)) for N > 1, else 1
    This factor can be included in standard deviation calculation to account
    for uncertainty of the mean
    """
    return (N/(N-1.0))**0.5 if N > 1 else 1.0

def average(*args, **vargs):
    """
    Construct a Mean object from an array of numbers.
    If a single argument is supplied and it is iterable, it is used as
    an array of values, otherwise each argument is one of the values.
    
    Mean.fromarray([1,2,3]) is equivalent to Mean.fromarray(1,2,3)
    
    If a keyword argument 'finites' is specified and is 'True', then all Not-a-Number's
    are discarded from averaging. Other keyword arguments are discarded and not
    taken as numbers to average. A NaN-Mean object is returned if no finite arguments are encountered.
    """        
    if len(args) == 1 and numpy.iterable(args[0]): args = args[0]
    
    value = numpy.asarray(args)

    if 'finites' in vargs and vargs['finites']:
        value = value[numpy.isfinite(value)]
    
    if len(value) < 1:
        return Mean(numpy.NaN, numpy.NaN, -numpy.Inf, numpy.Inf, 0)
    else:
        N = len(value)
        return Mean(numpy.mean(value), numpy.std(value) * stdcorrection(N), numpy.min(value), numpy.max(value), N)

def combine(*args, **vargs):
    """
    Combine a set of Mean objects representing similar measurements into a single one.
    If all arguments have non-zero number of points '.N', these are used for weighting.
    Otherwise all objects are considered to consist of equal numbers of points.
    
    The function can be invoked in two ways: with a single argument containing
    a list of objects to combine, or one object per argument.

    If a keyword argument 'finites' is specified and is 'True', then all Not-a-Number's
    are discarded from averaging.
    
    If a keyword argument 'finite_mean_and_error' is specified and is 'True',
    then all arguments with Not-a-Number mean or error are discarded from averaging.
    
    If a keyword argument 'weightonstd' is specified and is 'True', then 1/std^2 is used
    as a weight, rather than number of points.
    
    Other keyword arguments are discarded and not taken as numbers to average.
    A NaN-Mean is returned if no finite arguments are encountered.
    """
    if len(args) == 1 and numpy.iterable(args[0]): args = args[0]
    
    if 'finites' in vargs and vargs['finites']:
        args = filter(lambda x: x.isfinite(), args)
    elif 'finite_mean_and_error' in vargs and vargs['finite_mean_and_error']:
        args = filter(lambda x: isfinite(x.mean) and isfinite(x.error), args)

    weightonstd = 'weightonstd' in vargs and vargs['weightonstd']
    
    if len(args) < 1: return Mean()
    
    Ns = numpy.zeros([len(args)])
    means = numpy.zeros([len(args)])
    rmsdevs = numpy.zeros([len(args)])
    stds = numpy.zeros([len(args)])
    errors = numpy.zeros([len(args)])
    mins = numpy.zeros([len(args)])
    maxes = numpy.zeros([len(args)])
    
    for i in range(len(args)):
        if not isinstance(args[i], Mean): raise ValueError("Argument %d of an unsupported type" % (i+1))
        
        Ns[i] = args[i].N
        means[i] = args[i].mean
        rmsdevs[i] = args[i].rmsdev
        stds[i] = args[i].std
        errors[i] = args[i].error
        mins[i] = args[i].min
        maxes[i] = args[i].max

    if weightonstd:
        w = errors**-2
        m = numpy.sum(w*means) / numpy.sum(w)
        s = numpy.sum(w)**-0.5
        chi2 = numpy.mean(w * (means - m)**2)
        if chi2 > 1:
            s *= chi2**0.5
            chi2 = 1.0
        m = Mean(m, s, numpy.min(mins), numpy.max(maxes), 0)
        m.chi2 = chi2
        return m

    if numpy.min(Ns) < 1:
        # equalise weights of all arguments if any of them has lost information about the number of points. Ntotal is used only for construction of the result.
        Ns[:] = 1
        Ntotal = 0
    else:
        Ntotal = sum(Ns)
    
    mean = numpy.sum(Ns * means) / numpy.sum(Ns)
    rmsdev2 = numpy.sum(Ns * (rmsdevs**2 + means**2)) / numpy.sum(Ns) - mean**2
    
    return Mean(mean, rmsdev2**0.5 * stdcorrection(Ntotal), numpy.min(mins), numpy.max(maxes), Ntotal)

class LinearFit(object):
    """
    Straight line fit y(x) = A*x + B with errorbars
    Keeps the information necessary to reconstruct uncertainties of
    the intercepts of the fit with x=X.
    """

    def __init__(self, A, B, A_err, B_err, x_mean, x_std2, chi2):
        """
        Construct a linear fit object. The model is:
        
        y(x) = A*(x-x_mean) + B.
        
        A - slope of the fit
        B - intercept of the fit
        """
        self.A = A
        self.B = B
        self.A_err = A_err
        self.B_err = B_err
        self.x_mean = x_mean
        self.x_std2 = x_std2
        self.chi2 = chi2

    @staticmethod
    def fit(x, y, y_err=None):
        """
        Evaluate a straight line fit to a model y(x) = A*x + B.
        x are considered precise and y uncertain.
        
        A fit is unweighted unless y are 'Mean' objects or 'y_err' are supplied.
        
        Return a LinearFit object
        """
        
        x = numpy.asarray(x)
        
        if isinstance(y[0], Mean):
            if y_err is not None: raise ValueError("'y_err' should not be specified when 'y' have uncertainties themselves.")
                
            y_err = numpy.asarray([i.std for i in y])
            y = numpy.asarray([i.mean for i in y])
        else:
            y = numpy.asarray(y)
                
        if len(x) < 2: raise ValueError("it takes at least 2 points to fit a straight line")
        if len(x) != len(y): raise ValueError("'x' and 'y' are different size")
        
        if y_err is not None:
            if numpy.isscalar(y_err):
                y_err = y_err * numpy.ones(x.shape)
            else:
                y_err = numpy.asarray(y_err)
            if len(y) != len(y_err): raise ValueError("'y' and 'y_err' are different size")        
            # weights based on y uncertainties
            w = y_err**-2
        else:
            w = numpy.ones(x.shape)
        
        sum_w = numpy.sum(w)
        mean_x = numpy.sum(w*x) / sum_w        
        x = x - mean_x
        
        mean_y = numpy.sum(w*y) / sum_w
        mean_x2 = numpy.sum(w*x**2) / sum_w
        mean_xy = numpy.sum(w*x*y) / sum_w
        
        A = mean_xy / mean_x2
        B = mean_y
        
        if y_err is None:
            y_err = numpy.ones(x.shape) * (sum((A*x + B - y)**2 * w) / (len(x) - 2))**0.5
            w = y_err**-2
            sum_w = numpy.sum(w)
            if sum_w > 0:
                chi2 = 1.0
            else:
                chi2 = 0.0
        else:
            chi2 = sum((A*x + B - y)**2 * w) / (len(x) - 2)
        
        A_err = (sum_w * mean_x2)**-0.5
        B_err = (sum_w)**-0.5
        
        return LinearFit(A, B, A_err, B_err, mean_x, mean_x2, chi2)
    
    def slope(self): "slope of the linear fit with its uncertainty"; return Mean(self.A, self.A_err)

    def y(self, X):
        "intercept of the linear fit with x=X"
        if numpy.isscalar(X):
            return Mean(self.y_mean(X), self.y_err(X))
        else:
            return [Mean(self.y_mean(x), self.y_err(x)) for x in X]

    def x(self, Y):
        "intercept of the linear fit with y=Y"
        if numpy.isscalar(Y):
            x0 = self.x_mean + (Y - self.B)/self.A
            return Mean(x0, self.y_err(x0)/self.A)
        else:
            return [self.x(y) for y in Y]

    def y_mean(self, x):
        return self.A*(x - self.x_mean) + self.B

    def y_err(self, x):
        return self.B_err * (1 + (x - self.x_mean)**2 / self.x_std2)**0.5

    def plot_line(self, x, *args, **vargs):
        import pylab
        return pylab.plot(x, self.y_mean(x), *args, **vargs)

    def plot_envelope(self, x, *args, **vargs):
        """
        plot "envelope" of best fits (y0 +- dy0) for every value of 'x'
        if 'x_func' is supplied, then 'x_func(x)' is plotted instead of 'x'
        """
        import pylab
        if 'x_func' in vargs.keys():
            x_func = vargs.pop('x_func')
        else:
            x_func = lambda x: x
        return pylab.fill_between(x_func(x), self.y_mean(x) - self.y_err(x), self.y_mean(x) + self.y_err(x), *args, **vargs)

    def __repr__(self): return self.str(tex=False, x='x', y='y')
    
    def str(self, tex=False, x='x', y='y'):
        slope = self.slope()
        intercept = self.y(0)
        
        if intercept.mean < 0:
            intercept.mean *= -1
            isign = '-'
        else:
            isign = '+'
        
        return "%s = (%s)%s%s %s %s" % (y, slope.brief(tex=tex), r'\times ' if tex else ' ', x, isign, intercept.brief(tex=tex))
    
    @staticmethod
    def fit_with_xyerr(x, y, x_err, y_err, Niter=1):
        """
        fit a straight line through (y+-yerr) = A(x+-x_err) + B
        the x uncertainty is added to the y uncertainty scaled by the slope A
        obtained in a fit with x_err=0. If Niter is greater than zero,
        then fit is repeated with more and more precise value of A...
        """
        x = numpy.asarray(x)
        y = numpy.asarray(y)
        x_err = numpy.asarray(x_err)
        y_err = numpy.asarray(y_err)

        if any(y_err == 0) and Niter > 1:
            fit = LinearFit.fit(x, y)
        else:
            fit = LinearFit.fit(x, y, y_err)
        for n in range(Niter - 1):
            fit = LinearFit.fit(x, y, (y_err**2 + (fit.slope().mean*x_err)**2)**0.5)
        return fit

    def unify_chi2(self):
        "ament linear fit errorbars to make chi2=1"
        self.A_err *= self.chi2**0.5
        self.B_err *= self.chi2**0.5
        self.chi2 = 1
