import numpy as np

from polytope2 import Polytope2
from ellipsoid import Ellipsoid
from scipy.stats import bernoulli
from scipy.stats import multivariate_normal as mvn
from scipy.linalg import pinvh


class DikinWalk():
    '''
    Class for the Dikin Walk as described here:
        https://arxiv.org/abs/0911.3950
    Currently, this class only supports the Dikin Walk on polytopes.

    Arguments for initialization: 
        body:   Convex body (currently only polytope2 instance is supported)
        x0:     an initial point to start the walk (must be inside body)
                (e.g., Chebyshev center calculated by Polytope2.cheby_ball)
        r:              radius for Dikin Walk (defaults to 3/40)
        proposal:   member of {'gaussian', 'uniform'} corresponding to the 
                    proposal distribution type (defaults to 'gaussian')
        lazy_steps:     whether to make the walk lazy, i.e., move only 
                        if fair coin flip is heads
                        (defaults to true)
    Properties: The above arguments for initialization, and also:
        H:              Hessian of the barrier corresponding to body 
                        at current point
        r:              The radius used in the proposal distribution
        x:              The current point
        x_last:         The last point
        reject_type:    Indicates which type of step was used.  Member of
                        {None, 'lazy', 'body', 'metropolis'}
                        referring to whether self.x corresponds to an
                        a proper step, a lazy rejection, an outside of body rejection, or a rejection due to a metropolis filter, respectively.
    Methods:
        .get_hessian:   Gets the Hessian of the self-concordant barrier at a 
                        given point
        .step:          make one step
    '''

    def __init__(self, body, x0, r=3/40, proposal='gaussian', lazy_steps=True):
        self.__proposal_types = {'gaussian', 'uniform'}
        self.__reject_types = {'none', 'lazy', 'body', 'metropolis'}
        self.body = body
        self.x0 = x0
        self.r = r
        self.proposal = proposal.lower()
        self.lazy_steps = lazy_steps

    def __str__(self):
        output = 60*'-' + '\n'
        output += '{:^60}'.format('DikinWalk Instance:') + '\n'
        output += 60*'-' + '\n'
        output += '{:<20}{:<40}'.format('body type:', str(type(self.body)))
        output += '\n'
        output += '{:<20}{:<40}'.format('ellipsoid radius:', str(self.r))
        output += '\n'
        output += '{:<20}{:<40}'.format('proposal type:', str(self.proposal))
        output += '\n'
        output += '{:<20}{:<40}'.format('lazy steps:', str(self.lazy_steps))
        output += '\n'
        x0_str = str(np.round(self.x0, 4))
        x_str = str(np.round(self.x, 4))
        x_last_str = str(np.round(self.x_last, 4))
        if len(x0_str) <= 40 and len(x_str) <= 40 and len(x_last_str) <= 40:
            output += '{:<20}{:<40}'.format('initial point:', x0_str) + '\n'
            output += '{:<20}{:<40}'.format('last point:', x_last_str) + '\n'
            output += '{:<20}{:<40}'.format('current point:', x_str) + '\n'
        else:
            output += '{:<20}{:<40}'.format('initial point:', 'see .x0') + '\n'
            output += '{:<20}{:<40}'.format('last point:',
                                            'see .x_last') + '\n'
            output += '{:<20}{:<40}'.format('current point:', 'see .x') + '\n'
        return output

    @property
    def body(self):
        return self.__body

    @body.setter
    def body(self, body):
        if not isinstance(body, Polytope2):
            raise TypeError('Currently only Polytope2 bodies allowed')
        else:
            self.dim = body.A.shape[1]
            self.__body = body

    @property
    def x0(self):
        return self.__x0

    @x0.setter
    def x0(self, x0):
        if not self.body.is_inside(x0):
            raise ValueError('Initialization point x0 must be inside body')
        else:
            self.__x0 = np.array(x0).ravel()
            self.x = self.x0.copy()
            self.x_last = self.x0.copy()
            self.H = self.get_hessian()
            self.reject_type = 'none'

    @property
    def r(self):
        return self.__r

    @r.setter
    def r(self, r):
        if r <= 0:
            raise ValueError('Radius r must be positive')
        else:
            self.__r = r

    @property
    def proposal(self):
        return self.__proposal

    @proposal.setter
    def proposal(self, proposal):
        if not proposal in self.__proposal_types:
            msg = "proposal must be either 'gaussian' or 'uniform'"
            raise ValueError(msg)
        else:
            self.__proposal = proposal

    @property
    def reject_type(self):
        return self.__reject_type

    @reject_type.setter
    def reject_type(self, reject_type):
        if reject_type in self.__reject_types:
            self.__reject_type = reject_type
        else:
            msg = 'self.reject_type must be one of ' + str(self.__reject_types)
            raise ValueError(msg)

    def get_hessian(self, z=None):
        '''
        returns the Hessian of self-concordant barrier at point z 
        (if None, defaults to the current step, self.x) 
        '''
        if not isinstance(self.body, Polytope2):
            raise TypeError('Currently only Polytope2 bodies allowed')
        else:
            # See, e.g., Nisheeth Vishnoi's notes on Convex Optimization, Ch. 4
            if z is None:
                z = self.x
            s = self.body.b - np.matmul(self.body.A, z)
            A_z = np.matmul(np.diag(1/s), self.body.A)
            return np.matmul(A_z.T, A_z)

    def step(self):
        self.x_last = self.x.copy()
        coin_flip = 1
        if self.lazy_steps:
            coin_flip = bernoulli(p=.5).rvs(1)
        if coin_flip:
            if self.proposal == 'gaussian':
                self.__step_gaussian()
            else:
                self.__step_uniform()
        else:
            self.reject_type = 'lazy'

    def __step_gaussian(self):
        H_x = self.H
        sigma_x = (self.r**2 / self.dim) * pinvh(H_x)
        y = mvn(mean=self.x, cov=sigma_x, allow_singular=True).rvs(1)
        if not self.body.is_inside(y):
            # reject due to point outside of the body
            self.reject_type = 'body'
        else:
            # for numeric stability, use pseudo inverse for hermitians,
            # and allow singular covariance matrices
            H_y = self.get_hessian(z=y)
            sigma_y = (self.r**2 / self.dim) * pinvh(H_y)
            g_xy = mvn(mean=self.x, cov=sigma_x, allow_singular=True).pdf(x=y)
            g_yx = mvn(mean=y, cov=sigma_y, allow_singular=True).pdf(x=self.x)
            # note: ratio = g_yx/g_xy, numerically more stable below
            ratio = np.exp(np.log(g_yx) - np.log(g_xy))
            if not bernoulli(p=min(1, ratio)).rvs(1):
                # reject due to metropolis filter
                self.reject_type = 'metropolis'
            else:
                self.x = y
                self.H = H_y
                self.reject_type = 'none'

    def __step_uniform(self):
        # for numeric stability, use log dets for metropolis filter
        H_x_ldet = np.linalg.slogdet(self.H)[1]
        E_x = Ellipsoid(K=self.H, xc=self.x, r=self.r)
        y = E_x.sample(1).ravel()
        H_y = self.get_hessian(z=y)
        H_y_ldet = np.linalg.slogdet(H_y)[1]
        # the following term is sqrt(|H_y|/|H_x|)
        det_ratio_term = np.exp(.5*(H_y_ldet - H_x_ldet))
        if not bernoulli(p=min(1, det_ratio_term)).rvs(1):
            # reject due to metropolis filter
            self.reject_type = 'metropolis'
        else:
            self.x_last = self.x.copy()
            self.x = y
            self.H = H_y
            self.reject_type = 'none'