diff --git a/fluents/plots.py b/fluents/plots.py index d757908..9b4dab6 100644 --- a/fluents/plots.py +++ b/fluents/plots.py @@ -14,7 +14,7 @@ from matplotlib.nxutils import points_inside_poly from matplotlib.axes import Subplot, AxesImage from matplotlib.figure import Figure from matplotlib import cm,cbook -from pylab import Polygon +from pylab import Polygon, axis, Circle from matplotlib.collections import LineCollection from matplotlib.mlab import prctile import networkx @@ -227,7 +227,11 @@ class ViewFrame (gtk.Frame): if view.is_mappable_with(obj): view._update_color_from_dataset(obj) - # add selections below + elif isinstance(obj, fluents.dataset.Selection): + view = self.get_view() + if view.is_mappable_with(obj): + view.selection_changed(self.current_dim, obj) + class MainView (gtk.Notebook): @@ -501,11 +505,11 @@ class LineViewPlot(Plot): self.dataset = dataset self._data = dataset.asarray() if len(self._data.shape)==2 and not minor_axis: - minor_axis = major_axis-1 + minor_axis = major_axis - 1 self.major_axis = major_axis self.minor_axis = minor_axis Plot.__init__(self, name) - self.use_blit = False #fixme: blitting does work + self.use_blit = False #fixme: blitting should work self.current_dim = self.dataset.get_dim_name(major_axis) # make axes @@ -530,9 +534,11 @@ class LineViewPlot(Plot): #self._toolbar = PlotToolbar(self) self.canvas.mpl_connect('resize_event', self.clear_background) - def _set_background(self,ax): + def _set_background(self, ax): """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median. """ + if self._data.shape[self.minor_axis]<10: + return # settings patch_color = 'b' #blue patch_lw = 0 #no edges @@ -577,6 +583,13 @@ class LineViewPlot(Plot): # median line ax.plot(xax, med, median_color, linewidth=median_width) + # Disable selection modes + btn = self._toolbar.get_button('select') + btn.set_sensitive(False) + btn = self._toolbar.get_button('lassoselect') + btn.set_sensitive(False) + self._toolbar.freeze_button.set_sensitive(False) + def clear_background(self, event): """Callback on resize event. Clears the background. """ @@ -686,10 +699,11 @@ class ScatterMarkerPlot(Plot): class ScatterPlot(Plot): """The ScatterPlot is slower than scattermarker, but has size option.""" - def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2,c='b',s=30,sel_dim_2=None, name="Scatter plot"): + def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, c='b', s=30, sel_dim_2=None, name="Scatter plot"): Plot.__init__(self, name) self.use_blit = False self.ax = self.fig.add_subplot(111) + self._clean_bck = self.canvas.copy_from_bbox(self.ax.bbox) self.current_dim = id_dim self.dataset_1 = dataset_1 x_index = dataset_1[sel_dim][id_1] @@ -701,27 +715,32 @@ class ScatterPlot(Plot): self.yaxis_data = dataset_2._array[:, y_index] lw = scipy.zeros(self.xaxis_data.shape) self.sc = sc = self.ax.scatter(self.xaxis_data, self.yaxis_data, - s=s, c=c, linewidth=lw, - edgecolor='k', alpha=.8, - cmap=cm.jet) + s=s, c=c, linewidth=lw) if len(c)>1: self.fig.colorbar(sc, fraction=.05) self.ax.axhline(0, color='k', lw=1., zorder=1) self.ax.axvline(0, color='k', lw=1., zorder=1) - # collection - self.coll = self.ax.collections[0] - + # labels + self._text_labels = None + # add canvas to widget self.add(self.canvas) self.canvas.show() - def is_mappable_with(self, dataset): - """Returns True if dataset is mappable with this plot. + def is_mappable_with(self, obj): + """Returns True if dataset/selection is mappable with this plot. """ - if self.current_dim in dataset.get_dim_name() and dataset.asarray().shape[0] == self.xaxis_data.shape[0]: - return True - + if isinstance(obj, fluents.dataset.Dataset): + if self.current_dim in obj.get_dim_name() and obj.asarray().shape[0] == self.xaxis_data.shape[0]: + return True + elif isinstance(obj, fluents.dataset.Selection): + if self.current_dim in obj.get_dim_name(): + print "Selection is mappable" + return True + else: + return False + def _update_color_from_dataset(self, data): """Updates the facecolors from a dataset. """ @@ -730,7 +749,7 @@ class ScatterPlot(Plot): try: m, n = array.shape except: - raise ValueError, "No support for more tha 2 dimensions." + raise ValueError, "No support for more than 2 dimensions." # is dataset a vector or matrix? if not n==1: # we have a category dataset @@ -741,16 +760,12 @@ class ScatterPlot(Plot): else: map_vec = array.ravel() - # normalise mapping vector - map_vec = map_vec - map_vec.min() - map_vec = map_vec/map_vec.max() # update facecolors - self.sc._facecolors = self.sc.to_rgba(map_vec, self.sc._alpha) - # draw - self.sc._A = None # mean hack - self.ax.draw_artist(self.sc) + self.sc.set_array(map_vec) + self.sc.set_clim(map_vec.min(), map_vec.max()) + self.sc.update_scalarmappable() #sets facecolors from array self.canvas.draw() - + def rectangle_select_callback(self, x1, y1, x2, y2, key): ydata = self.yaxis_data xdata = self.xaxis_data @@ -775,13 +790,11 @@ class ScatterPlot(Plot): ids = self.dataset_1.get_identifiers(self.current_dim, index) ids = self.update_selection(ids, key) self.selection_listener(self.current_dim, ids) - self.canvas.widgetlock.release(self._lasso) def set_current_selection(self, selection): ids = selection[self.current_dim] # current identifiers if len(ids)==0: return - #self._toolbar.forward() #update data lims before draw index = self.dataset_1.get_indices(self.current_dim, ids) if self.use_blit: if self._background is None: @@ -790,12 +803,13 @@ class ScatterPlot(Plot): lw = scipy.zeros(self.xaxis_data.shape, 'f') if len(index)>0: lw.put(2., index) - self.coll.set_linewidth(lw) + self.sc.set_linewidth(lw) if self.use_blit: + self.ax.draw_artist(self.sc) self.canvas.blit() - self.ax.draw_artist(self.coll) else: + print self.ax.lines self.canvas.draw() @@ -811,10 +825,6 @@ class ImagePlot(Plot): self.ax.set_yticks([]) self.ax.grid(False) - # FIXME: ax shouldn't be in kw at all - if kw.has_key('ax'): - kw.pop('ax') - # Initial draw self.ax.imshow(dataset.asarray(), interpolation='nearest', aspect='auto') @@ -831,20 +841,15 @@ class ImagePlot(Plot): def get_toolbar(self): return self._toolbar - + + class HistogramPlot(Plot): def __init__(self, dataset, **kw): - self.dataset = dataset - self.keywords = kw - Plot.__init__(self, kw['name']) - + self.ax = self.fig.add_subplot(111) - #self.ax.set_xticks([]) - #self.ax.set_yticks([]) self.ax.grid(False) - # FIXME: ax shouldn't be in kw at all - + # Initial draw self.ax.hist(dataset.asarray(), bins=20) @@ -852,6 +857,56 @@ class HistogramPlot(Plot): self.add(self.canvas) self.canvas.show() + # Disable selection modes + btn = self._toolbar.get_button('select') + btn.set_sensitive(False) + btn = self._toolbar.get_button('lassoselect') + btn.set_sensitive(False) + self._toolbar.freeze_button.set_sensitive(False) + + def get_toolbar(self): + return self._toolbar + + +class BarPlot(Plot): + """Bar plot. + + Ordinary bar plot for (column) vectors. + For matrices there is one color for each row. + """ + def __init__(self, dataset, name): + self.dataset = dataset + n, m = dataset.shape + Plot.__init__(self, name) + self.ax = self.fig.add_subplot(111) + self.ax.grid(False) + + # Initial draw + if m>1: + sm = cm.ScalarMappable() + clrs = sm.to_rgba(range(n)) + for i, row in enumerate(dataset.asarray()): + left = scipy.arange(i+1, m*n+1, n) + height = row + color = clrs[i] + c = (color[0], color[1], color[2]) + self.ax.bar(left, height,color=c) + else: + height = dataset.asarray().ravel() + left = scipy.arange(1, n, 1) + self.ax.bar(left, height) + + # Add canvas and show + self.add(self.canvas) + self.canvas.show() + + # Disable selection modes + btn = self._toolbar.get_button('select') + btn.set_sensitive(False) + btn = self._toolbar.get_button('lassoselect') + btn.set_sensitive(False) + self._toolbar.freeze_button.set_sensitive(False) + def get_toolbar(self): return self._toolbar @@ -904,6 +959,7 @@ class NetworkPlot(Plot): self.ax.set_xticks([]) self.ax.set_yticks([]) self.ax.grid(False) + self.ax.set_frame_on(False) # FIXME: ax shouldn't be in kw at all if kw.has_key('ax'): kw.pop('ax') @@ -994,6 +1050,176 @@ class NetworkPlot(Plot): self.canvas.draw() +class VennPlot(Plot): + def __init__(self, name="Venn diagram"): + Plot.__init__(self, name) + self._ax = self.fig.add_subplot(111) + self._ax.grid(0) + self._init_bck() + for c in self._venn_patches: + self._ax.add_patch(c) + for mrk in self._markers: + self._ax.add_patch(mrk) + self._ax.set_xlim([-3, 3]) + self._ax.set_ylim([-2.5, 3.5]) + self._last_active = set() + self._ax.set_xticks([]) + self._ax.set_yticks([]) + self._ax.grid(0) + self._ax.axis('equal') + self._ax.set_frame_on(False) + # add canvas to widget + self.add(self.canvas) + self.canvas.show() + + def _init_bck(self): + res = 50 + a = .5 + r = 1.5 + mr = .2 + self.c1 = c1 = Circle((-1,0), radius=r, alpha=a, facecolor='b', resolution=res) + self.c2 = c2 = Circle((1,0), radius=r, alpha=a, facecolor='r', resolution=res) + self.c3 = c3 = Circle((0, scipy.sqrt(3)), radius=r, alpha=a, facecolor='g', resolution=res) + + self.c1marker = Circle((-1.25, -.25), radius=mr, facecolor='y', alpha=0) + self.c2marker = Circle((1.25, -.25), radius=mr, facecolor='y', alpha=0) + self.c3marker = Circle((0, scipy.sqrt(3)+.25), radius=mr, facecolor='y', alpha=0) + self.c1c2marker = Circle((0, -.15), radius=mr, facecolor='y', alpha=0) + + self.c1c3marker = Circle((-scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0) + self.c2c3marker = Circle((scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0) + self.c1c2c3marker = Circle((0, .6), radius=mr, facecolor='y', alpha=0) + + c1.elements = set(['a', 'b', 'c', 'f']) + c2.elements = set(['a', 'c', 'd', 'e']) + c3.elements = set(['a', 'e', 'f', 'g']) + self.active_elements = set() + self.all_elements = c1.elements.union(c2.elements).union(c3.elements) + + c1.active = False + c2.active = False + c3.active = False + + c1.name = 'Blue' + c2.name = 'Red' + c3.name = 'Green' + + self._venn_patches = [c1, c2, c3] + self._markers = [self.c1marker, self.c2marker, self.c3marker, + self.c1c2marker, self.c1c3marker, + self.c2c3marker, self.c1c2c3marker] + + self._tot_label = 'Tot: ' + str(len(self.all_elements)) + self._sel_label = 'Sel: ' + str(len(self.active_elements)) + self._legend = self._ax.legend((self._tot_label, self._sel_label), + loc='upper right') + + def set_selection(self, selection, patch=None): + if patch: + patch.selection = selection + else: + selection_set = False + for patch in self._venn_patches: + if len(patch.elements)==0: + patch.elements = selection + selection_set = True + if not selection_set: + self.venn_patches[0].elements = selection + + def lasso_select_callback(self, verts, key=None): + if verts==None: + print "ks" + verts = (self._event.xdata, self._event.ydata) + if key!='shift': + for m in self._markers: + m.set_alpha(0) + + self._patches_within_verts(verts, key) + active = [i.active for i in self._venn_patches] + if active==[True, False, False]: + self.c1marker.set_alpha(1) + self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements)) + elif active== [False, True, False]: + self.c2marker.set_alpha(1) + self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements)) + elif active== [False, False, True]: + self.c3marker.set_alpha(1) + self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements)) + elif active==[True, True, False]: + self.c1c2marker.set_alpha(1) + self.active_elements = self.c1.elements.intersection(self.c2.elements) + elif active==[True, False, True]: + self.c1c3marker.set_alpha(1) + self.active_elements = self.c1.elements.intersection(self.c3.elements) + elif active==[False, True, True]: + self.c2c3marker.set_alpha(1) + self.active_elements = self.c2.elements.intersection(self.c3.elements) + elif active==[True, True, True]: + self.c1c2c3marker.set_alpha(1) + self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements) + + if key=='shift': + self.active_elements = self.active_elements.union(self._last_active) + self._last_active = self.active_elements.copy() + self._sel_label = 'Sel: ' + str(len(self.active_elements)) + self._legend.texts[1].set_text(self._sel_label) + self.canvas.widgetlock.release(self._lasso) + del self._lasso + self._ax.figure.canvas.draw() + + def rectangle_select_callback(self, x1, y1, x2, y2, key): + """event1 and event2 are the press and release events""" + #x1, y1 = event1.xdata, event1.ydata + #x2, y2 = event2.xdata, event2.ydata + #key = event1.key + + verts = [(x1, y1), (x2, y2)] + if key!='shift': + for m in self._markers: + m.set_alpha(0) + + self._patches_within_verts(verts, key) + active = [i.active for i in self._venn_patches] + if active==[True, False, False]: + self.c1marker.set_alpha(1) + self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements)) + elif active== [False, True, False]: + self.c2marker.set_alpha(1) + self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements)) + elif active== [False, False, True]: + self.c3marker.set_alpha(1) + self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements)) + elif active==[True, True, False]: + self.c1c2marker.set_alpha(1) + self.active_elements = self.c1.elements.intersection(self.c2.elements) + elif active==[True, False, True]: + self.c1c3marker.set_alpha(1) + self.active_elements = self.c1.elements.intersection(self.c3.elements) + elif active==[False, True, True]: + self.c2c3marker.set_alpha(1) + self.active_elements = self.c2.elements.intersection(self.c3.elements) + elif active==[True, True, True]: + self.c1c2c3marker.set_alpha(1) + self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements) + + if key=='shift': + self.active_elements = self.active_elements.union(self._last_active) + self._last_active = self.active_elements.copy() + self._sel_label = 'Sel: ' + str(len(self.active_elements)) + self._legend.texts[1].set_text(self._sel_label) + self._ax.figure.canvas.draw() + + def _patches_within_verts(self, verts, key): + xy = scipy.array(verts).mean(0) + for venn_patch in self._venn_patches: + venn_patch.active = False + if self._distance(venn_patch.center,xy)