From 4de65f1085679e1bca15588bebe4218301037907 Mon Sep 17 00:00:00 2001 From: flatberg Date: Tue, 27 Feb 2007 15:05:21 +0000 Subject: [PATCH] Clean up --- fluents/plots.py | 300 +++++++++++++++++------------------------------ 1 file changed, 108 insertions(+), 192 deletions(-) diff --git a/fluents/plots.py b/fluents/plots.py index 722d7f2..b3e0efe 100644 --- a/fluents/plots.py +++ b/fluents/plots.py @@ -450,15 +450,22 @@ class Plot (View): View.__init__(self, title) logger.log('debug', 'plot %s init' %title) self.selection_listener = None - self.fig = Figure() - self.canvas = FigureCanvasGTKAgg(self.fig) - self._toolbar = PlotToolbar(self) - self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK) self.current_dim = None self._current_selection = None - self._background = None self._frozen = False - self.use_blit = False + self._init_mpl() + + def _init_mpl(self): + # init matplotlib related stuff + self._background = None + self._use_blit = False + self.fig = Figure() + self.canvas = FigureCanvasGTKAgg(self.fig) + self.axes = self.fig.gca() + self._toolbar = PlotToolbar(self) + self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK) + self.add(self.canvas) + self.canvas.show() def set_frozen(self, frozen): """A frozen plot will not be updated when the current @@ -550,45 +557,33 @@ class LineViewPlot(Plot): -- major_axis : dim_number for line dim (see scipy.ndarray for axis def.) -- minor_axis : needs definition only for higher order arrays - fixme: slow (cant get linecollection and blit to work) + fixme: slow """ def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"): + Plot.__init__(self, name) self.dataset = dataset self._data = dataset.asarray() if len(self._data.shape)==2 and not minor_axis: minor_axis = major_axis - 1 self.major_axis = major_axis self.minor_axis = minor_axis - Plot.__init__(self, name) - self.use_blit = False #fixme: blitting should work self.current_dim = self.dataset.get_dim_name(major_axis) - self.line_coll = None - - # make axes - self.axes = self.fig.add_subplot(111) #initial draw - x_axis = scipy.arange(self._data.shape[minor_axis]) + self.line_coll = None self.line_segs = [] + x_axis = scipy.arange(self._data.shape[minor_axis]) for xi in range(self._data.shape[major_axis]): yi = self._data.take([xi], major_axis).ravel() self.line_segs.append([(xx,yy) for xx,yy in izip(x_axis, yi)]) # draw background self._set_background(self.axes) - - # add canvas - self.add(self.canvas) - self.canvas.show() - - # add toolbar - self._toolbar = PlotToolbar(self) # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) - #self.canvas.mpl_connect('resize_event', self.clear_background) def _set_background(self, ax): """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median. @@ -652,7 +647,7 @@ class LineViewPlot(Plot): self.axes.add_collection(self.line_coll) #draw - if self.use_blit: + if self._use_blit: if self._background is None: self._background = self.canvas.copy_from_bbox(self.axes.bbox) self.canvas.restore_region(self._background) @@ -669,20 +664,19 @@ class ScatterMarkerPlot(Plot): def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, s=6, name="Scatter plot"): Plot.__init__(self, name) - self.axes = self.fig.add_subplot(111) - self.axes.axhline(0, color='k', lw=1., zorder=1) - self.axes.axvline(0, color='k', lw=1., zorder=1) self.current_dim = id_dim self.dataset_1 = dataset_1 self.ms = s - self._selection_line = None x_index = dataset_1[sel_dim][id_1] y_index = dataset_2[sel_dim][id_2] self.xaxis_data = dataset_1._array[:, x_index] self.yaxis_data = dataset_2._array[:, y_index] + + # init draw + self._selection_line = None self.line = self.axes.plot(self.xaxis_data, self.yaxis_data, 'o', markeredgewidth=0, markersize=s) - self.add(self.canvas) - self.canvas.show() + self.axes.axhline(0, color='k', lw=1., zorder=1) + self.axes.axvline(0, color='k', lw=1., zorder=1) def rectangle_select_callback(self, x1, y1, x2, y2, key): ydata = self.yaxis_data @@ -726,7 +720,7 @@ class ScatterMarkerPlot(Plot): markeredgewidth=1.0) self.axes.add_line(self._selection_line) - if self.use_blit: + if self._use_blit: if self._background is None: self._background = self.canvas.copy_from_bbox(self.axes.bbox) self.canvas.restore_region(self._background) @@ -740,11 +734,11 @@ class ScatterMarkerPlot(Plot): class ScatterPlot(Plot): """The ScatterPlot is slower than scattermarker, but has size option.""" def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, c='b', s=30, sel_dim_2=None, name="Scatter plot"): + + Plot.__init__(self, name) self.dataset_1 = dataset_1 self.s = s self.c = c - Plot.__init__(self, name) - self.use_blit = False self.current_dim = id_dim x_index = dataset_1[sel_dim][id_1] @@ -754,15 +748,11 @@ class ScatterPlot(Plot): y_index = dataset_2[sel_dim][id_2] self.xaxis_data = dataset_1._array[:, x_index] self.yaxis_data = dataset_2._array[:, y_index] - # labels - self._text_labels = None + # init draw self.init_draw() - - # add canvas to widget - self.add(self.canvas) - self.canvas.show() + # signals to enable correct use of blit self.connect('zoom-changed', self.onzoom) self.connect('pan-changed', self.onpan) self.need_redraw = False @@ -781,7 +771,7 @@ class ScatterPlot(Plot): self.clean_redraw() def clean_redraw(self): - if self.use_blit == True: + if self._use_blit == True: logger.log('notice', 'blit -> clean redraw ') self.set_current_selection(None) self._background = self.canvas.copy_from_bbox(self.axes.bbox) @@ -790,7 +780,6 @@ class ScatterPlot(Plot): self._background = None def init_draw(self): - self.axes = self.fig.add_subplot(111) lw = scipy.zeros(self.xaxis_data.shape) self.sc = self.axes.scatter(self.xaxis_data, self.yaxis_data, s=self.s, c=self.c, linewidth=lw, zorder=3) @@ -876,7 +865,7 @@ class ScatterPlot(Plot): linewidth.put(1, index) self.selection_collection.set_linewidth(linewidth) - if self.use_blit and len(index)>0 : + if self._use_blit and len(index)>0 : if self._background is None: self._background = self.canvas.copy_from_bbox(self.axes.bbox) self.canvas.restore_region(self._background) @@ -888,31 +877,19 @@ class ScatterPlot(Plot): class ImagePlot(Plot): def __init__(self, dataset, **kw): + Plot.__init__(self, kw.get('name', 'Image Plot')) self.dataset = dataset - self.keywords = kw - - Plot.__init__(self, kw['name']) - - self.axes = self.fig.add_subplot(111) - self.axes.set_xticks([]) - self.axes.set_yticks([]) - self.axes.grid(False) # Initial draw - self.axes.imshow(dataset.asarray(), interpolation='nearest', aspect='auto') - - # Add canvas and show - self.add(self.canvas) - self.canvas.show() + self.axes.grid(False) + self.axes.imshow(dataset.asarray(), interpolation='nearest') + self.axes.axis('tight') # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) - def get_toolbar(self): - return self._toolbar - class HistogramPlot(Plot): """ Histogram plot. @@ -924,8 +901,6 @@ class HistogramPlot(Plot): Plot.__init__(self, kw['name']) self.dataset = dataset self._data = dataset.asarray() - self.axes = self.fig.add_subplot(111) - self.axes.grid(False) # If dataset is 1-dim we may do selections if dataset.shape[0]==1: @@ -933,7 +908,8 @@ class HistogramPlot(Plot): if dataset.shape[1]==1: self.current_dim = dataset.get_dim_name(0) # Initial draw - bins = min(len(self.dataset[self.current_dim]), 20) + self.axes.grid(False) + bins = min(self._data.size, 20) count, lims, self.patches = self.axes.hist(self._data, bins=bins) # Add identifiers to the individual patches @@ -946,9 +922,6 @@ class HistogramPlot(Plot): bool_ind = scipy.bitwise_and(self._data>=lims[i], self._data<=end_lim) patch.index = scipy.where(bool_ind)[0] - # Add canvas and show - self.add(self.canvas) - self.canvas.show() if self.current_dim==None: # Disable selection modes @@ -1010,14 +983,13 @@ class BarPlot(Plot): Ordinary bar plot for (column) vectors. For matrices there is one color for each row. """ - def __init__(self, dataset, name): + def __init__(self, dataset, **kw): + Plot.__init__(self, kw.get('name', 'Bar Plot')) self.dataset = dataset - n, m = dataset.shape - Plot.__init__(self, name) - self.axes = self.fig.add_subplot(111) - self.axes.grid(False) - + # Initial draw + self.axes.grid(False) + n, m = dataset.shape if m>1: sm = cm.ScalarMappable() clrs = sm.to_rgba(range(n)) @@ -1032,157 +1004,110 @@ class BarPlot(Plot): left = scipy.arange(1, n, 1) self.axes.bar(left, height) - # Add canvas and show - self.add(self.canvas) - self.canvas.show() - # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) - def get_toolbar(self): - return self._toolbar - class NetworkPlot(Plot): - def __init__(self, dataset, **kw): - # Set member variables and call superclass' constructor + def __init__(self, dataset, pos=None, nodecolor='b', nodesize=40, + prog='neato', with_labels=False, name='Network Plot'): + + Plot.__init__(self, name) + self.dataset = dataset self.graph = dataset.asnetworkx() - self.dataset = dataset - self.keywords = kw - if not kw.has_key('name'): - kw['name'] = self.dataset.get_name() - Plot.__init__(self, kw['name']) + self._prog = prog + self._pos = pos + self._nodesize = nodesize + self._nodecolor = nodecolor + self._with_labels = with_labels + self.current_dim = self.dataset.get_dim_name(0) - if not kw.has_key('prog'): - kw['prog'] = 'neato' - if not kw.has_key('pos'): - kw['pos'] = networkx.graphviz_layout(self.graph, 'neato') - if not kw.has_key('nodelist'): - kw['nodelist'] = self.dataset.get_identifiers(self.current_dim, sorted=True) - if not kw.has_key('with_labels'): - kw['with_labels'] = True - # Keep node size and color as dicts for fast lookup - self.node_size = {} - if kw.has_key('node_size') and cbook.iterable(kw['node_size']): - kw.remove('node_size') - for id, size in zip(self.dataset[self.current_dim], kw['node_size']): - self.node_size[id] = size - else: - for id in dataset[self.current_dim]: - self.node_size[id] = 30 - - self.node_color = {} - if kw.has_key('node_color') and cbook.iterable(kw['node_color']): - kw.remove('node_color') - for id, color in zip(self.dataset[self.current_dim], kw['node_color']): - self.node_color[id] = color - else: - for id in self.dataset[self.current_dim]: - self.node_color[id] = 1.0 + if not self._pos: + self._pos = networkx.graphviz_layout(self.graph, self._prog) + self._xy = scipy.asarray([self._pos[node] for node in self.dataset.get_identifiers(self.current_dim, sorted=True)]) + self.xaxis_data = self._xy[:,0] + self.yaxis_data = self._xy[:,1] - self.axes = self.fig.add_subplot(111) + # Initial draw + self.default_props = {'nodesize': 50, 'nodecolor':'gray'} + self.node_collection = None + self.edge_collection = None + self.node_labels = None + lw = scipy.zeros(self.xaxis_data.shape) + self.node_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data, + s=self._nodesize, + c=self._nodecolor, + linewidth=lw, + zorder=3) + # selected nodes is a transparent graph that adjust edge-visibility + # according to the current selection + self.selected_nodes = self.axes.scatter(self.xaxis_data, + self.yaxis_data, + s=self._nodesize, + c=self._nodecolor, + linewidth=lw, + zorder=4, + alpha=0) + + self.edge_collection = networkx.draw_networkx_edges(self.graph, + self._pos, + ax=self.axes, + edge_color='gray') + if self._with_labels: + self.node_labels = networkx.draw_networkx_labels(self.graph, + self._pos, + ax=self.axes) + + # remove axes, frame and grid self.axes.set_xticks([]) self.axes.set_yticks([]) self.axes.grid(False) self.axes.set_frame_on(False) - # Add canvas and show - self.add(self.canvas) - self.canvas.show() - - # Initial draw - networkx.draw_networkx(self.graph, ax=self.axes, node_size=30, node_color='gray', **kw) - del kw['nodelist'] - - def get_toolbar(self): - return self._toolbar - def rectangle_select_callback(self, x1, y1, x2, y2, key): - pos = self.keywords['pos'] - ydata = scipy.zeros((len(pos),), 'l') - xdata = scipy.zeros((len(pos),), 'l') - node_ids = [] - c = 0 - for name,(x,y) in pos.items(): - node_ids.append(name) - xdata[c] = x - ydata[c] = y - c+=1 + ydata = self.yaxis_data + xdata = self.xaxis_data # find indices of selected area - if x1 > x2: + if x1>x2: x1, x2 = x2, x1 - if y1 > y2: + if y1>y2: y1, y2 = y2, y1 + assert x1<=x2 + assert y1<=y2 + index = scipy.nonzero((xdata>x1) & (xdatay1) & (ydata 0: + linewidth.put(2, index) + nodelist = selection[self.current_dim] + selected_edges = self.grap + self.selected_nodes.set_linewidth(linewidth) self.canvas.draw() class VennPlot(Plot): def __init__(self, name="Venn diagram"): Plot.__init__(self, name) - self.axes = self.fig.add_subplot(111) - self.axes.grid(0) + self._init_bck() + # init draw for c in self._venn_patches: self.axes.add_patch(c) for mrk in self._markers: @@ -1192,14 +1117,10 @@ class VennPlot(Plot): self._last_active = set() self.axes.set_xticks([]) self.axes.set_yticks([]) - self.axes.grid(0) self.axes.axis('equal') + self.axes.grid(False) self.axes.set_frame_on(False) - # add canvas to widget - self.add(self.canvas) - self.canvas.show() - def _init_bck(self): res = 50 a = .5 @@ -1293,11 +1214,6 @@ class VennPlot(Plot): self.axes.figure.canvas.draw() def rectangle_select_callback(self, x1, y1, x2, y2, key): - """event1 and event2 are the press and release events""" - #x1, y1 = event1.xdata, event1.ydata - #x2, y2 = event2.xdata, event2.ydata - #key = event1.key - verts = [(x1, y1), (x2, y2)] if key!='shift': for m in self._markers: