Added center check button on lineview
This commit is contained in:
		| @@ -5,13 +5,39 @@ fixme: | ||||
|         colorbar, but when adding colors the colorbar shoud be created. | ||||
| """ | ||||
|  | ||||
| from matplotlib import cm | ||||
| from matplotlib import cm,patches | ||||
| import gtk | ||||
| import fluents | ||||
| from fluents import plots | ||||
| from fluents import plots, main | ||||
| import scipy | ||||
| from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt,apply_along_axis | ||||
| from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt,apply_along_axis,empty | ||||
| from scipy.stats import corrcoef | ||||
|  | ||||
| def correlation_loadings(data, T, test=True): | ||||
|     """ Returns correlation loadings. | ||||
|  | ||||
|     :input: | ||||
|         - D: [nsamps, nvars], data (non-centered data) | ||||
|         - T: [nsamps, a_max], Scores | ||||
|     :ouput: | ||||
|         - R: [nvars, a_max], Correlation loadings | ||||
|  | ||||
|     :notes: | ||||
|      | ||||
|     """ | ||||
|     nsamps, nvars = data.shape | ||||
|     nsampsT, a_max = T.shape | ||||
|      | ||||
|     if nsamps!=nsampsT: raise IOError("D/T mismatch") | ||||
|      | ||||
|     # center | ||||
|     data = data - data.mean(0) | ||||
|     R = empty((nvars, a_max),'d') | ||||
|     for a in range(a_max): | ||||
|         for k in range(nvars): | ||||
|             R[k,a] = corrcoef(data[:,k], T[:,a])[0,1] | ||||
|      | ||||
|     return R | ||||
|  | ||||
| class BlmScatterPlot(plots.ScatterPlot): | ||||
|     """Scatter plot used for scores and loadings in bilinear models.""" | ||||
| @@ -31,13 +57,10 @@ class BlmScatterPlot(plots.ScatterPlot): | ||||
|         sel_dim = dataset_1.get_dim_name(1) | ||||
|         id_1, = dataset_1.get_identifiers(sel_dim, [absi]) | ||||
|         id_2, = dataset_1.get_identifiers(sel_dim, [ordi]) | ||||
|  | ||||
|         col = 'b' | ||||
|         if model.model.has_key(color_by): | ||||
|             col = model.model[color_by].ravel() | ||||
|              | ||||
|         plots.ScatterPlot.__init__(self, dataset_1, dataset_1, id_dim, sel_dim, id_1, id_2 ,c=col ,s=40 , name=title) | ||||
|          | ||||
|         self._mappable.set_cmap(self._cmap) | ||||
|         self.sc = self._mappable | ||||
|         self.add_pc_spin_buttons(self._T.shape[1], absi, ordi) | ||||
| @@ -200,19 +223,95 @@ class LplsZLoadingPlot(BlmScatterPlot): | ||||
|         title = "Lpls z-loadings (%s)" %model._dataset['Z'].get_name() | ||||
|         BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='L', color_by='tsqz') | ||||
|  | ||||
| class LplsXCorrelationPlot(BlmScatterPlot): | ||||
|     def __init__(self, model, absi=0, ordi=1): | ||||
|         title = "Lpls x-corr. loads (%s)" %model._dataset['X'].get_name() | ||||
|         if not model.model.has_key('Rx'): | ||||
|             R = correlation_loadings(model._data['X'], model.model['T']) | ||||
|             model.model['Rx'] = R | ||||
|         BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rx') | ||||
|         radius = 1 | ||||
|         center = (0,0) | ||||
|         c100 = patches.Circle(center,radius=radius, | ||||
|                               facecolor='gray', | ||||
|                               alpha=.1, | ||||
|                               zorder=1) | ||||
|         c50 = patches.Circle(center, radius=radius/2.0, | ||||
|                              facecolor='gray', | ||||
|                              alpha=.1, | ||||
|                              zorder=2) | ||||
|         self.axes.add_patch(c100) | ||||
|         self.axes.add_patch(c50) | ||||
|         self.axes.axhline(lw=1.5,color='k') | ||||
|         self.axes.axvline(lw=1.5,color='k') | ||||
|         self.axes.set_xlim([-1.05,1.05]) | ||||
|         self.axes.set_ylim([-1.05, 1.05]) | ||||
|         self.canvas.show() | ||||
|          | ||||
| class LplsZCorrelationPlot(BlmScatterPlot): | ||||
|     def __init__(self, model, absi=0, ordi=1): | ||||
|         title = "Lpls z-corr. loads (%s)" %model._dataset['Z'].get_name() | ||||
|         if not model.model.has_key('Rz'): | ||||
|             R = correlation_loadings(model._data['Z'].T, model.model['W']) | ||||
|             model.model['Rz'] = R | ||||
|         BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rz') | ||||
|         radius = 1 | ||||
|         center = (0,0) | ||||
|         c100 = patches.Circle(center,radius=radius, | ||||
|                               facecolor='gray', | ||||
|                               alpha=.1, | ||||
|                               zorder=1) | ||||
|         c50 = patches.Circle(center, radius=radius/2.0, | ||||
|                              facecolor='gray', | ||||
|                              alpha=.1, | ||||
|                              zorder=2) | ||||
|         self.axes.add_patch(c100) | ||||
|         self.axes.add_patch(c50) | ||||
|         self.axes.axhline(lw=1.5,color='k') | ||||
|         self.axes.axvline(lw=1.5,color='k') | ||||
|         self.axes.set_xlim([-1.05,1.05]) | ||||
|         self.axes.set_ylim([-1.05, 1.05]) | ||||
|         self.canvas.show() | ||||
|  | ||||
|  | ||||
| class LplsHypoidCorrelationPlot(BlmScatterPlot): | ||||
|     def __init__(self, model, absi=0, ordi=1): | ||||
|         title = "Hypoid correlations(%s)" %model._dataset['X'].get_name() | ||||
|         BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='W') | ||||
|      | ||||
|  | ||||
|  | ||||
| class LineViewXc(plots.LineViewPlot): | ||||
|     """A line view of centered raw data | ||||
|     """ | ||||
|     def __init__(self, model, name='Profiles'): | ||||
|         # copy, center, plot | ||||
|         x = model._dataset['X'].copy() | ||||
|         x._array = x._array - mean(x._array,0)[newaxis] | ||||
|         plots.LineViewPlot.__init__(self, x, 1, None, name) | ||||
|         dx = model._dataset['X'] | ||||
|         plots.LineViewPlot.__init__(self, dx, 1, None, False,name) | ||||
|         self.add_center_check_button(self.data_is_centered) | ||||
|      | ||||
|     def add_center_check_button(self, ticked): | ||||
|         """Add a checker button for centerd view of data.""" | ||||
|         cb = gtk.CheckButton("Center") | ||||
|         cb.set_active(ticked) | ||||
|         cb.connect('toggled', self._toggle_center) | ||||
|         toolitem = gtk.ToolItem()    | ||||
|         toolitem.set_expand(False) | ||||
|         toolitem.set_border_width(2) | ||||
|         toolitem.add(cb) | ||||
|         self._toolbar.insert(toolitem, -1) | ||||
|         toolitem.set_tooltip(self._toolbar.tooltips, "Column center the line view") | ||||
|         self._toolbar.show_all() #do i need this? | ||||
|  | ||||
|     def _toggle_center(self, active): | ||||
|         if self.data_is_centered: | ||||
|             self._data = self._data + self._mn_data | ||||
|             self.data_is_centered = False | ||||
|         else: | ||||
|             self._mn_data = self._data.mean(0) | ||||
|             self._data = self._data - self._mn_data | ||||
|             self.data_is_centered = True | ||||
|         self.make_lines() | ||||
|         self.set_background() | ||||
|         self.set_current_selection(main.project.get_selection()) | ||||
|  | ||||
|  | ||||
| class ParalellCoordinates(plots.Plot): | ||||
|   | ||||
| @@ -171,7 +171,7 @@ class LineViewPlot(Plot): | ||||
|     fixme: slow | ||||
|     """ | ||||
|     @plotlogger | ||||
|     def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"): | ||||
|     def __init__(self, dataset, major_axis=1, minor_axis=None, center=True,name="Line view"): | ||||
|         Plot.__init__(self, name) | ||||
|         self.dataset = dataset | ||||
|         self._data = dataset.asarray() | ||||
| @@ -180,28 +180,48 @@ class LineViewPlot(Plot): | ||||
|         self.major_axis = major_axis | ||||
|         self.minor_axis = minor_axis | ||||
|         self.current_dim = self.dataset.get_dim_name(major_axis) | ||||
|          | ||||
|         #initial draw | ||||
|         self.data_is_centered = False | ||||
|         self._mn_data = 0 | ||||
|         if center and len(self._data.shape)==2: | ||||
|             if minor_axis==0: | ||||
|                 self._mn_data = self._data.mean(minor_axis) | ||||
|             else: | ||||
|                 self._mn_data = self._data.mean(minor_axis)[:,newaxis] | ||||
|             self._data = self._data - self._mn_data  | ||||
|             self.data_is_centered = True | ||||
|         #initial line collection | ||||
|         self.line_coll = None | ||||
|         self.line_segs = [] | ||||
|         x_axis = scipy.arange(self._data.shape[minor_axis]) | ||||
|         for xi in range(self._data.shape[major_axis]): | ||||
|             yi = self._data.take([xi], major_axis).ravel() | ||||
|             self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)]) | ||||
|          | ||||
|         self.make_lines() | ||||
|         # draw background | ||||
|         self._set_background(self.axes) | ||||
|         self.set_background() | ||||
|          | ||||
|         # Disable selection modes | ||||
|         self._toolbar.freeze_button.set_sensitive(False) | ||||
|         self._toolbar.set_mode_sensitive('select', False) | ||||
|         self._toolbar.set_mode_sensitive('lassoselect', False) | ||||
|  | ||||
|     def make_lines(self): | ||||
|         """Creates one line for each item along major axis.""" | ||||
|         if self.line_coll: # remove any previous selection lines, if any | ||||
|             self.axes.collections.remove(self.line_coll) | ||||
|         self.line_coll = None | ||||
|         self.line_segs = [] | ||||
|         x_axis = scipy.arange(self._data.shape[self.minor_axis]) | ||||
|         for xi in range(self._data.shape[self.major_axis]): | ||||
|             yi = self._data.take([xi], self.major_axis).ravel() | ||||
|             self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)]) | ||||
|          | ||||
|     def _set_background(self, ax): | ||||
|     def set_background(self): | ||||
|         """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median. | ||||
|         """ | ||||
|         if self._data.shape[self.minor_axis]<6: | ||||
|             return  | ||||
|             return | ||||
|         # clean old patches if any | ||||
|         if len(self.axes.patches)>0: | ||||
|             self.axes.patches = [] | ||||
|         # clean old lines (median) if any | ||||
|         if len(self.axes.lines)>0: | ||||
|             self.axes.lines = [] | ||||
|         # settings | ||||
|         patch_color = 'b' #blue | ||||
|         patch_lw = 0 #no edges | ||||
| @@ -240,13 +260,16 @@ class LineViewPlot(Plot): | ||||
|                        facecolor=patch_color) | ||||
|  | ||||
|         # add polygons to axes | ||||
|         ax.add_patch(bck0) | ||||
|         ax.add_patch(bck1) | ||||
|         ax.add_patch(bck2) | ||||
|         self.axes.add_patch(bck0) | ||||
|         self.axes.add_patch(bck1) | ||||
|         self.axes.add_patch(bck2) | ||||
|         # median line | ||||
|         ax.plot(xax, med, median_color, linewidth=median_width) | ||||
|         self.axes.plot(xax, med, median_color, linewidth=median_width) | ||||
|  | ||||
|          | ||||
|         # set y-limits | ||||
|         padding = 0.1 | ||||
|         self.axes.set_ylim([self._data.min() - padding, self._data.max() + padding]) | ||||
|                            | ||||
|     def set_current_selection(self, selection): | ||||
|         """Draws the current selection. | ||||
|         """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user