"""
Quaternions
and tools for rotation representations

Lev Jun-Jul 2010
"""
import aver, numpy

class Quat(object):
    def __init__(self, r=0,i=0,j=0,k=0):
        self.r = numpy.double(r)
        self.i = numpy.double(i)
        self.j = numpy.double(j)
        self.k = numpy.double(k)
    
    def totuple(self): return (self.r, self.i, self.j, self.k)

    @staticmethod
    def fromtuple(array):
        if len(array) != 4:
            raise ValueError('need a four element array to create a quaternion')
        return Quat(*array)
    
    def copy(self): return Mean.fromtuple(*self.totuple())

    def __float__(self): return self.r
    def __abs__(self): return (self.r**2 + self.i**2 + self.j**2 + self.k**2)**0.5
    def __neg__(self): return Quat(-self.r, -self.i, -self.j, -self.k)
    def conjugate(self): return Quat(self.r, -self.i, -self.j, -self.k)
    
    def __repr__(self, rfmt='%g', ifmt=None):
        if ifmt is None:
            ifmt = rfmt.replace('%', '%+')
        return '(' + (rfmt % self.r) + (ifmt % self.i) + 'i' + (ifmt % self.j) + 'j' + (ifmt % self.k) + 'k)'
    def __add__(self, other):
        if isinstance(other, Quat):
            return Quat(self.r + other.r, self.i + other.i, self.j + other.j, self.k + other.k)
        elif numpy.isscalar(other) and numpy.isreal(other):
            return Quat(self.r + numpy.double(other), self.i, self.j, self.k)
        else:
            raise ValueError('need a quaternion to add to a quaternion')
    __radd__ = __add__

    def __sub__(self, other): return self + (-other)
    def __rsub__(self, other): return (-self) + other
    
    def __mul__(self, other):
        if isinstance(other, Quat):
            return Quat(self.r*other.r - self.i*other.i - self.j*other.j - self.k*other.k,
                        self.r*other.i + self.i*other.r + self.j*other.k - self.k*other.j,
                        self.r*other.j - self.i*other.k + self.j*other.r + self.k*other.i,
                        self.r*other.k + self.i*other.j - self.j*other.i + self.k*other.r)
        elif numpy.isscalar(other) and numpy.isreal(other):
            other = numpy.double(other)
            return Quat(self.r * other, self.i * other, self.j * other, self.k * other)
        else:
            raise ValueError('need a quaternion to multiply a quaternion by')
    __rmul__ = __mul__

    def __truediv__(self, other):
        if isinstance(other, Quat):
            return self * other.conjugate() * abs(other)**(-2)
        elif numpy.isscalar(other) and numpy.isreal(other):
            other = numpy.double(other)
            return Quat(self.r / other, self.i / other, self.j / other, self.k / other)
        else:
            raise ValueError('need a quaternion to divide a quaternion by')
    __div__  = __truediv__

    def __rtruediv__(self, other): return other * self.conjugate() * abs(self)**(-2)
    __rdiv__ = __rtruediv__


    def vect(self): return numpy.array([self.i, self.j, self.k])

    @staticmethod
    def rotation(axis, angle):
        angle = (angle + numpy.pi) % (2*numpy.pi) - numpy.pi
        if angle == 0:
            return Quat(1,0,0,0)
        axis = numpy.asarray(axis, dtype=numpy.double)
        l = numpy.sum(axis**2)**0.5
        if l == 0:
            raise ValueError('the axis must have non-zero components')
        axis /= l
        c = numpy.cos(angle*0.5)
        s = numpy.sin(angle*0.5)
        return Quat(c, s*axis[0], s*axis[1], s*axis[2])

    @staticmethod
    def from_euler(alpha, beta, gamma):
        return Quat.rotation([0,0,1], alpha) * Quat.rotation([0,1,0], beta) * Quat.rotation([0,0,1], gamma)
    
    def to_euler(self):
        x = numpy.array([1.0,0.0,0.0]);
        X = self.rotate_vector(x)
        print '%s -> %s' % (x,X)
        y = numpy.array([0.0,1.0,0.0]);
        Y = self.rotate_vector(y)
        print '%s -> %s' % (y,Y)
        z = numpy.array([0.0,0.0,1.0]);
        Z = self.rotate_vector(z)
        print '%s -> %s' % (z,Z)
        N = numpy.cross(z, Z)
        if numpy.dot(N,N) > 1e-6:
            beta = numpy.arccos(numpy.dot(z,Z))
            alpha = -numpy.arctan2(numpy.dot(N,x), numpy.dot(N,y))
            gamma = numpy.arctan2(numpy.dot(N,X), numpy.dot(N,Y))
        else:
            alpha = numpy.arctan2(X[1], X[0])
            beta = gamma = 0
#        beta  = numpy.arctan2((Z[0]**2 + Z[1]**2)**0.5, Z[2])
#        alpha = numpy.arctan2(Z[1], -Z[0])
#        gamma = numpy.arctan2(X[2], -Y[2])
        return alpha, beta, gamma

    def to_rotation(self, reduce=True):
        if abs(self) == 0:
            raise ValueError('the zero quaternion does not define any rotation')
        axis = self.vect()
        c = self.r
        s = numpy.sum(axis**2)**0.5
        angle = numpy.arctan2(s,c)*2
        if angle == 0:
            axis = numpy.array([0,0,0])
        else:
            axis /= s
        if reduce:
            angle = (angle + numpy.pi) % (2*numpy.pi) - numpy.pi
            if angle < 0:
                angle = -angle
                axis = -axis
        return axis, angle

    def rotate_vector(self, v):
        return (self * Quat(0, v[0], v[1], v[2]) / self).vect()