From b35f814ef0ce275e34079414937eb403c2ddc447 Mon Sep 17 00:00:00 2001 From: flatberg Date: Thu, 26 Jul 2007 18:26:50 +0000 Subject: [PATCH] Added center check button on lineview --- fluents/lib/blmplots.py | 121 ++++++++++++++++++++++++++++++++++++---- fluents/plots.py | 57 +++++++++++++------ 2 files changed, 150 insertions(+), 28 deletions(-) diff --git a/fluents/lib/blmplots.py b/fluents/lib/blmplots.py index a0d976f..236a9a6 100644 --- a/fluents/lib/blmplots.py +++ b/fluents/lib/blmplots.py @@ -5,13 +5,39 @@ fixme: colorbar, but when adding colors the colorbar shoud be created. """ -from matplotlib import cm +from matplotlib import cm,patches import gtk import fluents -from fluents import plots +from fluents import plots, main import scipy -from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt,apply_along_axis +from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt,apply_along_axis,empty +from scipy.stats import corrcoef +def correlation_loadings(data, T, test=True): + """ Returns correlation loadings. + + :input: + - D: [nsamps, nvars], data (non-centered data) + - T: [nsamps, a_max], Scores + :ouput: + - R: [nvars, a_max], Correlation loadings + + :notes: + + """ + nsamps, nvars = data.shape + nsampsT, a_max = T.shape + + if nsamps!=nsampsT: raise IOError("D/T mismatch") + + # center + data = data - data.mean(0) + R = empty((nvars, a_max),'d') + for a in range(a_max): + for k in range(nvars): + R[k,a] = corrcoef(data[:,k], T[:,a])[0,1] + + return R class BlmScatterPlot(plots.ScatterPlot): """Scatter plot used for scores and loadings in bilinear models.""" @@ -31,13 +57,10 @@ class BlmScatterPlot(plots.ScatterPlot): sel_dim = dataset_1.get_dim_name(1) id_1, = dataset_1.get_identifiers(sel_dim, [absi]) id_2, = dataset_1.get_identifiers(sel_dim, [ordi]) - col = 'b' if model.model.has_key(color_by): col = model.model[color_by].ravel() - plots.ScatterPlot.__init__(self, dataset_1, dataset_1, id_dim, sel_dim, id_1, id_2 ,c=col ,s=40 , name=title) - self._mappable.set_cmap(self._cmap) self.sc = self._mappable self.add_pc_spin_buttons(self._T.shape[1], absi, ordi) @@ -200,19 +223,95 @@ class LplsZLoadingPlot(BlmScatterPlot): title = "Lpls z-loadings (%s)" %model._dataset['Z'].get_name() BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='L', color_by='tsqz') +class LplsXCorrelationPlot(BlmScatterPlot): + def __init__(self, model, absi=0, ordi=1): + title = "Lpls x-corr. loads (%s)" %model._dataset['X'].get_name() + if not model.model.has_key('Rx'): + R = correlation_loadings(model._data['X'], model.model['T']) + model.model['Rx'] = R + BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rx') + radius = 1 + center = (0,0) + c100 = patches.Circle(center,radius=radius, + facecolor='gray', + alpha=.1, + zorder=1) + c50 = patches.Circle(center, radius=radius/2.0, + facecolor='gray', + alpha=.1, + zorder=2) + self.axes.add_patch(c100) + self.axes.add_patch(c50) + self.axes.axhline(lw=1.5,color='k') + self.axes.axvline(lw=1.5,color='k') + self.axes.set_xlim([-1.05,1.05]) + self.axes.set_ylim([-1.05, 1.05]) + self.canvas.show() + +class LplsZCorrelationPlot(BlmScatterPlot): + def __init__(self, model, absi=0, ordi=1): + title = "Lpls z-corr. loads (%s)" %model._dataset['Z'].get_name() + if not model.model.has_key('Rz'): + R = correlation_loadings(model._data['Z'].T, model.model['W']) + model.model['Rz'] = R + BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rz') + radius = 1 + center = (0,0) + c100 = patches.Circle(center,radius=radius, + facecolor='gray', + alpha=.1, + zorder=1) + c50 = patches.Circle(center, radius=radius/2.0, + facecolor='gray', + alpha=.1, + zorder=2) + self.axes.add_patch(c100) + self.axes.add_patch(c50) + self.axes.axhline(lw=1.5,color='k') + self.axes.axvline(lw=1.5,color='k') + self.axes.set_xlim([-1.05,1.05]) + self.axes.set_ylim([-1.05, 1.05]) + self.canvas.show() + + class LplsHypoidCorrelationPlot(BlmScatterPlot): def __init__(self, model, absi=0, ordi=1): title = "Hypoid correlations(%s)" %model._dataset['X'].get_name() BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='W') - + + class LineViewXc(plots.LineViewPlot): """A line view of centered raw data """ def __init__(self, model, name='Profiles'): - # copy, center, plot - x = model._dataset['X'].copy() - x._array = x._array - mean(x._array,0)[newaxis] - plots.LineViewPlot.__init__(self, x, 1, None, name) + dx = model._dataset['X'] + plots.LineViewPlot.__init__(self, dx, 1, None, False,name) + self.add_center_check_button(self.data_is_centered) + + def add_center_check_button(self, ticked): + """Add a checker button for centerd view of data.""" + cb = gtk.CheckButton("Center") + cb.set_active(ticked) + cb.connect('toggled', self._toggle_center) + toolitem = gtk.ToolItem() + toolitem.set_expand(False) + toolitem.set_border_width(2) + toolitem.add(cb) + self._toolbar.insert(toolitem, -1) + toolitem.set_tooltip(self._toolbar.tooltips, "Column center the line view") + self._toolbar.show_all() #do i need this? + + def _toggle_center(self, active): + if self.data_is_centered: + self._data = self._data + self._mn_data + self.data_is_centered = False + else: + self._mn_data = self._data.mean(0) + self._data = self._data - self._mn_data + self.data_is_centered = True + self.make_lines() + self.set_background() + self.set_current_selection(main.project.get_selection()) class ParalellCoordinates(plots.Plot): diff --git a/fluents/plots.py b/fluents/plots.py index a6e4b64..d9ffd2e 100644 --- a/fluents/plots.py +++ b/fluents/plots.py @@ -171,7 +171,7 @@ class LineViewPlot(Plot): fixme: slow """ @plotlogger - def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"): + def __init__(self, dataset, major_axis=1, minor_axis=None, center=True,name="Line view"): Plot.__init__(self, name) self.dataset = dataset self._data = dataset.asarray() @@ -180,28 +180,48 @@ class LineViewPlot(Plot): self.major_axis = major_axis self.minor_axis = minor_axis self.current_dim = self.dataset.get_dim_name(major_axis) - - #initial draw + self.data_is_centered = False + self._mn_data = 0 + if center and len(self._data.shape)==2: + if minor_axis==0: + self._mn_data = self._data.mean(minor_axis) + else: + self._mn_data = self._data.mean(minor_axis)[:,newaxis] + self._data = self._data - self._mn_data + self.data_is_centered = True + #initial line collection self.line_coll = None - self.line_segs = [] - x_axis = scipy.arange(self._data.shape[minor_axis]) - for xi in range(self._data.shape[major_axis]): - yi = self._data.take([xi], major_axis).ravel() - self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)]) - + self.make_lines() # draw background - self._set_background(self.axes) + self.set_background() # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) + + def make_lines(self): + """Creates one line for each item along major axis.""" + if self.line_coll: # remove any previous selection lines, if any + self.axes.collections.remove(self.line_coll) + self.line_coll = None + self.line_segs = [] + x_axis = scipy.arange(self._data.shape[self.minor_axis]) + for xi in range(self._data.shape[self.major_axis]): + yi = self._data.take([xi], self.major_axis).ravel() + self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)]) - def _set_background(self, ax): + def set_background(self): """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median. """ if self._data.shape[self.minor_axis]<6: - return + return + # clean old patches if any + if len(self.axes.patches)>0: + self.axes.patches = [] + # clean old lines (median) if any + if len(self.axes.lines)>0: + self.axes.lines = [] # settings patch_color = 'b' #blue patch_lw = 0 #no edges @@ -240,13 +260,16 @@ class LineViewPlot(Plot): facecolor=patch_color) # add polygons to axes - ax.add_patch(bck0) - ax.add_patch(bck1) - ax.add_patch(bck2) + self.axes.add_patch(bck0) + self.axes.add_patch(bck1) + self.axes.add_patch(bck2) # median line - ax.plot(xax, med, median_color, linewidth=median_width) + self.axes.plot(xax, med, median_color, linewidth=median_width) - + # set y-limits + padding = 0.1 + self.axes.set_ylim([self._data.min() - padding, self._data.max() + padding]) + def set_current_selection(self, selection): """Draws the current selection. """