import numpy as np
from scipy.linalg import fractional_matrix_power, pinvh
import matplotlib.patches


class Ellipsoid():
    '''
    Class encoding an ellipsoid in d-dimensions.  
    Ellipsoid is specified by E = {x: (x - xc)^T K (x - xc) <= r^2}
    Arguments: 
        K:      positive definite numpy matrix for ellipsoid   
        xc:     list, tuple, or numpy array for the center of ellipsoid       
                (defaults to origin if not provided)
        r:      scalar indicating the radius of the ellipsoid 
                (defaults to 1)
        pos_def_check:  whether to check if supplied K is positive definite
                        (defaults to false)
        tol:    tolerance for positive definiteness in terms of min e-val
                of (K/r**2)
                (defaults to 1e-8, applicable only if pos_def_check=True)
    '''

    def __init__(self, K=None, xc=None, r=1, pos_def_check=False, tol=1e-8):
        self.__pos_def_check = pos_def_check
        self.__tol = tol
        self.K = K
        if xc is None:
            self.xc = np.zeros(shape=self.K.shape[0], dtype=float)
        else:
            self.xc = xc
        self.r = r

    def __str__(self):
        output = 'Ellipsoid is all x such that:\n'
        K_rows = str(self.K).split('\n')
        max_len = max(map(len, K_rows))
        if max_len <= 45:
            mid_row = int(self.K.shape[0] / 2)
            K_rows = str(self.K).split('\n')
            for i in range(len(K_rows)):
                row = K_rows[i]
                if i == mid_row:
                    output += '{:<10}{:^45}{:>15}'.format('(x-x_c)^T', row,
                                                          '(x-x_c) <= r^2\n')
                else:
                    output += '{:<10}{:^45}{:>15}'.format('', row, '\n')
        else:
            output += '\t (x-x_c)^T K (x-x_c) <= r^2'
            output += '\t (see instance attribute K) \n'
        output += 'Ellipsoid radius: \n'
        output += '\t r = ' + str(self.r) + '\n'
        output += 'Ellipsoid center: \n'
        output += '\t x_c^T = ' + str(self.xc) + '\n'
        return output

    @property
    def K(self):
        return self.__K

    @K.setter
    def K(self, K):
        if len(K.shape) != 2:
            raise TypeError('K must be 2-dim numpy array')
        elif self.__pos_def_check:
            if np.allclose(K, K.T) and len(K.shape) == 2:
                self.__K = K.astype(float)
                evals = np.linalg.eigvalsh(K)
                if any(evals <= self.__tol):
                    raise ValueError('K must be positive definite')
                else:
                    self.__K = K.astype(float)
        else:
            self.__K = K.astype(float)

    @property
    def xc(self):
        return self.__xc

    @xc.setter
    def xc(self, xc):
        xc = np.array(xc, dtype=float).ravel()
        if len(xc) != self.K.shape[1]:
            raise TypeError('xc not conformable with K')
        else:
            self.__xc = np.array(xc, dtype=float).ravel()

    @property
    def r(self):
        return self.__r

    @r.setter
    def r(self, r):
        if isinstance(r, np.ndarray):
            r = r.item()
        if not isinstance(r, (int, float)):
            raise TypeError('r must be a scalar')
        elif r <= 0:
            raise ValueError('r must be positive')
        else:
            self.__r = float(r)

    def copy(self):
        K = self.K.copy()
        xc = self.xc.copy()
        r = self.r
        pos_def_check = self.__pos_def_check
        tol = self.__tol
        return Ellipsoid(K=K, xc=xc, r=r, pos_def_check=pos_def_check, tol=tol)

    def is_inside(self, y):
        '''
        Tests whether a point y is contained inside the ellipsoid.
        Arguments:
            y:  test point
        Returns:
            True or False depending on whether y is in the ellipsoid.
        '''
        y = np.array(y).ravel()
        val = np.dot(np.dot(np.transpose(y - self.xc), self.K), y - self.xc)
        if val <= self.r**2:
            return True
        else:
            return False

    def get_slice(self, z, verbose=False):
        '''
        Gets a two-dimensional slice of an ellipsoid.  
        Arguments:
            z:          A numpy array corresponding to the slice.  If the 
                        ellipsoid is d > 2 dimensions, then (d-2) dimensions 
                        of this array should be fixed values, and the remaining 
                        2 dimensions should be np.nan corresponding to the 
                        free coordinates.
            verbose:    Boolean indicating whether to print out free 
                        coordinates, fixed coordinates and their values, 
        Returns:
            Ellipsoid instance corresponding to 2-D representation.
        '''
        dim = self.K.shape[0]
        if dim == 2:
            if verbose:
                print('get_slice: ellipsoid already 2-dimensional')
            return self
        elif dim > 2:
            if isinstance(z, (list, tuple, np.ndarray)):
                z = np.array(z, dtype=float)
            else:
                raise TypeError('z must be an array like object')
            if sum(np.isnan(z)) != 2:
                raise ValueError('Need 2 free dimensions specified by np.nan')
            else:
                slice_idx = np.argwhere(np.isnan(z)).ravel()
                idx = np.setdiff1d(np.arange(dim), slice_idx)
                z1 = z[idx]
                if verbose:
                    print('free coordinates:', '\t' + str(slice_idx), sep='\n')
                    print('fixed coordinates:', '\t' + str(idx), sep='\n')
                    print('fixed coordinate values:', '\t' + str(z1), sep='\n')
                K = self.K[np.ix_(slice_idx, slice_idx)]
                # check for non-empty slice
                K1 = self.K[np.ix_(idx, idx)]
                xc = self.xc[slice_idx]
                xc1 = self.xc[idx]
                r1 = np.dot(z1 - xc1, np.matmul(K1, z1 - xc1))
                r = self.r - r1
                if r <= 0:
                    raise ValueError('get_slice: z not inside ellipsoid')
                else:
                    return Ellipsoid(K=K, xc=xc, r=r,
                                     pos_def_check=self.__pos_def_check,
                                     tol=self.__tol)
        else:
            raise TypeError('get_slice: requires dimension >= 2')

    def get_patch(self, **kwargs):
        '''
        Returns a matplotlib patch to plot 2D Ellipsoid. Note that if you are 
        plotting a slice of a higher dimensional ellipsoid, first generate a 2D 
        slice using the method Ellipsoid.get_slice
        Arguments:
            **kwargs: valid kwargs for matplotlib.patches.Ellipse
        '''
        if self.K.shape[0] != 2:
            raise ValueError('Require 2-dimensional Ellipsoid for plotting')
        else:
            # compute eigenvalues and associated eigenvectors
            vals, vecs = np.linalg.eigh(np.linalg.inv(self.K/self.r**2))
            # compute rotation of ellipse using first eigenvector
            x, y = vecs[:, 0]
            theta = np.degrees(np.arctan2(y, x))
            # get length of ellipse along each eigevector
            w, h = 2 * np.sqrt(vals)
            return matplotlib.patches.Ellipse(
                xy=self.xc, width=w, height=h, angle=theta, **kwargs)

    def sample_surface(self, n):
        '''
        Generate uniformly distributed random samples from surface of the 
        ellipsoid.
        Arguments:
            n: the number of samples to draw
        Returns:
            an n by dim matrix of samples
        '''
        X = np.random.randn(n, self.K.shape[1])
        X /= np.linalg.norm(X, axis=1)[:, None]
        K_inv = pinvh(self.K)
        X = self.r*np.dot(fractional_matrix_power(K_inv, .5), X.T)
        return X.T + self.xc

    def sample(self, n):
        '''
        Generate uniformly distributed samples from the ellipsoid.
        Arguments:
            n: the number of samples to draw
        Returns:
            an n by dim matrix of samples
        '''
        X = self.sample_surface(n) - self.xc
        X *= np.random.uniform(size=n).reshape(n, 1)
        X += self.xc
        return X
