from polytope import Polytope, cheby_ball
import polytope
import numpy as np
import matplotlib.patches


class Polytope2():
    '''
    Class encoding a polytope in d-dimensions.  Similar to:
        https://tulip-control.github.io/polytope/
    This class however does not enforce normalization, and also provides a 
    couple of new methods for extracting a 2D slice of a polytope as well as 
    plotting that slice.  It calls some methods from polytope.Polytope when
    necessary.  It omits many methods such as projections (aka 'shadows'), 
    unions of polytopes, checking to see if the polyhedron defined by the 
    provided inequalities is closed (i.e., is a polytope), etc.  

    Eventually any dependencies on the original polytope class will be removed.
    
    Arguments: 
        A:      an m x d np matrix consisting of hyperplane normals as rows
        b:      an m dimensional numpy array consisting of hyperplane offsets
        normalize:  whether to normalize the hyperplane normals, etc.
                    (defaults to False)        
    '''

    def __init__(self, A, b, normalize=False):
        self.A = A
        self.b = b
        # Normalize rows of A (and entries of b by the same factor)
        if normalize:
            self.normalize()
        # parameters set by class methods
        self.cheby_radius = None
        self.cheby_center = None
        self.fulldim = None  # not yet known
        self.minrep = None  # not yet known
        self.is_closed = None  # not yet known

    def __str__(self):
        output = 'Polytope (if full-dim and closed) is all x such that:\n'
        A_rows = str(np.round(self.A, 4)).split('\n')
        b_rows = str(np.round(self.b.reshape(-1, 1), 4)).split('\n')
        max_len = max(map(len, A_rows))
        if max_len <= 50:
            mid_row = int(self.A.shape[0] / 2)
            for (i, (A_row, b_row)) in enumerate(zip(A_rows, b_rows)):
                if i == mid_row:
                    output += '{:<50}{:<10}{:>10}'.format(A_row, 'x <=', 
                                                          b_row + '\n')
                else:
                    output += '{:<50}{:>20}'.format(A_row, b_row + '\n')
        else:
            output += '\t Ax <= b'
            output += '\t (see instance attributes A, b) \n'
        if self.minrep is None:
            output += '.remove_redundants() has not been called yet.'
        return output

    @property
    def A(self):
        return self.__A

    @A.setter
    def A(self, A):
        if not isinstance(A, np.ndarray):
            raise TypeError('A: must be a numpy array')
        elif len(A.shape) != 2:
            raise TypeError('A: must be a numpy matrix')
        else:
            self.__A = A.astype(float)

    @property
    def b(self):
        return self.__b

    @b.setter
    def b(self, b):
        # Ensure A and b are consistent with each other
        if self.A.shape[0] != len(b):
            raise ValueError('A and b are not conformable')
        elif not isinstance(b, (np.ndarray, list, tuple)):
            raise TypeError('b: must be array-like object')
        else:
            self.__b = np.array(b, dtype=float).ravel()

    def copy(self):
        return Polytope2(A=self.A.copy(), b=self.b.copy())

    def normalize(self):
        A_norm = np.linalg.norm(self.A, axis=1)
        self.A = self.A / A_norm[:, np.newaxis]
        self.b = self.b / A_norm

    def chebyshev_ball(self, inner_tol=1e-8, outer_tol=1e16):
        '''
        Calculate the largest inscribed ball, aka, the Chebyshev ball, and 
        record as instance attributes:
        self.cheby_radius:    radius
        self.cheby_center:    center
        self.fulldim:   whether polytope is full dimensional
                        (self.fulldim = self.cheby_radius > tol)
        Calls polytope.cheby_ball from 
            https://tulip-control.github.io/polytope/
        Arguments:
            inner_tol:  tolerance for inner ball radius to declare full dim
            outer_tol:  tolerance for outer ball radius to declare closed
        '''
        cheby = cheby_ball(Polytope(A=self.A, b=self.b))
        self.cheby_radius = cheby[0]
        self.cheby_center = cheby[1]
        self.fulldim = self.cheby_radius >= inner_tol
        self.is_closed = self.cheby_radius <= outer_tol

    def remove_redundants(self, tol=1e-8):
        '''
        Removes redundant inequalities in the hyperplane representation of the
        polytope.  Calls reduce method from:
            https://tulip-control.github.io/polytope/ 
        Arguments:
            tol:    tolerance to use in calling polytope.reduce
        Updates:
            self.A
            self.b
        '''
        poly = polytope.reduce(Polytope(A=self.A, b=self.b), abs_tol=tol)
        self.A = poly.A
        self.b = poly.b
        self.minrep = True

    def get_vertices(self):
        '''
        Gets the vertices of Polytope (if closed and bounded).  Calls 'extreme'
        method from:
            https://tulip-control.github.io/polytope/ 
        Returns:
            (k by d) numpy array of the k verticles of the polytope
        '''
        return polytope.extreme(Polytope(A=self.A, b=self.b))

    def is_inside(self, y, tol=0):
        '''
        Tests whether a supplied point is inside the polytope.
        Arguments:
            y:      point to be tested
            tol:    error tolerance for any of the hyperplanes
        Returns:
            True or False indicating whether point is inside
        '''
        y = np.array(y).ravel()
        b0 = np.matmul(self.A, y)
        if any(b0 > self.b + tol):
            return False
        else:
            return True

    def get_slice(self, z, verbose=False):
        '''
        Gets a two-dimensional slice of a polytope
        Arguments:
            z:          A numpy array corresponding to the slice.  If the 
                        polytope is d > 2 dimensions, then (d-2) dimensions 
                        of this array should be fixed values, and the remaining 
                        2 dimensions should be np.nan corresponding to the 
                        free coordinates.
            verbose:    Boolean indicating whether to print out free 
                        coordinates, fixed coordinates and their values, 
        Returns:
            Polytope2 instance corresponding to 2-D representation.
        '''
        dim = self.A.shape[1]
        if dim == 2:
            if verbose:
                print('get_slice: polytope already 2-dimensional')
            return self
        elif dim > 2:
            if isinstance(z, (list, tuple, np.ndarray)):
                z = np.array(z, dtype=float).ravel()
            else:
                raise TypeError('z must be an array like object')
            if sum(np.isnan(z)) != 2:
                raise ValueError('Need 2 free dimensions specified by np.nan')
            else:
                slice_idx = np.argwhere(np.isnan(z)).ravel()
                idx = np.setdiff1d(np.arange(dim), slice_idx)
                z1 = z[idx]
                if verbose:
                    print(
                        'get_slice, free coordinates:',
                        '\t' + str(slice_idx),
                        sep='\n')
                    print(
                        'get_slice, fixed coordinates:',
                        '\t' + str(idx),
                        sep='\n')
                    print(
                        'get slice, fixed coordinate values:',
                        '\t' + str(z1),
                        sep='\n')
                A1 = self.A[:, idx]
                # check for non-empty slice
                if any(np.dot(A1, z1) > self.b):
                    raise ValueError('get_slice: z not inside polytope')
                else:
                    A = self.A[:, slice_idx]
                    b = self.b - np.dot(A1, z1)
                    return Polytope2(A=A, b=b)

    def get_patch(self, **kwargs):
        '''
        Returns a matplotlib patch to plot 2D Polytope. Note that if you are 
        plotting a slice of a higher dimensional polytope, first generate a 2D 
        slice using the method Polytope.get_slice
        Arguments:
            **kwargs: valid kwargs for matplotlib.patches.Polygon
        '''
        if self.A.shape[1] != 2:
            raise ValueError('Require 2-dimensional Polytope for plotting')
        else:
            vertices = self.get_vertices()
            if self.cheby_radius is None:
                self.chebyshev_ball()
            x = vertices[:, 1] - self.cheby_center[1]
            y = vertices[:, 0] - self.cheby_center[0]
            norm = np.sqrt(x**2 + y**2)
            x = x / norm
            angle = np.arccos(x)
            corr = np.ones(y.size) - 2 * (y < 0)
            angle = angle * corr
            idx = np.argsort(angle)
            return matplotlib.patches.Polygon(
                vertices[idx, :], closed=True, **kwargs)


