import pygtk import gobject import gtk import matplotlib from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg from matplotlib.nxutils import points_inside_poly from matplotlib.figure import Figure from matplotlib.collections import LineCollection from matplotlib.patches import Polygon,Rectangle,Circle from matplotlib.lines import Line2D from matplotlib.mlab import prctile from matplotlib.colors import ColorConverter import networkx import scipy import fluents import logger import view def plotlogger(func, name=None): def wrapped(parent, *args, **kw): parent.__args = args parent.__kw = kw return func(parent, *args, **kw) return wrapped class Plot(view.View): def __init__(self, title): view.View.__init__(self, title) logger.log('debug', 'plot %s init' %title) self.selection_listener = None self.current_dim = None self._current_selection = None self._frozen = False self._init_mpl() def _init_mpl(self): # init matplotlib related stuff self._background = None self._colorbar = None self._mappable = None self._use_blit = False self.fig = Figure() self.canvas = FigureCanvasGTKAgg(self.fig) self.axes = self.fig.gca() self._toolbar = view.PlotToolbar(self) self._key_press = self.canvas.mpl_connect( 'key_press_event', self.on_key_press) self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK) self.add(self.canvas) self.canvas.show() def set_frozen(self, frozen): """A frozen plot will not be updated when the current selection is changed.""" self._frozen = frozen if not frozen and self._current_selection != None: self.set_current_selection(self._current_selection) def get_title(self): return self.title def get_toolbar(self): return self._toolbar def selection_changed(self, dim_name, selection): """ Selection observer handle. A selection change in a plot is only drawn if: 1.) plot is sensitive to selections (not freezed) 2.) plot is visible (has a view) 3.) the selections dim_name is the plot's dimension. """ self._current_selection = selection if self._frozen \ or not self.get_property('visible') \ or self.current_dim != dim_name: return self.set_current_selection(selection) def set_selection_listener(self, listener): """Allow project to listen to selections. The selection will propagate back to all plots through the selection_changed() method. The listener will be called as listener(dimension_name, ids). """ self.selection_listener = listener def update_selection(self, ids, key=None): """Returns updated current selection from ids. If a key is pressed we use the appropriate mode. key map: shift : union control : intersection """ if key == 'shift': ids = set(ids).union(self._current_selection[self.current_dim]) elif key == 'control': ids = set(ids).intersection(self._current_selection[self.current_dim]) return ids def set_current_selection(self, selection): """Called whenever the plot should change the selection. This method is a dummy method, so that specialized plots that have no implemented selection can ignore selections alltogether. """ pass def rectangle_select_callback(self, *args): """Overrriden in subclass.""" if hasattr(self, 'canvas'): self.canvas.draw() def lasso_select_callback(self, *args): """Overrriden in subclass.""" if hasattr(self, 'canvas'): self.canvas.draw() def get_index_from_selection(self, dataset, selection): """Returns the index vector of current selection in given dim.""" if not selection: return [] ids = selection.get(self.current_dim, []) # current identifiers if not ids : return [] return dataset.get_indices(self.current_dim, ids) def on_key_press(self, event): if event.key == 'c': self._toggle_colorbar() def _toggle_colorbar(self): if self._colorbar == None: if self._mappable == None: logger.log('notice', 'No mappable in this plot') return if self._mappable._A != None: # we need colormapping # get axes original position self._ax_last_pos = self.axes.get_position() self._colorbar = self.fig.colorbar(self._mappable) self._colorbar.draw_all() self.canvas.draw() else: # remove colorbar # remove, axes, observers, colorbar instance, and restore viewlims cb, ax = self._mappable.colorbar self.fig.delaxes(ax) self._mappable.observers = [obs for obs in self._mappable.observers if obs !=self._colorbar] self._colorbar = None self.axes.set_position(self._ax_last_pos) self.canvas.draw() class LineViewPlot(Plot): """Line view plot with percentiles. A line view of vectors across a specified dimension of input dataset. No selection interaction is defined. Only support for 2d-arrays. input: -- major_axis : dim_number for line dim (see scipy.ndarray for axis def.) -- minor_axis : needs definition only for higher order arrays fixme: slow """ @plotlogger def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"): Plot.__init__(self, name) self.dataset = dataset self._data = dataset.asarray() if len(self._data.shape)==2 and not minor_axis: minor_axis = major_axis - 1 self.major_axis = major_axis self.minor_axis = minor_axis self.current_dim = self.dataset.get_dim_name(major_axis) #initial draw self.line_coll = None self.line_segs = [] x_axis = scipy.arange(self._data.shape[minor_axis]) for xi in range(self._data.shape[major_axis]): yi = self._data.take([xi], major_axis).ravel() self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)]) # draw background self._set_background(self.axes) # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) def _set_background(self, ax): """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median. """ if self._data.shape[self.minor_axis]<6: return # settings patch_color = 'b' #blue patch_lw = 0 #no edges patch_alpha = .15 # transparancy median_color = 'b' #blue median_width = 1.5 #linewidth percentiles = [0, 5, 25, 50, 75, 100] # ordinate xax = scipy.arange(self._data.shape[self.minor_axis]) #vertices verts_0 = [] #100,0 verts_1 = [] # 90,10 verts_2 = [] # 75,25 med = [] # add top vertices the low vertices (do i need an order?)#background for i in xax: prct = prctile(self._data.take([i], self.minor_axis), percentiles) verts_0.append((i, prct[0])) verts_1.append((i, prct[1])) verts_2.append((i, prct[2])) med.append(prct[3]) for i in xax[::-1]: prct = prctile(self._data.take([i], self.minor_axis), percentiles) verts_0.append((i, prct[-1])) verts_1.append((i, prct[-2])) verts_2.append((i, prct[-3])) # make polygons from vertices bck0 = Polygon(verts_0, alpha=patch_alpha, lw=patch_lw, facecolor=patch_color) bck1 = Polygon(verts_1, alpha=patch_alpha, lw=patch_lw, facecolor=patch_color) bck2 = Polygon(verts_2, alpha=patch_alpha, lw=patch_lw, facecolor=patch_color) # add polygons to axes ax.add_patch(bck0) ax.add_patch(bck1) ax.add_patch(bck2) # median line ax.plot(xax, med, median_color, linewidth=median_width) def set_current_selection(self, selection): """Draws the current selection. """ index = self.get_index_from_selection(self.dataset, selection) if self.line_coll: self.axes.collections.remove(self.line_coll) segs = [self.line_segs[i] for i in index] self.line_coll = LineCollection(segs, colors=(1,0,0,1)) self.axes.add_collection(self.line_coll) #draw if self._use_blit: if self._background is None: self._background = self.canvas.copy_from_bbox(self.axes.bbox) self.canvas.restore_region(self._background) self.axes.draw_artist(self.lines) self.canvas.blit() else: self.canvas.draw() class ScatterMarkerPlot(Plot): """The ScatterMarkerPlot is faster than regular scatterplot, but has no color and size options.""" @plotlogger def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, s=6, name="Scatter plot"): Plot.__init__(self, name) self.current_dim = id_dim self.dataset_1 = dataset_1 self.ms = s x_index = dataset_1[sel_dim][id_1] y_index = dataset_2[sel_dim][id_2] self.xaxis_data = dataset_1._array[:, x_index] self.yaxis_data = dataset_2._array[:, y_index] # init draw self._selection_line = None self.line = self.axes.plot(self.xaxis_data, self.yaxis_data, 'o', markeredgewidth=0, markersize=s) self.axes.axhline(0, color='k', lw=1., zorder=1) self.axes.axvline(0, color='k', lw=1., zorder=1) def rectangle_select_callback(self, x1, y1, x2, y2, key): ydata = self.yaxis_data xdata = self.xaxis_data # find indices of selected area if x1>x2: x1, x2 = x2, x1 if y1>y2: y1, y2 = y2, y1 assert x1<=x2 assert y1<=y2 index = scipy.nonzero((xdata>x1) & (xdatay1) & (ydata clean redraw ') self.set_current_selection(None) self._background = self.canvas.copy_from_bbox(self.axes.bbox) self.set_current_selection(self._current_selection) else: self._background = None def init_draw(self): lw = scipy.zeros(self.xaxis_data.shape) self.sc = self.axes.scatter(self.xaxis_data, self.yaxis_data, s=self.s, c=self.c, linewidth=lw, zorder=3, **self.kw) self._mappable = self.sc self.axes.axhline(0, color='k', lw=1., zorder=1) self.axes.axvline(0, color='k', lw=1., zorder=1) self.selection_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data, alpha=0, c='w',s=self.s, linewidth=0, zorder=4) self._background = self.canvas.copy_from_bbox(self.axes.bbox) def is_mappable_with(self, obj): """Returns True if dataset/selection is mappable with this plot. """ if isinstance(obj, fluents.dataset.Dataset): if self.current_dim in obj.get_dim_name() \ and obj.asarray().shape[0] == self.xaxis_data.shape[0]: return True elif isinstance(obj, fluents.dataset.Selection): if self.current_dim in obj.get_dim_name(): print "Selection is mappable" return True else: return False def _update_color_from_dataset(self, data): """Updates the facecolors from a dataset. """ array = data.asarray() #only support for 2d-arrays: try: m, n = array.shape except: raise ValueError, "No support for more than 2 dimensions." # is dataset a vector or matrix? if not n==1: # we have a category dataset if isinstance(data, fluents.dataset.CategoryDataset): map_vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1) else: map_vec = array.sum(1) else: map_vec = array.ravel() # update facecolors self.sc.set_array(map_vec) self.sc.set_clim(map_vec.min(), map_vec.max()) self.sc.update_scalarmappable() #sets facecolors from array self.canvas.draw() def rectangle_select_callback(self, x1, y1, x2, y2, key): ydata = self.yaxis_data xdata = self.xaxis_data # find indices of selected area if x1>x2: x1, x2 = x2, x1 if y1>y2: y1, y2 = y2, y1 assert x1<=x2 assert y1<=y2 index = scipy.nonzero((xdata>x1) & (xdatay1) & (ydata 0: linewidth.put(1, index) self.selection_collection.set_linewidth(linewidth) if self._use_blit and len(index)>0 : if self._background is None: self._background = self.canvas.copy_from_bbox(self.axes.bbox) self.canvas.restore_region(self._background) self.axes.draw_artist(self.selection_collection) self.canvas.blit() else: self.canvas.draw() class ImagePlot(Plot): @plotlogger def __init__(self, dataset, **kw): Plot.__init__(self, kw.get('name', 'Image Plot')) self.dataset = dataset # Initial draw self.axes.grid(False) self.axes.imshow(dataset.asarray(), interpolation='nearest') self.axes.axis('tight') self._mappable = self.axes.images[0] # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) class HistogramPlot(Plot): """ Histogram plot. If dataset is 1-dim the current_dim is set and selections may be performed. For dataset> 1.dim the histogram is over all values and selections are not defined,""" @plotlogger def __init__(self, dataset, **kw): Plot.__init__(self, kw['name']) self.dataset = dataset self._data = dataset.asarray() # If dataset is 1-dim we may do selections if dataset.shape[0]==1: self.current_dim = dataset.get_dim_name(1) if dataset.shape[1]==1: self.current_dim = dataset.get_dim_name(0) # Initial draw self.axes.grid(False) bins = min(self._data.size, 20) count, lims, self.patches = self.axes.hist(self._data, bins=bins) # Add identifiers to the individual patches if self.current_dim != None: for i, patch in enumerate(self.patches): if i==len(self.patches)-1: end_lim = self._data.max() + 1 else: end_lim = lims[i+1] bool_ind = scipy.bitwise_and(self._data>=lims[i], self._data<=end_lim) patch.index = scipy.where(bool_ind)[0] if self.current_dim==None: # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) def rectangle_select_callback(self, x1, y1, x2, y2, key): if self.current_dim == None: return # make (x1, y1) the lower left corner if x1>x2: x1, x2 = x2, x1 if y1>y2: y1, y2 = y2, y1 self.active_patches = [] for patch in self.patches: xmin = patch.xy[0] xmax = xmin + patch.get_width() ymin, ymax = 0, patch.get_height() if xmax>x1 and xmin y2 or ymax>y1): self.active_patches.append(patch) if not self.active_patches: return ids = set() for patch in self.active_patches: ids.update(self.dataset.get_identifiers(self.current_dim, patch.index)) ids = self.update_selection(ids, key) self.selection_listener(self.current_dim, ids) def lasso_select_callback(self, verts, key): if self.current_dim == None: return self.active_patches = [] for patch in self.patches: if scipy.any(points_inside_poly(verts, patch.get_verts())): self.active_patches.append(patch) if not self.active_patches: return ids = set() for patch in self.active_patches: ids.update(self.dataset.get_identifiers(self.current_dim, patch.index)) ids = self.update_selection(ids, key) self.selection_listener(self.current_dim, ids) def set_current_selection(self, selection): index = self.get_index_from_selection(self.dataset, selection) for patch in self.patches: patch.set_facecolor('b') for patch in self.patches: if scipy.intersect1d(patch.index, index).size>1: patch.set_facecolor('r') self.canvas.draw() class BarPlot(Plot): """Bar plot. Ordinary bar plot for (column) vectors. For matrices there is one color for each row. """ @plotlogger def __init__(self, dataset, **kw): Plot.__init__(self, kw.get('name', 'Bar Plot')) self.dataset = dataset # Initial draw self.axes.grid(False) n, m = dataset.shape if m>1: clrs = matplotlib.cm.ScalarMappable().to_rgba(range(n)) for i, row in enumerate(dataset.asarray()): left = scipy.arange(i+1, m*n+1, n) height = row color = clrs[i] c = (color[0], color[1], color[2]) self.axes.bar(left, height,color=c) else: height = dataset.asarray().ravel() left = scipy.arange(1, n, 1) self.axes.bar(left, height) # Disable selection modes self._toolbar.freeze_button.set_sensitive(False) self._toolbar.set_mode_sensitive('select', False) self._toolbar.set_mode_sensitive('lassoselect', False) class NetworkPlot(Plot): @plotlogger def __init__(self, dataset, pos=None, nodecolor='b', nodesize=40, prog='neato', with_labels=False, name='Network Plot'): Plot.__init__(self, name) self.dataset = dataset self.graph = dataset.asnetworkx() self._prog = prog self._pos = pos self._nodesize = nodesize self._nodecolor = nodecolor self._with_labels = with_labels self.current_dim = self.dataset.get_dim_name(0) if not self._pos: self._pos = networkx.graphviz_layout(self.graph, self._prog) self._xy = scipy.asarray([self._pos[node] for node in self.dataset.get_identifiers(self.current_dim, sorted=True)]) self.xaxis_data = self._xy[:,0] self.yaxis_data = self._xy[:,1] # Initial draw self.default_props = {'nodesize' : 50, 'nodecolor' : 'blue', 'edge_color' : 'gray', 'edge_color_selected' : 'red'} self.node_collection = None self.edge_collection = None self.node_labels = None lw = scipy.zeros(self.xaxis_data.shape) self.node_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data, s=self._nodesize, c=self._nodecolor, linewidth=lw, zorder=3) self._mappable = self.node_collection # selected nodes is a transparent graph that adjust node-edge visibility # according to the current selection needed to get get the selected # nodes 'on top' as zorder may not be defined individually self.selected_nodes = self.axes.scatter(self.xaxis_data, self.yaxis_data, s=self._nodesize, c=self._nodecolor, linewidth=lw, zorder=4, alpha=0) edge_color = self.default_props['edge_color'] self.edge_collection = networkx.draw_networkx_edges(self.graph, self._pos, ax=self.axes, edge_color=edge_color) # edge color rgba-arrays self._edge_color_rgba = scipy.repmat(ColorConverter().to_rgba(edge_color), self.graph.number_of_edges(),1) self._edge_color_selected = ColorConverter().to_rgba(self.default_props['edge_color_selected']) if self._with_labels: self.node_labels = networkx.draw_networkx_labels(self.graph, self._pos, ax=self.axes) # remove axes, frame and grid self.axes.set_xticks([]) self.axes.set_yticks([]) self.axes.grid(False) self.axes.set_frame_on(False) self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1) def rectangle_select_callback(self, x1, y1, x2, y2, key): ydata = self.yaxis_data xdata = self.xaxis_data # find indices of selected area if x1>x2: x1, x2 = x2, x1 if y1>y2: y1, y2 = y2, y1 assert x1<=x2 assert y1<=y2 index = scipy.nonzero((xdata>x1) & (xdatay1) & (ydata 0: linewidth.put(2, index) idents = selection[self.current_dim] edge_index = [i for i,edge in enumerate(self.graph.edges()) if (edge[0] in idents and edge[1] in idents)] if len(edge_index)>0: for i in edge_index: edge_color_rgba[i,:] = self._edge_color_selected self._A = None self.edge_collection._colors = edge_color_rgba self.selected_nodes.set_linewidth(linewidth) self.canvas.draw() class VennPlot(Plot): @plotlogger def __init__(self, name="Venn diagram"): Plot.__init__(self, name) # init draw self._init_bck() for c in self._venn_patches: self.axes.add_patch(c) for mrk in self._markers: self.axes.add_patch(mrk) self.axes.set_xlim([-3, 3]) self.axes.set_ylim([-2.5, 3.5]) self._last_active = set() self.axes.set_xticks([]) self.axes.set_yticks([]) self.axes.axis('equal') self.axes.grid(False) self.axes.set_frame_on(False) self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1) def _init_bck(self): res = 50 a = .5 r = 1.5 mr = .2 self.c1 = c1 = Circle((-1,0), radius=r, alpha=a, facecolor='b') self.c2 = c2 = Circle((1,0), radius=r, alpha=a, facecolor='r') self.c3 = c3 = Circle((0, scipy.sqrt(3)), radius=r, alpha=a, facecolor='g') self.c1marker = Circle((-1.25, -.25), radius=mr, facecolor='y', alpha=0) self.c2marker = Circle((1.25, -.25), radius=mr, facecolor='y', alpha=0) self.c3marker = Circle((0, scipy.sqrt(3)+.25), radius=mr, facecolor='y', alpha=0) self.c1c2marker = Circle((0, -.15), radius=mr, facecolor='y', alpha=0) self.c1c3marker = Circle((-scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0) self.c2c3marker = Circle((scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0) self.c1c2c3marker = Circle((0, .6), radius=mr, facecolor='y', alpha=0) c1.elements = set(['a', 'b', 'c', 'f']) c2.elements = set(['a', 'c', 'd', 'e']) c3.elements = set(['a', 'e', 'f', 'g']) self.active_elements = set() self.all_elements = c1.elements.union(c2.elements).union(c3.elements) c1.active = False c2.active = False c3.active = False c1.name = 'Blue' c2.name = 'Red' c3.name = 'Green' self._venn_patches = [c1, c2, c3] self._markers = [self.c1marker, self.c2marker, self.c3marker, self.c1c2marker, self.c1c3marker, self.c2c3marker, self.c1c2c3marker] self._tot_label = 'Tot: ' + str(len(self.all_elements)) self._sel_label = 'Sel: ' + str(len(self.active_elements)) self._legend = self.axes.legend((self._tot_label, self._sel_label), loc='upper right') def set_selection(self, selection, patch=None): if patch: patch.selection = selection else: selection_set = False for patch in self._venn_patches: if len(patch.elements)==0: patch.elements = selection selection_set = True if not selection_set: self.venn_patches[0].elements = selection def lasso_select_callback(self, verts, key=None): if verts==None: verts = (self._event.xdata, self._event.ydata) if key!='shift': for m in self._markers: m.set_alpha(0) self._patches_within_verts(verts, key) active = [i.active for i in self._venn_patches] if active==[True, False, False]: self.c1marker.set_alpha(1) self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements)) elif active== [False, True, False]: self.c2marker.set_alpha(1) self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements)) elif active== [False, False, True]: self.c3marker.set_alpha(1) self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements)) elif active==[True, True, False]: self.c1c2marker.set_alpha(1) self.active_elements = self.c1.elements.intersection(self.c2.elements) elif active==[True, False, True]: self.c1c3marker.set_alpha(1) self.active_elements = self.c1.elements.intersection(self.c3.elements) elif active==[False, True, True]: self.c2c3marker.set_alpha(1) self.active_elements = self.c2.elements.intersection(self.c3.elements) elif active==[True, True, True]: self.c1c2c3marker.set_alpha(1) self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements) if key=='shift': self.active_elements = self.active_elements.union(self._last_active) self._last_active = self.active_elements.copy() self._sel_label = 'Sel: ' + str(len(self.active_elements)) self._legend.texts[1].set_text(self._sel_label) self.axes.figure.canvas.draw() def rectangle_select_callback(self, x1, y1, x2, y2, key): verts = [(x1, y1), (x2, y2)] if key!='shift': for m in self._markers: m.set_alpha(0) self._patches_within_verts(verts, key) active = [i.active for i in self._venn_patches] if active==[True, False, False]: self.c1marker.set_alpha(1) self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements)) elif active== [False, True, False]: self.c2marker.set_alpha(1) self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements)) elif active== [False, False, True]: self.c3marker.set_alpha(1) self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements)) elif active==[True, True, False]: self.c1c2marker.set_alpha(1) self.active_elements = self.c1.elements.intersection(self.c2.elements) elif active==[True, False, True]: self.c1c3marker.set_alpha(1) self.active_elements = self.c1.elements.intersection(self.c3.elements) elif active==[False, True, True]: self.c2c3marker.set_alpha(1) self.active_elements = self.c2.elements.intersection(self.c3.elements) elif active==[True, True, True]: self.c1c2c3marker.set_alpha(1) self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements) if key=='shift': self.active_elements = self.active_elements.union(self._last_active) self._last_active = self.active_elements.copy() self._sel_label = 'Sel: ' + str(len(self.active_elements)) self._legend.texts[1].set_text(self._sel_label) self.axes.figure.canvas.draw() def _patches_within_verts(self, verts, key): xy = scipy.array(verts).mean(0) for venn_patch in self._venn_patches: venn_patch.active = False if self._distance(venn_patch.center,xy)