"""Specialised plots for functions defined in blmfuncs.py.

fixme:
        -- If scatterplot is not inited with a colorvector there will be no
        colorbar, but when adding colors the colorbar shoud be created.
"""

from matplotlib import cm,patches
import gtk
import fluents
from fluents import plots, main
import scipy
from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt,apply_along_axis,empty
from scipy.stats import corrcoef

def correlation_loadings(data, T, test=True):
    """ Returns correlation loadings.

    :input:
        - D: [nsamps, nvars], data (non-centered data)
        - T: [nsamps, a_max], Scores
    :ouput:
        - R: [nvars, a_max], Correlation loadings

    :notes:
    
    """
    nsamps, nvars = data.shape
    nsampsT, a_max = T.shape
    
    if nsamps!=nsampsT: raise IOError("D/T mismatch")
    
    # center
    data = data - data.mean(0)
    R = empty((nvars, a_max),'d')
    for a in range(a_max):
        for k in range(nvars):
            R[k,a] = corrcoef(data[:,k], T[:,a])[0,1]
    
    return R

class BlmScatterPlot(plots.ScatterPlot):
    """Scatter plot used for scores and loadings in bilinear models."""

    def __init__(self, title, model, absi=0, ordi=1, part_name='T', color_by=None):
        if model.model.has_key(part_name)!=True:
            raise ValueError("Model part: %s not found in model" %mod_param)
        self._T = model.model[part_name]
        if self._T.shape[1]==1:
            logger.log('notice', 'Scores have only one component')
            absi= ordi = 0
        self._absi = absi
        self._ordi = ordi
        self._cmap = cm.jet
        dataset_1 = model.as_dataset(part_name)
        id_dim = dataset_1.get_dim_name(0)
        sel_dim = dataset_1.get_dim_name(1)
        id_1, = dataset_1.get_identifiers(sel_dim, [absi])
        id_2, = dataset_1.get_identifiers(sel_dim, [ordi])
        col = 'b'
        if model.model.has_key(color_by):
            col = model.model[color_by].ravel()
        plots.ScatterPlot.__init__(self, dataset_1, dataset_1, id_dim, sel_dim, id_1, id_2 ,c=col ,s=40 , name=title)
        self._mappable.set_cmap(self._cmap)
        self.sc = self._mappable
        self.add_pc_spin_buttons(self._T.shape[1], absi, ordi)

    def _update_color_from_dataset(self, data):
        """Overriding scatter for testing of colormaps.
        """
        is_category = False
        array = data.asarray()
        #only support for 2d-arrays:
        try:
            m, n = array.shape
        except:
            raise ValueError, "No support for more than 2 dimensions."
        # is dataset a vector or matrix?
        if not n==1:
            # we have a category dataset
            if isinstance(data, fluents.dataset.CategoryDataset):
                is_category = True
                map_vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1)
            else:
                map_vec = array.sum(1)
        else:
            map_vec = array.ravel()

        # update facecolors
        self.sc.set_array(map_vec)
        self.sc.set_clim(map_vec.min(), map_vec.max())
        if is_category:
            cmap = cm.Paired
        else:
            cmap = cm.jet

        self.sc.set_cmap(cmap)
        self.sc.update_scalarmappable() #sets facecolors from array
        self.canvas.draw()
        
    def set_facecolor(self, colors):
        """Set patch facecolors.
        """
        pass

    def set_alphas(self, alphas):
        """Set alpha channel for all patches."""
        pass

    def set_sizes(self, sizes):
        """Set patch sizes."""
        pass
    
    def add_pc_spin_buttons(self, amax, absi, ordi):    
        sb_a = gtk.SpinButton(climb_rate=1)
        sb_a.set_range(1, amax)
        sb_a.set_value(absi+1)
        sb_a.set_increments(1, 5)
        sb_a.connect('value_changed', self.set_absicca)
        sb_o = gtk.SpinButton(climb_rate=1)
        sb_o.set_range(1, amax)
        sb_o.set_value(ordi+1)
        sb_o.set_increments(1, 5)
        sb_o.connect('value_changed', self.set_ordinate)
        hbox = gtk.HBox()
        gtk_label_a = gtk.Label("A:")
        gtk_label_o = gtk.Label(" O:")
        toolitem = gtk.ToolItem()                          
        toolitem.set_expand(False)
        toolitem.set_border_width(2)
        toolitem.add(hbox)        
        hbox.pack_start(gtk_label_a)        
        hbox.pack_start(sb_a)
        hbox.pack_start(gtk_label_o)        
        hbox.pack_start(sb_o)
        self._toolbar.insert(toolitem, -1)
        toolitem.set_tooltip(self._toolbar.tooltips, "Set Principal component")
        self._toolbar.show_all() #do i need this?

    def set_absicca(self, sb):
        self._absi = sb.get_value_as_int() - 1
        xy = self._T[:,[self._absi, self._ordi]]
        self.xaxis_data = xy[:,0]
        self.yaxis_data = xy[:,1]
        self.sc._offsets = xy
        self.selection_collection._offsets = xy
        self.canvas.draw_idle()
        pad = abs(self.xaxis_data.min()-self.xaxis_data.max())*0.05
        new_lims = (self.xaxis_data.min()+pad, self.xaxis_data.max()+pad)
        self.axes.set_xlim(new_lims, emit=True)
        self.canvas.draw_idle()
        
    def set_ordinate(self, sb):
        self._ordi = sb.get_value_as_int() - 1
        xy = self._T[:,[self._absi, self._ordi]]
        self.xaxis_data = xy[:,0]
        self.yaxis_data = xy[:,1]
        self.sc._offsets = xy
        self.selection_collection._offsets = xy
        pad = abs(self.yaxis_data.min()-self.yaxis_data.max())*0.05
        new_lims = (self.yaxis_data.min()+pad, self.yaxis_data.max()+pad)
        self.axes.set_ylim(new_lims, emit=True)
        self.canvas.draw_idle()
    
    def show_labels(self, index=None):
        if self._text_labels == None:
            x = self.xaxis_data
            y = self.yaxis_data
            self._text_labels = {}
            for name, n in self.dataset_1[self.current_dim].items():
                txt = self.axes.text(x[n],y[n], name)
                txt.set_visible(False)
                self._text_labels[n] = txt
        if index!=None:
            self.hide_labels()
            for indx,txt in self._text_labels.items():
                if indx in index:
                    txt.set_visible(True)
        self.canvas.draw()
                
    def hide_labels(self):
        for txt in self._text_labels.values():
            txt.set_visible(False)
        self.canvas.draw()


class PcaScorePlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Pca scores (%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T')


class PcaLoadingPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Pca loadings (%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='p_tsq')
        

class PlsScorePlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Pls scores (%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T')


class PlsLoadingPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Pls loadings (%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='w_tsq')


class PlsCorrelationLoadingPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Pls correlation loadings (%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='CP')
        
class LplsXLoadingPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Lpls x-loadings (%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='tsqx')

class LplsZLoadingPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Lpls z-loadings (%s)" %model._dataset['Z'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='L', color_by='tsqz')

class LplsXCorrelationPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Lpls x-corr. loads (%s)" %model._dataset['X'].get_name()
        if not model.model.has_key('Rx'):
            R = correlation_loadings(model._data['X'], model.model['T'])
            model.model['Rx'] = R
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rx')
        radius = 1
        center = (0,0)
        c100 = patches.Circle(center,radius=radius,
                              facecolor='gray',
                              alpha=.1,
                              zorder=1)
        c50 = patches.Circle(center, radius=radius/2.0,
                             facecolor='gray',
                             alpha=.1,
                             zorder=2)
        self.axes.add_patch(c100)
        self.axes.add_patch(c50)
        self.axes.axhline(lw=1.5,color='k')
        self.axes.axvline(lw=1.5,color='k')
        self.axes.set_xlim([-1.05,1.05])
        self.axes.set_ylim([-1.05, 1.05])
        self.canvas.show()
        
class LplsZCorrelationPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Lpls z-corr. loads (%s)" %model._dataset['Z'].get_name()
        if not model.model.has_key('Rz'):
            R = correlation_loadings(model._data['Z'].T, model.model['W'])
            model.model['Rz'] = R
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rz')
        radius = 1
        center = (0,0)
        c100 = patches.Circle(center,radius=radius,
                              facecolor='gray',
                              alpha=.1,
                              zorder=1)
        c50 = patches.Circle(center, radius=radius/2.0,
                             facecolor='gray',
                             alpha=.1,
                             zorder=2)
        self.axes.add_patch(c100)
        self.axes.add_patch(c50)
        self.axes.axhline(lw=1.5,color='k')
        self.axes.axvline(lw=1.5,color='k')
        self.axes.set_xlim([-1.05,1.05])
        self.axes.set_ylim([-1.05, 1.05])
        self.canvas.show()


class LplsHypoidCorrelationPlot(BlmScatterPlot):
    def __init__(self, model, absi=0, ordi=1):
        title = "Hypoid correlations(%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='W')


class LineViewXc(plots.LineViewPlot):
    """A line view of centered raw data
    """
    def __init__(self, model, name='Profiles'):
        dx = model._dataset['X']
        plots.LineViewPlot.__init__(self, dx, 1, None, False,name)
        self.add_center_check_button(self.data_is_centered)
    
    def add_center_check_button(self, ticked):
        """Add a checker button for centerd view of data."""
        cb = gtk.CheckButton("Center")
        cb.set_active(ticked)
        cb.connect('toggled', self._toggle_center)
        toolitem = gtk.ToolItem()   
        toolitem.set_expand(False)
        toolitem.set_border_width(2)
        toolitem.add(cb)
        self._toolbar.insert(toolitem, -1)
        toolitem.set_tooltip(self._toolbar.tooltips, "Column center the line view")
        self._toolbar.show_all() #do i need this?

    def _toggle_center(self, active):
        if self.data_is_centered:
            self._data = self._data + self._mn_data
            self.data_is_centered = False
        else:
            self._mn_data = self._data.mean(0)
            self._data = self._data - self._mn_data
            self.data_is_centered = True
        self.make_lines()
        self.set_background()
        self.set_current_selection(main.project.get_selection())


class ParalellCoordinates(plots.Plot):
    """Parallell coordinates for score loads with many comp.
    """
    def __init__(self, model, p='loads'):
        pass


class PlsQvalScatter(plots.ScatterPlot):
    """A vulcano like plot of loads vs qvals
    """
    def __init__(self, model, pc=0):
        if not model.model.has_key('w_tsq'):
            return None
        self._W = model.model['W']
        dataset_1 = model.as_dataset('W')
        dataset_2 = model.as_dataset('w_tsq')
        id_dim = dataset_1.get_dim_name(0) #genes
        sel_dim = dataset_1.get_dim_name(1) #_comp
        sel_dim_2 = dataset_2.get_dim_name(1) #_zero_dim
        id_1, = dataset_1.get_identifiers(sel_dim, [0])
        id_2, = dataset_2.get_identifiers(sel_dim_2, [0])
        if model.model.has_key('w_tsq'):
            col = model.model['w_tsq'].ravel()
            #col = normalise(col)
        else:
            col = 'g'
        plots.ScatterPlot.__init__(self, dataset_1, dataset_2,
                                   id_dim, sel_dim, id_1, id_2,
                                   c=col, s=20, sel_dim_2=sel_dim_2,
                                   name='Load Volcano')


class PredictionErrorPlot(plots.Plot):
    """A boxplot of prediction error vs. comp. number.
    """
    def __init__(self, model, name="Prediction Error"):
        if not model.model.has_key('sep'):
            logger.log('notice', 'Model has no calculations of sep')
            return None
        plots.Plot.__init__(self, name)
        self._frozen = True
        self.current_dim = 'johndoe'
        self.axes = self.fig.add_subplot(111)
        
        # draw
        sep = model.model['sep']
        aopt = model.model['aopt']
        bx_plot_lines = self.axes.boxplot(sqrt(sep))
        aopt_marker = self.axes.axvline(aopt, linewidth=10,
                                      color='r',zorder=0,
                                      alpha=.5)
        
        # add canvas
        self.add(self.canvas)
        self.canvas.show()

    def set_current_selection(self, selection):
        pass


class TRBiplot(plots.ScatterPlot):    
    def __init__(self, model, absi=0, ordi=1):
        title = "Target rotation biplot(%s)" %model._dataset['X'].get_name()
        BlmScatterPlot.__init__(self, title, model, absi, ordi, 'B')
        B = model.model.get('B')
        # normalize B
        Bnorm = scipy.apply_along_axis(scipy.linalg.norm, 1, B)
        x = model._dataset['X'].copy()
        Xc = x._array - mean(x._array,0)[newaxis]
        w_rot = B/Bnorm 
        t_rot = dot(Xc, w_rot)
    

class InfluencePlot(plots.ScatterPlot):
    """
    """
    pass
        

class RMSEPPlot(plots.BarPlot):
    def __init__(self, model, name="RMSEP"):
        if not model.model.has_key('rmsep'):
            logger.log('notice', 'Model has no calculations of sep')
            return
        dataset = model.as_dataset('rmsep')
        plots.BarPlot.__init__(self, dataset, name=name)


def normalise(x):
    """Scale vector x to [0,1]
    """
    x = x - x.min()
    x = x/x.max()
    return x