diff --git a/fluents/lib/blmplots.py b/fluents/lib/blmplots.py index 902eb45..d282145 100644 --- a/fluents/lib/blmplots.py +++ b/fluents/lib/blmplots.py @@ -15,94 +15,169 @@ from fluents import plots from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt from matplotlib import cm import pylab as PB +import gtk -class PcaScorePlot(plots.ScatterPlot): - """PCA Score plot""" - def __init__(self, model, absi=0, ordi=1): - self._T = model.model['T'] - dataset_1 = model.as_dataset('T') - dataset_2 = dataset_1 + +class BlmScatterPlot(plots.ScatterPlot): + """Scatter plot used for scores and loadings in bilinear models.""" + + def __init__(self, title, model, absi=0, ordi=1, part_name='T', color_by=None): + if model.model.has_key(part_name)!=True: + raise ValueError("Model part: %s not found in model" %mod_param) + self._T = model.model[part_name] + if self._T.shape[1]==1: + logger.log('notice', 'Scores have only one component') + absi= ordi = 0 + self._absi = absi + self._ordi = ordi + self._colorbar = None + dataset_1 = model.as_dataset(part_name) id_dim = dataset_1.get_dim_name(0) sel_dim = dataset_1.get_dim_name(1) id_1, = dataset_1.get_identifiers(sel_dim, [absi]) id_2, = dataset_1.get_identifiers(sel_dim, [ordi]) - plots.ScatterPlot.__init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2 ,c='b' ,s=40 , name='pca-scores') - def set_absicca(self, n): - self.xaxis_data = self._T[:,n] + col = 'b' + if model.model.has_key(color_by): + col = model.model[color_by].ravel() + + plots.ScatterPlot.__init__(self, dataset_1, dataset_1, id_dim, sel_dim, id_1, id_2 ,c=col ,s=40 , name=title) - def set_ordinate(self, n): - self.yaxis_data = self._T[:,n] + self.add_pc_spin_buttons(self._T.shape[1], absi, ordi) + self._key_press = self.canvas.mpl_connect( + 'key_press_event', self._on_key_press) -class PcaLoadingPlot(plots.ScatterPlot): - """PCA Loading plot""" - def __init__(self, model, absi=0, ordi=1): - self._P = model.model['P'] - dataset_1 = model.as_dataset('P') - dataset_2 = dataset_1 - id_dim = dataset_1.get_dim_name(0) - sel_dim = dataset_1.get_dim_name(1) - id_1, = dataset_1.get_identifiers(sel_dim, [absi]) - id_2, = dataset_1.get_identifiers(sel_dim, [ordi]) - if model.model.has_key('p_tsq'): - col = model.model['p_tsq'].ravel() - col = normalise(col) - else: - col = 'g' - plots.ScatterPlot.__init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2,c=col,s=20, name='pls-loadings') + def _on_key_press(self, event): + if event.key=='c': + self.toggle_colorbar() - def set_absicca(self, n): - self.xaxis_data = self._P[:,n] + def set_facecolor(self, colors): + """Set patch facecolors. + """ + pass - def set_ordinate(self, n): - self.yaxis_data = self._P[:,n] - -class PlsScorePlot(plots.ScatterPlot): - """PLS Score plot""" - def __init__(self, model, absi=0, ordi=1): - self._T = model.model['T'] - dataset_1 = model.as_dataset('T') - dataset_2 = dataset_1 - id_dim = dataset_1.get_dim_name(0) - sel_dim = dataset_1.get_dim_name(1) - id_1, = dataset_1.get_identifiers(sel_dim, [absi]) - id_2, = dataset_1.get_identifiers(sel_dim, [ordi]) - - plots.ScatterPlot.__init__(self, dataset_1, dataset_2, - id_dim, sel_dim, id_1, id_2 , - c='b' ,s=40 , name='pls-scores') - - def set_absicca(self, n): - self.xaxis_data = self._T[:,n] + def set_alphas(self, alphas): + """Set alpha channel for all patches.""" + pass - def set_ordinate(self, n): - self.yaxis_data = self._T[:,n] + def set_sizes(self, sizes): + """Set patch sizes.""" + pass - -class PlsLoadingPlot(plots.ScatterPlot): - """PLS Loading plot""" - def __init__(self, model, absi=0, ordi=1): - self._P = model.model['P'] - dataset_1 = model.as_dataset('P') - dataset_2 = dataset_1 - id_dim = dataset_1.get_dim_name(0) - sel_dim = dataset_1.get_dim_name(1) - id_1, = dataset_1.get_identifiers(sel_dim, [absi]) - id_2, = dataset_1.get_identifiers(sel_dim, [ordi]) - if model.model.has_key('w_tsq'): - col = model.model['w_tsq'].ravel() - col = normalise(col) + def toggle_colorbar(self): + if self._colorbar==None: + if self.sc._A!=None: # we need colormapping + # get axes original position + self._ax_last_pos = self.ax.get_position() + self._colorbar = self.fig.colorbar(self.sc) + self._colorbar.draw_all() + self.canvas.draw() else: - col = 'g' - plots.ScatterPlot.__init__(self, dataset_1, dataset_2, - id_dim, sel_dim, id_1, id_2, - c=col, s=20, name='loadings') + # remove colorbar + # remove, axes, observers, colorbar instance, and restore viewlims + cb, ax = self.sc.colorbar + self.fig.delaxes(ax) + self.sc.observers = [obs for obs in self.sc.observers if obs !=self._colorbar] + self._colorbar = None + self.sc.colorbar = None + self.ax.set_position(self._ax_last_pos) + self.canvas.draw() + + def add_pc_spin_buttons(self, amax, absi, ordi): + sb_a = gtk.SpinButton(climb_rate=1) + sb_a.set_range(1, amax) + sb_a.set_value(absi) + sb_a.set_increments(1, 5) + sb_a.connect('value_changed', self.set_absicca) + sb_o = gtk.SpinButton(climb_rate=1) + sb_o.set_range(1, amax) + sb_o.set_value(ordi) + sb_o.set_increments(1, 5) + sb_o.connect('value_changed', self.set_ordinate) + hbox = gtk.HBox() + gtk_label_a = gtk.Label("A:") + gtk_label_o = gtk.Label(" O:") + toolitem = gtk.ToolItem() + toolitem.set_expand(False) + toolitem.set_border_width(2) + toolitem.add(hbox) + hbox.pack_start(gtk_label_a) + hbox.pack_start(sb_a) + hbox.pack_start(gtk_label_o) + hbox.pack_start(sb_o) + self._toolbar.insert(toolitem, -1) + toolitem.set_tooltip(self._toolbar.tooltips, "Set Principal component") + self._toolbar.show_all() #do i need this? + + def set_absicca(self, sb): + self._absi = sb.get_value_as_int() - 1 + xy = self._T[:,[self._absi, self._ordi]] + self.xaxis_data = xy[:,0] + self.yaxis_data = xy[:,1] + self.sc._offsets = xy + if self.use_blit==True: + self.canvas.restore_region(self._clean_bck) + self.ax.draw_artist(self.sc) + self.canvas.blit() + else: + self.canvas.draw_idle() - def set_absicca(self, n): - self.xaxis_data = self._P[:,n] + def set_ordinate(self, sb): + self._ordi = sb.get_value_as_int() - 1 + xy = self._T[:,[self._absi, self._ordi]] + self.xaxis_data = xy[:,0] + self.yaxis_data = xy[:,1] + self.sc._offsets = xy + if self.use_blit==True: + self.canvas.restore_region(self._clean_bck) + self.ax.draw_artist(self.sc) + self.canvas.blit() + else: + self.canvas.draw_idle() + + def show_labels(self, index=None): + if self._text_labels == None: + x = self.xaxis_data + y = self.yaxis_data + self._text_labels = {} + for name, n in self.dataset_1[self.current_dim].items(): + txt = self.ax.text(x[n],y[n], name) + txt.set_visible(False) + self._text_labels[n] = txt + if index!=None: + self.hide_labels() + for indx,txt in self._text_labels.items(): + if indx in index: + txt.set_visible(True) + self.canvas.draw() + + def hide_labels(self): + for txt in self._text_labels.values(): + txt.set_visible(False) + self.canvas.draw() + +class PcaScorePlot(BlmScatterPlot): + def __init__(self, model, absi=0, ordi=1): + title = "Pca scores (%s)" %model._dataset['X'].get_name() + BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T') + + +class PcaLoadingPlot(BlmScatterPlot): + def __init__(self, model, absi=0, ordi=1): + title = "Pca loadings (%s)" %model._dataset['X'].get_name() + BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='p_tsq') - def set_ordinate(self, n): - self.yaxis_data = self._T[:,n] + +class PlsScorePlot(BlmScatterPlot): + def __init__(self, model, absi=0, ordi=1): + title = "Pls scores (%s)" %model._dataset['X'].get_name() + BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T') + + +class PlsLoadingPlot(BlmScatterPlot): + def __init__(self, model, absi=0, ordi=1): + title = "Pca loadings (%s)" %model._dataset['X'].get_name() + BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='w_tsq') class LineViewXc(plots.LineViewPlot): @@ -127,7 +202,7 @@ class PlsQvalScatter(plots.ScatterPlot): """ def __init__(self, model, pc=0): if not model.model.has_key('w_tsq'): - return + return None self._W = model.model['P'] dataset_1 = model.as_dataset('P') dataset_2 = model.as_dataset('w_tsq') @@ -146,13 +221,14 @@ class PlsQvalScatter(plots.ScatterPlot): c=col, s=20, sel_dim_2=sel_dim_2, name='Load Volcano') + class PredictionErrorPlot(plots.Plot): """A boxplot of prediction error vs. comp. number. """ - def __init__(self, model, name="Pred. Err."): + def __init__(self, model, name="Prediction Error"): if not model.model.has_key('sep'): logger.log('notice', 'Model has no calculations of sep') - return + return None plots.Plot.__init__(self, name) self._frozen = True self.current_dim = 'johndoe' @@ -180,6 +256,15 @@ class InfluencePlot(plots.ScatterPlot): pass +class RMSEPPlot(plots.BarPlot): + def __init__(self, model, name="RMSEP"): + if not model.model.has_key('rmsep'): + logger.log('notice', 'Model has no calculations of sep') + return + dataset = model.as_dataset('rmsep') + plots.BarPlot.__init__(self, dataset, name=name) + + def normalise(x): """Scale vector x to [0,1] """