def make_standard_simplex(dim):
    '''
    Function to construct a standard simplex (i.e., the corner of a 
        d-dimensional cube in the positive orthant which touches the origin).  
    Arguments:
        dim: The dimension of the simplex
    Returns:
        polytope instance corresponding to the simplex
    '''
    # dimension check
    if isinstance(dim, np.ndarray):
        dim = dim.item()
    if not isinstance(dim, int) or dim < 1:
        raise TypeError('dim: positive integer dimension required')
    # positive orthant
    A = -np.eye(dim)
    b = np.zeros(dim)
    # variables sum to at most 1
    A = np.concatenate((A, np.ones((1, dim))), axis=0)
    b = np.concatenate((b, [1]))
    return Polytope2(A=A, b=b)


# note: need to call convex hull algorithm to use the following technique
# def make_regular_simplex(dim):
#     '''
#     Helper function to construct a regular simplex
#     Arguments:
#         dim: The dimension of the simplex
#     Returns:
#         v:      vertices of the regular simplex
#         b:      an m dimensional numpy array consisting of hyperplane offsets
#     Use the returned values A and b to construct a simplex using Polytope2's
#     constructor.
#     '''
#     # what follows works in R to find the vertices...
#     #simplex <- function(n) {
#     #    qr.Q(qr(matrix(1,nrow=n)),complete=T)[,-1]
#     #}


def make_hyperrect(dim=None, center=None, sidelengths=None):
    '''
    Helper function to construct a hyperrectangle (i.e., d-orthotope) using the
    Polytope2 class.
    Arguments:
        dim:            The dimension of the hyperrectangle 
                        (inferred from len(center) if not provided)
        center:         The center (set to origin if not provided)
                        If float or integer provided, uses same value for all
                        coordinates.
        sidelengths:    An array-like object containing sidelengths
                        If float or integer provided, uses same sidelengths
                        for all sides (hypercube generated).  
    Returns:
        A:      an m x d np matrix consisting of hyperplane normals as rows
        b:      an m dimensional numpy array consisting of hyperplane offsets
    Use the returned values A and b to construct a simplex using Polytope2's 
    constructor.
    '''
    if sidelengths is None:
        raise TypeError('sidelengths: argument required')
    if dim is None and center is None and sidelengths is None:
        msg = 'Must specify at least one of dim, center, sidelengths'
        raise TypeError(msg)
    # dimension check
    if dim is None:
        if isinstance(sidelengths, int):
            dim = 1
        else:
            dim = len(center)
    elif isinstance(dim, np.ndarray):
        dim = dim.item()
    if not isinstance(dim, int) or dim < 1:
        raise TypeError('dim: positive integer dimension >= 1 required')
    # process sidelengths
    if isinstance(sidelengths, int):
        sidelengths = np.array([sidelengths] * dim, dtype=float)
    if not isinstance(sidelengths, (list, tuple, np.ndarray)):
        raise TypeError('sidelengths: must be array-like object')
    else:
        if len(sidelengths) != dim:
            raise TypeError('sidelengths: must be of length dim')
        else:
            sidelengths = np.array(sidelengths, dtype=float).ravel()
            if any(sidelengths <= 0):
                raise ValueError('sidelengths: must be positive')
    # process center
    if center is None:
        center = np.zeros(dim, dtype=float)
    elif not isinstance(center, (list, tuple, np.ndarray)):
        raise TypeError('center: must be array-like object')
    else:
        if len(center) != dim:
            raise TypeError('center: must be of length dim')
        else:
            center = np.array(center, dtype=float).ravel()
    # return equivalent polytope
    A = np.concatenate((-np.eye(dim), np.eye(dim)), axis=0)
    b = np.concatenate((-(center - sidelengths / 2), center + sidelengths / 2))
    return Polytope2(A=A, b=b)


def make_hypercube(dim, center=None, sidelengths=2):
    '''
    Helper function to construct a hypercube using the Polytope2 class.
    Arguments:
        dim:            The dimension of the hypercube
                        (inferred from len(center) if not provided)
        center:         The center (set to origin if not provided)
        sidelengths:    int or float specifying common sidelength 
                        (defaults to 2 corresponding to l-infinity ball)
    Returns:
        A:      an m x d np matrix consisting of hyperplane normals as rows
        b:      an m dimensional numpy array consisting of hyperplane offsets
    Use the returned values A and b to construct a simplex using Polytope2's 
    constructor.
    '''
    return make_hyperrect(dim=dim, center=center, sidelengths=sidelengths)