This commit is contained in:
Arnar Flatberg 2007-02-27 15:05:21 +00:00
parent f73a6db0ee
commit 4de65f1085

View File

@ -450,15 +450,22 @@ class Plot (View):
View.__init__(self, title)
logger.log('debug', 'plot %s init' %title)
self.selection_listener = None
self.fig = Figure()
self.canvas = FigureCanvasGTKAgg(self.fig)
self._toolbar = PlotToolbar(self)
self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK)
self.current_dim = None
self._current_selection = None
self._background = None
self._frozen = False
self.use_blit = False
self._init_mpl()
def _init_mpl(self):
# init matplotlib related stuff
self._background = None
self._use_blit = False
self.fig = Figure()
self.canvas = FigureCanvasGTKAgg(self.fig)
self.axes = self.fig.gca()
self._toolbar = PlotToolbar(self)
self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK)
self.add(self.canvas)
self.canvas.show()
def set_frozen(self, frozen):
"""A frozen plot will not be updated when the current
@ -550,45 +557,33 @@ class LineViewPlot(Plot):
-- major_axis : dim_number for line dim (see scipy.ndarray for axis def.)
-- minor_axis : needs definition only for higher order arrays
fixme: slow (cant get linecollection and blit to work)
fixme: slow
"""
def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"):
Plot.__init__(self, name)
self.dataset = dataset
self._data = dataset.asarray()
if len(self._data.shape)==2 and not minor_axis:
minor_axis = major_axis - 1
self.major_axis = major_axis
self.minor_axis = minor_axis
Plot.__init__(self, name)
self.use_blit = False #fixme: blitting should work
self.current_dim = self.dataset.get_dim_name(major_axis)
self.line_coll = None
# make axes
self.axes = self.fig.add_subplot(111)
#initial draw
x_axis = scipy.arange(self._data.shape[minor_axis])
self.line_coll = None
self.line_segs = []
x_axis = scipy.arange(self._data.shape[minor_axis])
for xi in range(self._data.shape[major_axis]):
yi = self._data.take([xi], major_axis).ravel()
self.line_segs.append([(xx,yy) for xx,yy in izip(x_axis, yi)])
# draw background
self._set_background(self.axes)
# add canvas
self.add(self.canvas)
self.canvas.show()
# add toolbar
self._toolbar = PlotToolbar(self)
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
#self.canvas.mpl_connect('resize_event', self.clear_background)
def _set_background(self, ax):
"""Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median.
@ -652,7 +647,7 @@ class LineViewPlot(Plot):
self.axes.add_collection(self.line_coll)
#draw
if self.use_blit:
if self._use_blit:
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
@ -669,20 +664,19 @@ class ScatterMarkerPlot(Plot):
def __init__(self, dataset_1, dataset_2, id_dim, sel_dim,
id_1, id_2, s=6, name="Scatter plot"):
Plot.__init__(self, name)
self.axes = self.fig.add_subplot(111)
self.axes.axhline(0, color='k', lw=1., zorder=1)
self.axes.axvline(0, color='k', lw=1., zorder=1)
self.current_dim = id_dim
self.dataset_1 = dataset_1
self.ms = s
self._selection_line = None
x_index = dataset_1[sel_dim][id_1]
y_index = dataset_2[sel_dim][id_2]
self.xaxis_data = dataset_1._array[:, x_index]
self.yaxis_data = dataset_2._array[:, y_index]
# init draw
self._selection_line = None
self.line = self.axes.plot(self.xaxis_data, self.yaxis_data, 'o', markeredgewidth=0, markersize=s)
self.add(self.canvas)
self.canvas.show()
self.axes.axhline(0, color='k', lw=1., zorder=1)
self.axes.axvline(0, color='k', lw=1., zorder=1)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
ydata = self.yaxis_data
@ -726,7 +720,7 @@ class ScatterMarkerPlot(Plot):
markeredgewidth=1.0)
self.axes.add_line(self._selection_line)
if self.use_blit:
if self._use_blit:
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
@ -740,11 +734,11 @@ class ScatterMarkerPlot(Plot):
class ScatterPlot(Plot):
"""The ScatterPlot is slower than scattermarker, but has size option."""
def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, c='b', s=30, sel_dim_2=None, name="Scatter plot"):
Plot.__init__(self, name)
self.dataset_1 = dataset_1
self.s = s
self.c = c
Plot.__init__(self, name)
self.use_blit = False
self.current_dim = id_dim
x_index = dataset_1[sel_dim][id_1]
@ -754,15 +748,11 @@ class ScatterPlot(Plot):
y_index = dataset_2[sel_dim][id_2]
self.xaxis_data = dataset_1._array[:, x_index]
self.yaxis_data = dataset_2._array[:, y_index]
# labels
self._text_labels = None
# init draw
self.init_draw()
# add canvas to widget
self.add(self.canvas)
self.canvas.show()
# signals to enable correct use of blit
self.connect('zoom-changed', self.onzoom)
self.connect('pan-changed', self.onpan)
self.need_redraw = False
@ -781,7 +771,7 @@ class ScatterPlot(Plot):
self.clean_redraw()
def clean_redraw(self):
if self.use_blit == True:
if self._use_blit == True:
logger.log('notice', 'blit -> clean redraw ')
self.set_current_selection(None)
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
@ -790,7 +780,6 @@ class ScatterPlot(Plot):
self._background = None
def init_draw(self):
self.axes = self.fig.add_subplot(111)
lw = scipy.zeros(self.xaxis_data.shape)
self.sc = self.axes.scatter(self.xaxis_data, self.yaxis_data,
s=self.s, c=self.c, linewidth=lw, zorder=3)
@ -876,7 +865,7 @@ class ScatterPlot(Plot):
linewidth.put(1, index)
self.selection_collection.set_linewidth(linewidth)
if self.use_blit and len(index)>0 :
if self._use_blit and len(index)>0 :
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
@ -888,31 +877,19 @@ class ScatterPlot(Plot):
class ImagePlot(Plot):
def __init__(self, dataset, **kw):
Plot.__init__(self, kw.get('name', 'Image Plot'))
self.dataset = dataset
self.keywords = kw
Plot.__init__(self, kw['name'])
self.axes = self.fig.add_subplot(111)
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.grid(False)
# Initial draw
self.axes.imshow(dataset.asarray(), interpolation='nearest', aspect='auto')
# Add canvas and show
self.add(self.canvas)
self.canvas.show()
self.axes.grid(False)
self.axes.imshow(dataset.asarray(), interpolation='nearest')
self.axes.axis('tight')
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
def get_toolbar(self):
return self._toolbar
class HistogramPlot(Plot):
""" Histogram plot.
@ -924,8 +901,6 @@ class HistogramPlot(Plot):
Plot.__init__(self, kw['name'])
self.dataset = dataset
self._data = dataset.asarray()
self.axes = self.fig.add_subplot(111)
self.axes.grid(False)
# If dataset is 1-dim we may do selections
if dataset.shape[0]==1:
@ -933,7 +908,8 @@ class HistogramPlot(Plot):
if dataset.shape[1]==1:
self.current_dim = dataset.get_dim_name(0)
# Initial draw
bins = min(len(self.dataset[self.current_dim]), 20)
self.axes.grid(False)
bins = min(self._data.size, 20)
count, lims, self.patches = self.axes.hist(self._data, bins=bins)
# Add identifiers to the individual patches
@ -946,9 +922,6 @@ class HistogramPlot(Plot):
bool_ind = scipy.bitwise_and(self._data>=lims[i],
self._data<=end_lim)
patch.index = scipy.where(bool_ind)[0]
# Add canvas and show
self.add(self.canvas)
self.canvas.show()
if self.current_dim==None:
# Disable selection modes
@ -1010,14 +983,13 @@ class BarPlot(Plot):
Ordinary bar plot for (column) vectors.
For matrices there is one color for each row.
"""
def __init__(self, dataset, name):
def __init__(self, dataset, **kw):
Plot.__init__(self, kw.get('name', 'Bar Plot'))
self.dataset = dataset
n, m = dataset.shape
Plot.__init__(self, name)
self.axes = self.fig.add_subplot(111)
self.axes.grid(False)
# Initial draw
self.axes.grid(False)
n, m = dataset.shape
if m>1:
sm = cm.ScalarMappable()
clrs = sm.to_rgba(range(n))
@ -1032,157 +1004,110 @@ class BarPlot(Plot):
left = scipy.arange(1, n, 1)
self.axes.bar(left, height)
# Add canvas and show
self.add(self.canvas)
self.canvas.show()
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
def get_toolbar(self):
return self._toolbar
class NetworkPlot(Plot):
def __init__(self, dataset, **kw):
# Set member variables and call superclass' constructor
def __init__(self, dataset, pos=None, nodecolor='b', nodesize=40,
prog='neato', with_labels=False, name='Network Plot'):
Plot.__init__(self, name)
self.dataset = dataset
self.graph = dataset.asnetworkx()
self.dataset = dataset
self.keywords = kw
if not kw.has_key('name'):
kw['name'] = self.dataset.get_name()
Plot.__init__(self, kw['name'])
self._prog = prog
self._pos = pos
self._nodesize = nodesize
self._nodecolor = nodecolor
self._with_labels = with_labels
self.current_dim = self.dataset.get_dim_name(0)
if not kw.has_key('prog'):
kw['prog'] = 'neato'
if not kw.has_key('pos'):
kw['pos'] = networkx.graphviz_layout(self.graph, 'neato')
if not kw.has_key('nodelist'):
kw['nodelist'] = self.dataset.get_identifiers(self.current_dim, sorted=True)
if not kw.has_key('with_labels'):
kw['with_labels'] = True
# Keep node size and color as dicts for fast lookup
self.node_size = {}
if kw.has_key('node_size') and cbook.iterable(kw['node_size']):
kw.remove('node_size')
for id, size in zip(self.dataset[self.current_dim], kw['node_size']):
self.node_size[id] = size
else:
for id in dataset[self.current_dim]:
self.node_size[id] = 30
self.node_color = {}
if kw.has_key('node_color') and cbook.iterable(kw['node_color']):
kw.remove('node_color')
for id, color in zip(self.dataset[self.current_dim], kw['node_color']):
self.node_color[id] = color
else:
for id in self.dataset[self.current_dim]:
self.node_color[id] = 1.0
if not self._pos:
self._pos = networkx.graphviz_layout(self.graph, self._prog)
self._xy = scipy.asarray([self._pos[node] for node in self.dataset.get_identifiers(self.current_dim, sorted=True)])
self.xaxis_data = self._xy[:,0]
self.yaxis_data = self._xy[:,1]
self.axes = self.fig.add_subplot(111)
# Initial draw
self.default_props = {'nodesize': 50, 'nodecolor':'gray'}
self.node_collection = None
self.edge_collection = None
self.node_labels = None
lw = scipy.zeros(self.xaxis_data.shape)
self.node_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data,
s=self._nodesize,
c=self._nodecolor,
linewidth=lw,
zorder=3)
# selected nodes is a transparent graph that adjust edge-visibility
# according to the current selection
self.selected_nodes = self.axes.scatter(self.xaxis_data,
self.yaxis_data,
s=self._nodesize,
c=self._nodecolor,
linewidth=lw,
zorder=4,
alpha=0)
self.edge_collection = networkx.draw_networkx_edges(self.graph,
self._pos,
ax=self.axes,
edge_color='gray')
if self._with_labels:
self.node_labels = networkx.draw_networkx_labels(self.graph,
self._pos,
ax=self.axes)
# remove axes, frame and grid
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.grid(False)
self.axes.set_frame_on(False)
# Add canvas and show
self.add(self.canvas)
self.canvas.show()
# Initial draw
networkx.draw_networkx(self.graph, ax=self.axes, node_size=30, node_color='gray', **kw)
del kw['nodelist']
def get_toolbar(self):
return self._toolbar
def rectangle_select_callback(self, x1, y1, x2, y2, key):
pos = self.keywords['pos']
ydata = scipy.zeros((len(pos),), 'l')
xdata = scipy.zeros((len(pos),), 'l')
node_ids = []
c = 0
for name,(x,y) in pos.items():
node_ids.append(name)
xdata[c] = x
ydata[c] = y
c+=1
ydata = self.yaxis_data
xdata = self.xaxis_data
# find indices of selected area
if x1 > x2:
if x1>x2:
x1, x2 = x2, x1
if y1 > y2:
if y1>y2:
y1, y2 = y2, y1
assert x1<=x2
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = [node_ids[i] for i in index]
ids = self.dataset.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def lasso_select_callback(self, verts, key=None):
pos = self.keywords['pos']
xys = []
node_ids = []
c = 0
for name,(x,y) in pos.items():
node_ids.append(name)
xys.append((x,y))
c+=1
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
ids = [node_ids[i] for i in index]
ids = self.dataset.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
ids = selection[self.current_dim] # current identifiers
node_set = set(self.graph.nodes())
selected_nodes = list(ids.intersection(node_set))
unselected_nodes = list(node_set.difference(ids))
if self.node_color:
unselected_colors = [self.node_color[x] for x in unselected_nodes]
else:
unselected_colors = 'gray'
if self.node_size:
unselected_sizes = [self.node_size[x] for x in unselected_nodes]
selected_sizes = [self.node_size[x] for x in selected_nodes]
else:
unselected_sizes = 30
selected_sizes = 30
self.axes.collections = []
networkx.draw_networkx_edges(self.graph,
edge_list=self.graph.edges(),
ax=self.axes,
**self.keywords)
networkx.draw_networkx_labels(self.graph, **self.keywords)
if unselected_nodes:
networkx.draw_networkx_nodes(self.graph, nodelist=unselected_nodes, \
node_color='gray', node_size=unselected_sizes, ax=self.axes, **self.keywords)
if selected_nodes:
networkx.draw_networkx_nodes(self.graph, nodelist=selected_nodes, \
node_color='r', node_size=selected_sizes, ax=self.axes, **self.keywords)
self.axes.collections[-1].set_zorder(3)
linewidth = scipy.zeros(self.xaxis_data.shape, 'f')
index = self.get_index_from_selection(self.dataset, selection)
if len(index) > 0:
linewidth.put(2, index)
nodelist = selection[self.current_dim]
selected_edges = self.grap
self.selected_nodes.set_linewidth(linewidth)
self.canvas.draw()
class VennPlot(Plot):
def __init__(self, name="Venn diagram"):
Plot.__init__(self, name)
self.axes = self.fig.add_subplot(111)
self.axes.grid(0)
self._init_bck()
# init draw
for c in self._venn_patches:
self.axes.add_patch(c)
for mrk in self._markers:
@ -1192,14 +1117,10 @@ class VennPlot(Plot):
self._last_active = set()
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.grid(0)
self.axes.axis('equal')
self.axes.grid(False)
self.axes.set_frame_on(False)
# add canvas to widget
self.add(self.canvas)
self.canvas.show()
def _init_bck(self):
res = 50
a = .5
@ -1293,11 +1214,6 @@ class VennPlot(Plot):
self.axes.figure.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key):
"""event1 and event2 are the press and release events"""
#x1, y1 = event1.xdata, event1.ydata
#x2, y2 = event2.xdata, event2.ydata
#key = event1.key
verts = [(x1, y1), (x2, y2)]
if key!='shift':
for m in self._markers: