Projects/laydi
Projects
/
laydi
Archived
7
0
Fork 0

Barplot, lots of changes in scatter, disabled modes, removed labels, bugfix on lasso

This commit is contained in:
Arnar Flatberg 2007-01-31 12:02:11 +00:00
parent e716db3fd2
commit 088f180b5d
1 changed files with 272 additions and 47 deletions

View File

@ -14,7 +14,7 @@ from matplotlib.nxutils import points_inside_poly
from matplotlib.axes import Subplot, AxesImage from matplotlib.axes import Subplot, AxesImage
from matplotlib.figure import Figure from matplotlib.figure import Figure
from matplotlib import cm,cbook from matplotlib import cm,cbook
from pylab import Polygon from pylab import Polygon, axis, Circle
from matplotlib.collections import LineCollection from matplotlib.collections import LineCollection
from matplotlib.mlab import prctile from matplotlib.mlab import prctile
import networkx import networkx
@ -227,7 +227,11 @@ class ViewFrame (gtk.Frame):
if view.is_mappable_with(obj): if view.is_mappable_with(obj):
view._update_color_from_dataset(obj) view._update_color_from_dataset(obj)
# add selections below elif isinstance(obj, fluents.dataset.Selection):
view = self.get_view()
if view.is_mappable_with(obj):
view.selection_changed(self.current_dim, obj)
class MainView (gtk.Notebook): class MainView (gtk.Notebook):
@ -505,7 +509,7 @@ class LineViewPlot(Plot):
self.major_axis = major_axis self.major_axis = major_axis
self.minor_axis = minor_axis self.minor_axis = minor_axis
Plot.__init__(self, name) Plot.__init__(self, name)
self.use_blit = False #fixme: blitting does work self.use_blit = False #fixme: blitting should work
self.current_dim = self.dataset.get_dim_name(major_axis) self.current_dim = self.dataset.get_dim_name(major_axis)
# make axes # make axes
@ -533,6 +537,8 @@ class LineViewPlot(Plot):
def _set_background(self, ax): def _set_background(self, ax):
"""Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median. """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median.
""" """
if self._data.shape[self.minor_axis]<10:
return
# settings # settings
patch_color = 'b' #blue patch_color = 'b' #blue
patch_lw = 0 #no edges patch_lw = 0 #no edges
@ -577,6 +583,13 @@ class LineViewPlot(Plot):
# median line # median line
ax.plot(xax, med, median_color, linewidth=median_width) ax.plot(xax, med, median_color, linewidth=median_width)
# Disable selection modes
btn = self._toolbar.get_button('select')
btn.set_sensitive(False)
btn = self._toolbar.get_button('lassoselect')
btn.set_sensitive(False)
self._toolbar.freeze_button.set_sensitive(False)
def clear_background(self, event): def clear_background(self, event):
"""Callback on resize event. Clears the background. """Callback on resize event. Clears the background.
""" """
@ -690,6 +703,7 @@ class ScatterPlot(Plot):
Plot.__init__(self, name) Plot.__init__(self, name)
self.use_blit = False self.use_blit = False
self.ax = self.fig.add_subplot(111) self.ax = self.fig.add_subplot(111)
self._clean_bck = self.canvas.copy_from_bbox(self.ax.bbox)
self.current_dim = id_dim self.current_dim = id_dim
self.dataset_1 = dataset_1 self.dataset_1 = dataset_1
x_index = dataset_1[sel_dim][id_1] x_index = dataset_1[sel_dim][id_1]
@ -701,26 +715,31 @@ class ScatterPlot(Plot):
self.yaxis_data = dataset_2._array[:, y_index] self.yaxis_data = dataset_2._array[:, y_index]
lw = scipy.zeros(self.xaxis_data.shape) lw = scipy.zeros(self.xaxis_data.shape)
self.sc = sc = self.ax.scatter(self.xaxis_data, self.yaxis_data, self.sc = sc = self.ax.scatter(self.xaxis_data, self.yaxis_data,
s=s, c=c, linewidth=lw, s=s, c=c, linewidth=lw)
edgecolor='k', alpha=.8,
cmap=cm.jet)
if len(c)>1: if len(c)>1:
self.fig.colorbar(sc, fraction=.05) self.fig.colorbar(sc, fraction=.05)
self.ax.axhline(0, color='k', lw=1., zorder=1) self.ax.axhline(0, color='k', lw=1., zorder=1)
self.ax.axvline(0, color='k', lw=1., zorder=1) self.ax.axvline(0, color='k', lw=1., zorder=1)
# collection # labels
self.coll = self.ax.collections[0] self._text_labels = None
# add canvas to widget # add canvas to widget
self.add(self.canvas) self.add(self.canvas)
self.canvas.show() self.canvas.show()
def is_mappable_with(self, dataset): def is_mappable_with(self, obj):
"""Returns True if dataset is mappable with this plot. """Returns True if dataset/selection is mappable with this plot.
""" """
if self.current_dim in dataset.get_dim_name() and dataset.asarray().shape[0] == self.xaxis_data.shape[0]: if isinstance(obj, fluents.dataset.Dataset):
if self.current_dim in obj.get_dim_name() and obj.asarray().shape[0] == self.xaxis_data.shape[0]:
return True return True
elif isinstance(obj, fluents.dataset.Selection):
if self.current_dim in obj.get_dim_name():
print "Selection is mappable"
return True
else:
return False
def _update_color_from_dataset(self, data): def _update_color_from_dataset(self, data):
"""Updates the facecolors from a dataset. """Updates the facecolors from a dataset.
@ -730,7 +749,7 @@ class ScatterPlot(Plot):
try: try:
m, n = array.shape m, n = array.shape
except: except:
raise ValueError, "No support for more tha 2 dimensions." raise ValueError, "No support for more than 2 dimensions."
# is dataset a vector or matrix? # is dataset a vector or matrix?
if not n==1: if not n==1:
# we have a category dataset # we have a category dataset
@ -741,14 +760,10 @@ class ScatterPlot(Plot):
else: else:
map_vec = array.ravel() map_vec = array.ravel()
# normalise mapping vector
map_vec = map_vec - map_vec.min()
map_vec = map_vec/map_vec.max()
# update facecolors # update facecolors
self.sc._facecolors = self.sc.to_rgba(map_vec, self.sc._alpha) self.sc.set_array(map_vec)
# draw self.sc.set_clim(map_vec.min(), map_vec.max())
self.sc._A = None # mean hack self.sc.update_scalarmappable() #sets facecolors from array
self.ax.draw_artist(self.sc)
self.canvas.draw() self.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key): def rectangle_select_callback(self, x1, y1, x2, y2, key):
@ -775,13 +790,11 @@ class ScatterPlot(Plot):
ids = self.dataset_1.get_identifiers(self.current_dim, index) ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key) ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids) self.selection_listener(self.current_dim, ids)
self.canvas.widgetlock.release(self._lasso)
def set_current_selection(self, selection): def set_current_selection(self, selection):
ids = selection[self.current_dim] # current identifiers ids = selection[self.current_dim] # current identifiers
if len(ids)==0: if len(ids)==0:
return return
#self._toolbar.forward() #update data lims before draw
index = self.dataset_1.get_indices(self.current_dim, ids) index = self.dataset_1.get_indices(self.current_dim, ids)
if self.use_blit: if self.use_blit:
if self._background is None: if self._background is None:
@ -790,12 +803,13 @@ class ScatterPlot(Plot):
lw = scipy.zeros(self.xaxis_data.shape, 'f') lw = scipy.zeros(self.xaxis_data.shape, 'f')
if len(index)>0: if len(index)>0:
lw.put(2., index) lw.put(2., index)
self.coll.set_linewidth(lw) self.sc.set_linewidth(lw)
if self.use_blit: if self.use_blit:
self.ax.draw_artist(self.sc)
self.canvas.blit() self.canvas.blit()
self.ax.draw_artist(self.coll)
else: else:
print self.ax.lines
self.canvas.draw() self.canvas.draw()
@ -811,10 +825,6 @@ class ImagePlot(Plot):
self.ax.set_yticks([]) self.ax.set_yticks([])
self.ax.grid(False) self.ax.grid(False)
# FIXME: ax shouldn't be in kw at all
if kw.has_key('ax'):
kw.pop('ax')
# Initial draw # Initial draw
self.ax.imshow(dataset.asarray(), interpolation='nearest', aspect='auto') self.ax.imshow(dataset.asarray(), interpolation='nearest', aspect='auto')
@ -832,18 +842,13 @@ class ImagePlot(Plot):
def get_toolbar(self): def get_toolbar(self):
return self._toolbar return self._toolbar
class HistogramPlot(Plot): class HistogramPlot(Plot):
def __init__(self, dataset, **kw): def __init__(self, dataset, **kw):
self.dataset = dataset
self.keywords = kw
Plot.__init__(self, kw['name']) Plot.__init__(self, kw['name'])
self.ax = self.fig.add_subplot(111) self.ax = self.fig.add_subplot(111)
#self.ax.set_xticks([])
#self.ax.set_yticks([])
self.ax.grid(False) self.ax.grid(False)
# FIXME: ax shouldn't be in kw at all
# Initial draw # Initial draw
self.ax.hist(dataset.asarray(), bins=20) self.ax.hist(dataset.asarray(), bins=20)
@ -852,6 +857,56 @@ class HistogramPlot(Plot):
self.add(self.canvas) self.add(self.canvas)
self.canvas.show() self.canvas.show()
# Disable selection modes
btn = self._toolbar.get_button('select')
btn.set_sensitive(False)
btn = self._toolbar.get_button('lassoselect')
btn.set_sensitive(False)
self._toolbar.freeze_button.set_sensitive(False)
def get_toolbar(self):
return self._toolbar
class BarPlot(Plot):
"""Bar plot.
Ordinary bar plot for (column) vectors.
For matrices there is one color for each row.
"""
def __init__(self, dataset, name):
self.dataset = dataset
n, m = dataset.shape
Plot.__init__(self, name)
self.ax = self.fig.add_subplot(111)
self.ax.grid(False)
# Initial draw
if m>1:
sm = cm.ScalarMappable()
clrs = sm.to_rgba(range(n))
for i, row in enumerate(dataset.asarray()):
left = scipy.arange(i+1, m*n+1, n)
height = row
color = clrs[i]
c = (color[0], color[1], color[2])
self.ax.bar(left, height,color=c)
else:
height = dataset.asarray().ravel()
left = scipy.arange(1, n, 1)
self.ax.bar(left, height)
# Add canvas and show
self.add(self.canvas)
self.canvas.show()
# Disable selection modes
btn = self._toolbar.get_button('select')
btn.set_sensitive(False)
btn = self._toolbar.get_button('lassoselect')
btn.set_sensitive(False)
self._toolbar.freeze_button.set_sensitive(False)
def get_toolbar(self): def get_toolbar(self):
return self._toolbar return self._toolbar
@ -904,6 +959,7 @@ class NetworkPlot(Plot):
self.ax.set_xticks([]) self.ax.set_xticks([])
self.ax.set_yticks([]) self.ax.set_yticks([])
self.ax.grid(False) self.ax.grid(False)
self.ax.set_frame_on(False)
# FIXME: ax shouldn't be in kw at all # FIXME: ax shouldn't be in kw at all
if kw.has_key('ax'): if kw.has_key('ax'):
kw.pop('ax') kw.pop('ax')
@ -994,6 +1050,176 @@ class NetworkPlot(Plot):
self.canvas.draw() self.canvas.draw()
class VennPlot(Plot):
def __init__(self, name="Venn diagram"):
Plot.__init__(self, name)
self._ax = self.fig.add_subplot(111)
self._ax.grid(0)
self._init_bck()
for c in self._venn_patches:
self._ax.add_patch(c)
for mrk in self._markers:
self._ax.add_patch(mrk)
self._ax.set_xlim([-3, 3])
self._ax.set_ylim([-2.5, 3.5])
self._last_active = set()
self._ax.set_xticks([])
self._ax.set_yticks([])
self._ax.grid(0)
self._ax.axis('equal')
self._ax.set_frame_on(False)
# add canvas to widget
self.add(self.canvas)
self.canvas.show()
def _init_bck(self):
res = 50
a = .5
r = 1.5
mr = .2
self.c1 = c1 = Circle((-1,0), radius=r, alpha=a, facecolor='b', resolution=res)
self.c2 = c2 = Circle((1,0), radius=r, alpha=a, facecolor='r', resolution=res)
self.c3 = c3 = Circle((0, scipy.sqrt(3)), radius=r, alpha=a, facecolor='g', resolution=res)
self.c1marker = Circle((-1.25, -.25), radius=mr, facecolor='y', alpha=0)
self.c2marker = Circle((1.25, -.25), radius=mr, facecolor='y', alpha=0)
self.c3marker = Circle((0, scipy.sqrt(3)+.25), radius=mr, facecolor='y', alpha=0)
self.c1c2marker = Circle((0, -.15), radius=mr, facecolor='y', alpha=0)
self.c1c3marker = Circle((-scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
self.c2c3marker = Circle((scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
self.c1c2c3marker = Circle((0, .6), radius=mr, facecolor='y', alpha=0)
c1.elements = set(['a', 'b', 'c', 'f'])
c2.elements = set(['a', 'c', 'd', 'e'])
c3.elements = set(['a', 'e', 'f', 'g'])
self.active_elements = set()
self.all_elements = c1.elements.union(c2.elements).union(c3.elements)
c1.active = False
c2.active = False
c3.active = False
c1.name = 'Blue'
c2.name = 'Red'
c3.name = 'Green'
self._venn_patches = [c1, c2, c3]
self._markers = [self.c1marker, self.c2marker, self.c3marker,
self.c1c2marker, self.c1c3marker,
self.c2c3marker, self.c1c2c3marker]
self._tot_label = 'Tot: ' + str(len(self.all_elements))
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend = self._ax.legend((self._tot_label, self._sel_label),
loc='upper right')
def set_selection(self, selection, patch=None):
if patch:
patch.selection = selection
else:
selection_set = False
for patch in self._venn_patches:
if len(patch.elements)==0:
patch.elements = selection
selection_set = True
if not selection_set:
self.venn_patches[0].elements = selection
def lasso_select_callback(self, verts, key=None):
if verts==None:
print "ks"
verts = (self._event.xdata, self._event.ydata)
if key!='shift':
for m in self._markers:
m.set_alpha(0)
self._patches_within_verts(verts, key)
active = [i.active for i in self._venn_patches]
if active==[True, False, False]:
self.c1marker.set_alpha(1)
self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
elif active== [False, True, False]:
self.c2marker.set_alpha(1)
self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
elif active== [False, False, True]:
self.c3marker.set_alpha(1)
self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
elif active==[True, True, False]:
self.c1c2marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c2.elements)
elif active==[True, False, True]:
self.c1c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements)
elif active==[False, True, True]:
self.c2c3marker.set_alpha(1)
self.active_elements = self.c2.elements.intersection(self.c3.elements)
elif active==[True, True, True]:
self.c1c2c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)
if key=='shift':
self.active_elements = self.active_elements.union(self._last_active)
self._last_active = self.active_elements.copy()
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend.texts[1].set_text(self._sel_label)
self.canvas.widgetlock.release(self._lasso)
del self._lasso
self._ax.figure.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key):
"""event1 and event2 are the press and release events"""
#x1, y1 = event1.xdata, event1.ydata
#x2, y2 = event2.xdata, event2.ydata
#key = event1.key
verts = [(x1, y1), (x2, y2)]
if key!='shift':
for m in self._markers:
m.set_alpha(0)
self._patches_within_verts(verts, key)
active = [i.active for i in self._venn_patches]
if active==[True, False, False]:
self.c1marker.set_alpha(1)
self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
elif active== [False, True, False]:
self.c2marker.set_alpha(1)
self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
elif active== [False, False, True]:
self.c3marker.set_alpha(1)
self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
elif active==[True, True, False]:
self.c1c2marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c2.elements)
elif active==[True, False, True]:
self.c1c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements)
elif active==[False, True, True]:
self.c2c3marker.set_alpha(1)
self.active_elements = self.c2.elements.intersection(self.c3.elements)
elif active==[True, True, True]:
self.c1c2c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)
if key=='shift':
self.active_elements = self.active_elements.union(self._last_active)
self._last_active = self.active_elements.copy()
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend.texts[1].set_text(self._sel_label)
self._ax.figure.canvas.draw()
def _patches_within_verts(self, verts, key):
xy = scipy.array(verts).mean(0)
for venn_patch in self._venn_patches:
venn_patch.active = False
if self._distance(venn_patch.center,xy)<venn_patch.radius:
venn_patch.active = True
def _distance(self, (x1,y1),(x2,y2)):
return scipy.sqrt( (x2-x1)**2 + (y2-y1)**2 )
class PlotMode: class PlotMode:
"""A PlotMode object corresponds to a mouse mode in a plot. """A PlotMode object corresponds to a mouse mode in a plot.
@ -1250,13 +1476,12 @@ class SelectPlotMode2 (PlotMode):
self.plot._lasso = None self.plot._lasso = None
def _on_select(self, event): def _on_select(self, event):
if self.canvas.widgetlock.locked(): return if event.inaxes is None:
if event.inaxes is None: return logger.log('debug', 'Lasso select not in axes')
return
self.plot._lasso = Lasso(event.inaxes, (event.xdata, event.ydata), self.lasso_callback) self.plot._lasso = Lasso(event.inaxes, (event.xdata, event.ydata), self.lasso_callback)
self.plot._lasso.line.set_linewidth(1) self.plot._lasso.line.set_linewidth(1)
self.plot._lasso.line.set_linestyle('--') self.plot._lasso.line.set_linestyle('--')
# get a lock on the widget
self.canvas.widgetlock(self.plot._lasso)
self._event = event self._event = event
def lasso_callback(self, verts): def lasso_callback(self, verts):