diff --git a/fluents/lib/blmplots.py b/fluents/lib/blmplots.py index d282145..58849dd 100644 --- a/fluents/lib/blmplots.py +++ b/fluents/lib/blmplots.py @@ -1,22 +1,16 @@ """Specialised plots for functions defined in blmfuncs.py. fixme: - -- Im normalsing all color mapping input vectors to [0,1]. This will - destroy informative numerical values in colorbar (but we - are not showing these anyway). A better fix would be to let the - colorbar listen to the scalarmappable instance and corect itself, but - I did not get that to work ... - -fixme2: -- If scatterplot is not inited with a colorvector there will be no colorbar, but when adding colors the colorbar shoud be created. """ -from fluents import plots -from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt -from matplotlib import cm -import pylab as PB -import gtk +from matplotlib import cm +import gtk +import fluents +from fluents import plots +import scipy +from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt class BlmScatterPlot(plots.ScatterPlot): """Scatter plot used for scores and loadings in bilinear models.""" @@ -31,6 +25,7 @@ class BlmScatterPlot(plots.ScatterPlot): self._absi = absi self._ordi = ordi self._colorbar = None + self._cmap = cm.jet dataset_1 = model.as_dataset(part_name) id_dim = dataset_1.get_dim_name(0) sel_dim = dataset_1.get_dim_name(1) @@ -42,11 +37,46 @@ class BlmScatterPlot(plots.ScatterPlot): col = model.model[color_by].ravel() plots.ScatterPlot.__init__(self, dataset_1, dataset_1, id_dim, sel_dim, id_1, id_2 ,c=col ,s=40 , name=title) - + + self.sc.set_cmap(self._cmap) self.add_pc_spin_buttons(self._T.shape[1], absi, ordi) self._key_press = self.canvas.mpl_connect( 'key_press_event', self._on_key_press) + + def _update_color_from_dataset(self, data): + """Overriding scatter for testing of colormaps. + """ + is_category = False + array = data.asarray() + #only support for 2d-arrays: + try: + m, n = array.shape + except: + raise ValueError, "No support for more than 2 dimensions." + # is dataset a vector or matrix? + if not n==1: + # we have a category dataset + if isinstance(data, fluents.dataset.CategoryDataset): + is_category = True + map_vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1) + else: + map_vec = array.sum(1) + else: + map_vec = array.ravel() + + # update facecolors + self.sc.set_array(map_vec) + self.sc.set_clim(map_vec.min(), map_vec.max()) + if is_category: + cmap = cm.Paired + else: + cmap = cm.jet + + self.sc.set_cmap(cmap) + self.sc.update_scalarmappable() #sets facecolors from array + self.canvas.draw() + def _on_key_press(self, event): if event.key=='c': self.toggle_colorbar() @@ -68,7 +98,7 @@ class BlmScatterPlot(plots.ScatterPlot): if self._colorbar==None: if self.sc._A!=None: # we need colormapping # get axes original position - self._ax_last_pos = self.ax.get_position() + self._ax_last_pos = self.axes.get_position() self._colorbar = self.fig.colorbar(self.sc) self._colorbar.draw_all() self.canvas.draw() @@ -80,18 +110,18 @@ class BlmScatterPlot(plots.ScatterPlot): self.sc.observers = [obs for obs in self.sc.observers if obs !=self._colorbar] self._colorbar = None self.sc.colorbar = None - self.ax.set_position(self._ax_last_pos) + self.axes.set_position(self._ax_last_pos) self.canvas.draw() def add_pc_spin_buttons(self, amax, absi, ordi): sb_a = gtk.SpinButton(climb_rate=1) sb_a.set_range(1, amax) - sb_a.set_value(absi) + sb_a.set_value(absi+1) sb_a.set_increments(1, 5) sb_a.connect('value_changed', self.set_absicca) sb_o = gtk.SpinButton(climb_rate=1) sb_o.set_range(1, amax) - sb_o.set_value(ordi) + sb_o.set_value(ordi+1) sb_o.set_increments(1, 5) sb_o.connect('value_changed', self.set_ordinate) hbox = gtk.HBox() @@ -115,12 +145,12 @@ class BlmScatterPlot(plots.ScatterPlot): self.xaxis_data = xy[:,0] self.yaxis_data = xy[:,1] self.sc._offsets = xy - if self.use_blit==True: - self.canvas.restore_region(self._clean_bck) - self.ax.draw_artist(self.sc) - self.canvas.blit() - else: - self.canvas.draw_idle() + self.selection_collection._offsets = xy + self.canvas.draw_idle() + pad = abs(self.xaxis_data.min()-self.xaxis_data.max())*0.05 + new_lims = (self.xaxis_data.min()+pad, self.xaxis_data.max()+pad) + self.axes.set_xlim(new_lims, emit=True) + self.canvas.draw_idle() def set_ordinate(self, sb): self._ordi = sb.get_value_as_int() - 1 @@ -128,20 +158,19 @@ class BlmScatterPlot(plots.ScatterPlot): self.xaxis_data = xy[:,0] self.yaxis_data = xy[:,1] self.sc._offsets = xy - if self.use_blit==True: - self.canvas.restore_region(self._clean_bck) - self.ax.draw_artist(self.sc) - self.canvas.blit() - else: - self.canvas.draw_idle() - + self.selection_collection._offsets = xy + pad = abs(self.yaxis_data.min()-self.yaxis_data.max())*0.05 + new_lims = (self.yaxis_data.min()+pad, self.yaxis_data.max()+pad) + self.axes.set_ylim(new_lims, emit=True) + self.canvas.draw_idle() + def show_labels(self, index=None): if self._text_labels == None: x = self.xaxis_data y = self.yaxis_data self._text_labels = {} for name, n in self.dataset_1[self.current_dim].items(): - txt = self.ax.text(x[n],y[n], name) + txt = self.axes.text(x[n],y[n], name) txt.set_visible(False) self._text_labels[n] = txt if index!=None: @@ -156,6 +185,7 @@ class BlmScatterPlot(plots.ScatterPlot): txt.set_visible(False) self.canvas.draw() + class PcaScorePlot(BlmScatterPlot): def __init__(self, model, absi=0, ordi=1): title = "Pca scores (%s)" %model._dataset['X'].get_name() @@ -213,7 +243,7 @@ class PlsQvalScatter(plots.ScatterPlot): id_2, = dataset_2.get_identifiers(sel_dim_2, [0]) if model.model.has_key('w_tsq'): col = model.model['w_tsq'].ravel() - col = normalise(col) + #col = normalise(col) else: col = 'g' plots.ScatterPlot.__init__(self, dataset_1, dataset_2, @@ -232,13 +262,13 @@ class PredictionErrorPlot(plots.Plot): plots.Plot.__init__(self, name) self._frozen = True self.current_dim = 'johndoe' - self.ax = self.fig.add_subplot(111) + self.axes = self.fig.add_subplot(111) # draw sep = model.model['sep'] aopt = model.model['aopt'] - bx_plot_lines = self.ax.boxplot(sqrt(sep)) - aopt_marker = self.ax.axvline(aopt, linewidth=10, + bx_plot_lines = self.axes.boxplot(sqrt(sep)) + aopt_marker = self.axes.axvline(aopt, linewidth=10, color='r',zorder=0, alpha=.5)