"""Specialised plots for functions defined in blmfuncs.py.

fixme:
        -- Im normalsing all color mapping input vectors to [0,1]. This will
        destroy informative numerical values in colorbar (but we
        are not showing these anyway). A better fix would be to let the
        colorbar listen to the scalarmappable instance and corect itself, but
        I did not get that to work ...

fixme2:
        -- If scatterplot is not inited with a colorvector there will be no
        colorbar, but when adding colors the colorbar shoud be created.
"""
from fluents import plots
from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt
from matplotlib import cm
import pylab as PB

class PcaScorePlot(plots.ScatterPlot):
    """PCA Score plot"""
    def __init__(self, model, absi=0, ordi=1):
        self._T = model.model['T']
        dataset_1 = model.as_dataset('T')
        dataset_2 = dataset_1
        id_dim = dataset_1.get_dim_name(0)
        sel_dim = dataset_1.get_dim_name(1)
        id_1, = dataset_1.get_identifiers(sel_dim, [absi])
        id_2, = dataset_1.get_identifiers(sel_dim, [ordi])
        plots.ScatterPlot.__init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2 ,c='b' ,s=40 , name='pca-scores')

    def set_absicca(self, n):
        self.xaxis_data = self._T[:,n]

    def set_ordinate(self, n):
        self.yaxis_data = self._T[:,n]

class PcaLoadingPlot(plots.ScatterPlot):
     """PCA Loading plot"""
     def __init__(self, model, absi=0, ordi=1):
         self._P = model.model['P']
         dataset_1 = model.as_dataset('P')
         dataset_2 = dataset_1
         id_dim = dataset_1.get_dim_name(0)
         sel_dim = dataset_1.get_dim_name(1)
         id_1, = dataset_1.get_identifiers(sel_dim, [absi])
         id_2, = dataset_1.get_identifiers(sel_dim, [ordi])
         if model.model.has_key('p_tsq'):
             col = model.model['p_tsq'].ravel()
             col = normalise(col)
         else:
             col = 'g'
         plots.ScatterPlot.__init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2,c=col,s=20, name='pls-loadings')

     def set_absicca(self, n):
         self.xaxis_data = self._P[:,n]

     def set_ordinate(self, n):
         self.yaxis_data = self._P[:,n]
    
class PlsScorePlot(plots.ScatterPlot):
    """PLS Score plot"""
    def __init__(self, model, absi=0, ordi=1):
        self._T = model.model['T']
        dataset_1 = model.as_dataset('T')
        dataset_2 = dataset_1
        id_dim = dataset_1.get_dim_name(0)
        sel_dim = dataset_1.get_dim_name(1)
        id_1, = dataset_1.get_identifiers(sel_dim, [absi])
        id_2, = dataset_1.get_identifiers(sel_dim, [ordi])
        
        plots.ScatterPlot.__init__(self, dataset_1, dataset_2,
                                   id_dim, sel_dim, id_1, id_2 ,
                                   c='b' ,s=40 , name='pls-scores')
        
    def set_absicca(self, n):
        self.xaxis_data = self._T[:,n]

    def set_ordinate(self, n):
        self.yaxis_data = self._T[:,n]


class PlsLoadingPlot(plots.ScatterPlot):    
    """PLS Loading plot"""
    def __init__(self, model, absi=0, ordi=1):
        self._P = model.model['P']
        dataset_1 = model.as_dataset('P')
        dataset_2 = dataset_1
        id_dim = dataset_1.get_dim_name(0)
        sel_dim = dataset_1.get_dim_name(1)
        id_1, = dataset_1.get_identifiers(sel_dim, [absi])
        id_2, = dataset_1.get_identifiers(sel_dim, [ordi])
        if model.model.has_key('w_tsq'):
            col = model.model['w_tsq'].ravel()
            col = normalise(col)
        else:
            col = 'g'
        plots.ScatterPlot.__init__(self, dataset_1, dataset_2,
                                   id_dim, sel_dim, id_1, id_2,
                                   c=col, s=20,                                   name='loadings')
        
    def set_absicca(self, n):
        self.xaxis_data = self._P[:,n]
        
    def set_ordinate(self, n):
        self.yaxis_data = self._T[:,n]


class LineViewXc(plots.LineViewPlot):
    """A line view of centered raw data
    """
    def __init__(self, model, name='Profiles'):
        # copy, center, plot
        x = model._dataset['X'].copy()
        x._array = x._array - mean(x._array,0)[newaxis]
        plots.LineViewPlot.__init__(self, x, 1, None, name)


class ParalellCoordinates(plots.Plot):
    """Parallell coordinates for score loads with many comp.
    """
    def __init__(self, model, p='loads'):
        pass


class PlsQvalScatter(plots.ScatterPlot):
    """A vulcano like plot of loads vs qvals
    """
    def __init__(self, model, pc=0):
        if not model.model.has_key('w_tsq'):
            return
        self._W = model.model['P']
        dataset_1 = model.as_dataset('P')
        dataset_2 = model.as_dataset('w_tsq')
        id_dim = dataset_1.get_dim_name(0) #genes
        sel_dim = dataset_1.get_dim_name(1) #_comp
        sel_dim_2 = dataset_2.get_dim_name(1) #_zero_dim
        id_1, = dataset_1.get_identifiers(sel_dim, [0])
        id_2, = dataset_2.get_identifiers(sel_dim_2, [0])
        if model.model.has_key('w_tsq'):
            col = model.model['w_tsq'].ravel()
            col = normalise(col)
        else:
            col = 'g'
        plots.ScatterPlot.__init__(self, dataset_1, dataset_2,
                                   id_dim, sel_dim, id_1, id_2,
                                   c=col, s=20, sel_dim_2=sel_dim_2,
                                   name='Load Volcano')

class PredictionErrorPlot(plots.Plot):
    """A boxplot of prediction error vs. comp. number.
    """
    def __init__(self, model, name="Pred. Err."):
        if not model.model.has_key('sep'):
            logger.log('notice', 'Model has no calculations of sep')
            return
        plots.Plot.__init__(self, name)
        self._frozen = True
        self.current_dim = 'johndoe'
        self.ax = self.fig.add_subplot(111)
        
        # draw
        sep = model.model['sep']
        aopt = model.model['aopt']
        bx_plot_lines = self.ax.boxplot(sqrt(sep))
        aopt_marker = self.ax.axvline(aopt, linewidth=10,
                                      color='r',zorder=0,
                                      alpha=.5)
        
        # add canvas
        self.add(self.canvas)
        self.canvas.show()

    def set_current_selection(self, selection):
        pass
    
    
class InfluencePlot(plots.ScatterPlot):
    """
    """
    pass
        

def normalise(x):
    """Scale vector x to [0,1]
    """
    x = x - x.min()
    x = x/x.max()
    return x