This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
laydi/fluents/plots.py

1092 lines
42 KiB
Python
Raw Normal View History

import pygtk
2006-06-01 15:51:16 +02:00
import gobject
import gtk
2007-02-26 10:08:50 +01:00
2007-03-14 17:27:23 +01:00
import matplotlib
2007-02-26 10:08:50 +01:00
from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg
2007-01-15 14:47:18 +01:00
from matplotlib.nxutils import points_inside_poly
from matplotlib.figure import Figure
2006-09-18 19:23:34 +02:00
from matplotlib.collections import LineCollection
from matplotlib.patches import Polygon,Rectangle,Circle
from matplotlib.lines import Line2D
2006-09-18 19:23:34 +02:00
from matplotlib.mlab import prctile
from matplotlib.colors import ColorConverter
import networkx
2006-10-09 20:04:39 +02:00
import scipy
2007-07-24 14:19:13 +02:00
from numpy import matlib
2007-02-26 10:08:50 +01:00
import fluents
import logger
2007-02-27 17:28:03 +01:00
import view
2007-02-26 10:08:50 +01:00
2007-03-14 17:27:23 +01:00
def plotlogger(func, name=None):
def wrapped(parent, *args, **kw):
parent.__args = args
parent.__kw = kw
return func(parent, *args, **kw)
return wrapped
2007-03-14 17:27:23 +01:00
class Plot(view.View):
def __init__(self, title):
2007-02-27 17:28:03 +01:00
view.View.__init__(self, title)
logger.log('debug', 'plot %s init' %title)
self.selection_listener = None
2007-02-27 16:05:21 +01:00
self.current_dim = None
self._current_selection = None
self._frozen = False
self._init_mpl()
def _init_mpl(self):
# init matplotlib related stuff
self._background = None
2007-04-24 19:04:29 +02:00
self._colorbar = None
2007-04-24 19:16:34 +02:00
self._mappable = None
2007-02-27 16:05:21 +01:00
self._use_blit = False
2006-10-09 20:04:39 +02:00
self.fig = Figure()
2007-02-26 10:08:50 +01:00
self.canvas = FigureCanvasGTKAgg(self.fig)
2007-02-27 16:05:21 +01:00
self.axes = self.fig.gca()
2007-02-27 17:28:03 +01:00
self._toolbar = view.PlotToolbar(self)
2007-04-24 19:04:29 +02:00
self._key_press = self.canvas.mpl_connect(
'key_press_event', self.on_key_press)
2006-10-09 20:04:39 +02:00
self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK)
2007-02-27 16:05:21 +01:00
self.add(self.canvas)
self.canvas.show()
2007-04-24 19:04:29 +02:00
def set_frozen(self, frozen):
"""A frozen plot will not be updated when the current
selection is changed."""
self._frozen = frozen
2006-10-14 00:25:18 +02:00
if not frozen and self._current_selection != None:
self.set_current_selection(self._current_selection)
2006-10-09 20:04:39 +02:00
def get_title(self):
return self.title
def get_toolbar(self):
return self._toolbar
def selection_changed(self, dim_name, selection):
""" Selection observer handle.
A selection change in a plot is only drawn if:
1.) plot is sensitive to selections (not freezed)
2.) plot is visible (has a view)
3.) the selections dim_name is the plot's dimension.
"""
2006-10-14 00:25:18 +02:00
self._current_selection = selection
if self._frozen \
or not self.get_property('visible') \
or self.current_dim != dim_name:
return
2006-10-14 00:25:18 +02:00
self.set_current_selection(selection)
def set_selection_listener(self, listener):
"""Allow project to listen to selections.
The selection will propagate back to all plots through the
selection_changed() method. The listener will be called as
listener(dimension_name, ids).
"""
self.selection_listener = listener
2007-01-15 14:47:18 +01:00
def update_selection(self, ids, key=None):
"""Returns updated current selection from ids.
If a key is pressed we use the appropriate mode.
key map:
shift : union
control : intersection
"""
2007-08-14 18:10:35 +02:00
print "Inside update selection, key: %s" %key
if key == 'shift':
ids = set(ids).union(self._current_selection[self.current_dim])
elif key == 'control':
ids = set(ids).intersection(self._current_selection[self.current_dim])
return ids
def set_current_selection(self, selection):
"""Called whenever the plot should change the selection.
This method is a dummy method, so that specialized plots that have
no implemented selection can ignore selections alltogether.
"""
pass
def rectangle_select_callback(self, *args):
"""Overrriden in subclass."""
if hasattr(self, 'canvas'):
self.canvas.draw()
def lasso_select_callback(self, *args):
"""Overrriden in subclass."""
if hasattr(self, 'canvas'):
self.canvas.draw()
def get_index_from_selection(self, dataset, selection):
"""Returns the index vector of current selection in given dim."""
if not selection: return []
ids = selection.get(self.current_dim, []) # current identifiers
if not ids : return []
return dataset.get_indices(self.current_dim, ids)
2007-04-24 19:04:29 +02:00
def on_key_press(self, event):
if event.key == 'c':
self._toggle_colorbar()
def _toggle_colorbar(self):
if self._colorbar == None:
if self._mappable == None:
2007-04-24 19:16:34 +02:00
logger.log('notice', 'No mappable in this plot')
2007-04-24 19:04:29 +02:00
return
if self._mappable._A != None: # we need colormapping
# get axes original position
self._ax_last_pos = self.axes.get_position()
self._colorbar = self.fig.colorbar(self._mappable)
self._colorbar.draw_all()
self.canvas.draw()
else:
# remove colorbar
# remove, axes, observers, colorbar instance, and restore viewlims
cb, ax = self._mappable.colorbar
self.fig.delaxes(ax)
self._mappable.observers = [obs for obs in self._mappable.observers if obs !=self._colorbar]
self._colorbar = None
self.axes.set_position(self._ax_last_pos)
self.canvas.draw()
class LineViewPlot(Plot):
2006-10-17 15:58:33 +02:00
"""Line view plot with percentiles.
A line view of vectors across a specified dimension of input dataset.
No selection interaction is defined.
Only support for 2d-arrays.
input:
-- major_axis : dim_number for line dim (see scipy.ndarray for axis def.)
-- minor_axis : needs definition only for higher order arrays
2006-10-17 15:58:33 +02:00
2007-02-27 16:05:21 +01:00
fixme: slow
"""
2007-03-14 17:27:23 +01:00
@plotlogger
2007-07-26 20:26:50 +02:00
def __init__(self, dataset, major_axis=1, minor_axis=None, center=True,name="Line view"):
2007-02-27 16:05:21 +01:00
Plot.__init__(self, name)
self.dataset = dataset
2006-10-17 15:58:33 +02:00
self._data = dataset.asarray()
if len(self._data.shape)==2 and not minor_axis:
minor_axis = major_axis - 1
2006-10-17 15:58:33 +02:00
self.major_axis = major_axis
self.minor_axis = minor_axis
self.current_dim = self.dataset.get_dim_name(major_axis)
2007-07-26 20:26:50 +02:00
self.data_is_centered = False
self._mn_data = 0
if center and len(self._data.shape)==2:
if minor_axis==0:
self._mn_data = self._data.mean(minor_axis)
else:
self._mn_data = self._data.mean(minor_axis)[:,newaxis]
self._data = self._data - self._mn_data
self.data_is_centered = True
#initial line collection
2007-02-27 16:05:21 +01:00
self.line_coll = None
2007-07-26 20:26:50 +02:00
self.make_lines()
2006-10-17 15:58:33 +02:00
# draw background
2007-07-26 20:26:50 +02:00
self.set_background()
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
2007-07-26 20:26:50 +02:00
def make_lines(self):
"""Creates one line for each item along major axis."""
if self.line_coll: # remove any previous selection lines, if any
self.axes.collections.remove(self.line_coll)
self.line_coll = None
self.line_segs = []
x_axis = scipy.arange(self._data.shape[self.minor_axis])
for xi in range(self._data.shape[self.major_axis]):
yi = self._data.take([xi], self.major_axis).ravel()
self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)])
2006-10-17 15:58:33 +02:00
2007-07-26 20:26:50 +02:00
def set_background(self):
2006-10-17 15:58:33 +02:00
"""Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median.
"""
2007-04-24 19:04:29 +02:00
if self._data.shape[self.minor_axis]<6:
2007-07-26 20:26:50 +02:00
return
# clean old patches if any
if len(self.axes.patches)>0:
self.axes.patches = []
# clean old lines (median) if any
if len(self.axes.lines)>0:
self.axes.lines = []
2006-10-17 15:58:33 +02:00
# settings
patch_color = 'b' #blue
patch_lw = 0 #no edges
patch_alpha = .15 # transparancy
median_color = 'b' #blue
median_width = 1.5 #linewidth
percentiles = [0, 5, 25, 50, 75, 100]
# ordinate
xax = scipy.arange(self._data.shape[self.minor_axis])
#vertices
2006-09-18 19:23:34 +02:00
verts_0 = [] #100,0
verts_1 = [] # 90,10
verts_2 = [] # 75,25
2007-04-24 19:04:29 +02:00
med = []
2006-10-17 15:58:33 +02:00
# add top vertices the low vertices (do i need an order?)#background
2006-09-18 19:23:34 +02:00
for i in xax:
2006-10-17 15:58:33 +02:00
prct = prctile(self._data.take([i], self.minor_axis), percentiles)
verts_0.append((i, prct[0]))
verts_1.append((i, prct[1]))
verts_2.append((i, prct[2]))
med.append(prct[3])
2006-09-18 19:23:34 +02:00
for i in xax[::-1]:
2006-10-17 15:58:33 +02:00
prct = prctile(self._data.take([i], self.minor_axis), percentiles)
verts_0.append((i, prct[-1]))
verts_1.append((i, prct[-2]))
verts_2.append((i, prct[-3]))
# make polygons from vertices
bck0 = Polygon(verts_0, alpha=patch_alpha, lw=patch_lw,
facecolor=patch_color)
bck1 = Polygon(verts_1, alpha=patch_alpha, lw=patch_lw,
facecolor=patch_color)
bck2 = Polygon(verts_2, alpha=patch_alpha, lw=patch_lw,
facecolor=patch_color)
# add polygons to axes
2007-07-26 20:26:50 +02:00
self.axes.add_patch(bck0)
self.axes.add_patch(bck1)
self.axes.add_patch(bck2)
2006-10-17 15:58:33 +02:00
# median line
2007-07-26 20:26:50 +02:00
self.axes.plot(xax, med, median_color, linewidth=median_width)
2006-04-27 13:03:11 +02:00
2007-07-26 20:26:50 +02:00
# set y-limits
padding = 0.1
self.axes.set_ylim([self._data.min() - padding, self._data.max() + padding])
def set_current_selection(self, selection):
2006-10-17 15:58:33 +02:00
"""Draws the current selection.
"""
index = self.get_index_from_selection(self.dataset, selection)
2006-08-30 15:39:32 +02:00
if self.line_coll:
self.axes.collections.remove(self.line_coll)
2006-10-17 15:58:33 +02:00
segs = [self.line_segs[i] for i in index]
self.line_coll = LineCollection(segs, colors=(1,0,0,1))
self.axes.add_collection(self.line_coll)
2006-10-17 15:58:33 +02:00
#draw
2007-02-27 16:05:21 +01:00
if self._use_blit:
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
self.axes.draw_artist(self.lines)
2006-10-17 15:58:33 +02:00
self.canvas.blit()
else:
self.canvas.draw()
2006-04-27 13:03:11 +02:00
class ScatterMarkerPlot(Plot):
"""The ScatterMarkerPlot is faster than regular scatterplot, but
has no color and size options."""
2007-03-14 17:27:23 +01:00
@plotlogger
def __init__(self, dataset_1, dataset_2, id_dim, sel_dim,
id_1, id_2, s=6, name="Scatter plot"):
Plot.__init__(self, name)
self.current_dim = id_dim
self.dataset_1 = dataset_1
self.ms = s
x_index = dataset_1[sel_dim][id_1]
y_index = dataset_2[sel_dim][id_2]
2006-10-09 20:04:39 +02:00
self.xaxis_data = dataset_1._array[:, x_index]
self.yaxis_data = dataset_2._array[:, y_index]
2007-02-27 16:05:21 +01:00
# init draw
self._selection_line = None
2007-07-02 11:54:56 +02:00
self.line = self.axes.plot(self.xaxis_data,
self.yaxis_data, 'o',
markeredgewidth=0,
markersize=s)
2007-02-27 16:05:21 +01:00
self.axes.axhline(0, color='k', lw=1., zorder=1)
self.axes.axvline(0, color='k', lw=1., zorder=1)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
ydata = self.yaxis_data
xdata = self.xaxis_data
2006-04-25 12:08:12 +02:00
# find indices of selected area
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
assert x1<=x2
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
2007-01-15 14:47:18 +01:00
def lasso_select_callback(self, verts, key=None):
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
#remove old selection
if self._selection_line:
self.axes.lines.remove(self._selection_line)
index = self.get_index_from_selection(self.dataset_1, selection)
if len(index)==0:
# no selection
self.canvas.draw()
self._selection_line = None
2006-10-06 12:20:53 +02:00
return
xdata_new = self.xaxis_data.take(index) #take data
ydata_new = self.yaxis_data.take(index)
self._selection_line = Line2D(xdata_new, ydata_new
,marker='o', markersize=self.ms,
linewidth=0, markerfacecolor='r',
markeredgewidth=1.0)
self.axes.add_line(self._selection_line)
2007-02-27 16:05:21 +01:00
if self._use_blit:
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
if self.selection_line:
self.axes.draw_artist(self._selection_line)
2006-10-06 12:20:53 +02:00
self.canvas.blit()
else:
self.canvas.draw()
2006-08-30 15:39:32 +02:00
class ScatterPlot(Plot):
2006-10-06 12:20:53 +02:00
"""The ScatterPlot is slower than scattermarker, but has size option."""
2007-03-14 17:27:23 +01:00
@plotlogger
def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, c='b', s=30, sel_dim_2=None, name="Scatter plot", **kw):
"""Initializes a scatter plot.
"""
2007-02-27 16:05:21 +01:00
Plot.__init__(self, name)
self.dataset_1 = dataset_1
self.s = s
self.c = c
2007-03-14 17:27:23 +01:00
self.kw = kw
self.current_dim = id_dim
self._map_ids = dataset_1.get_identifiers(id_dim, sorted=True)
x_index = dataset_1[sel_dim][id_1]
2006-10-06 12:20:53 +02:00
if sel_dim_2:
y_index = dataset_2[sel_dim_2][id_2]
else:
y_index = dataset_2[sel_dim][id_2]
2006-10-09 20:04:39 +02:00
self.xaxis_data = dataset_1._array[:, x_index]
self.yaxis_data = dataset_2._array[:, y_index]
2007-02-27 16:05:21 +01:00
# init draw
self.init_draw()
2007-02-27 16:05:21 +01:00
# signals to enable correct use of blit
self.connect('zoom-changed', self.onzoom)
self.connect('pan-changed', self.onpan)
self.need_redraw = False
self.canvas.mpl_connect('resize_event', self.onresize)
def onzoom(self, widget, mode):
#logger.log('notice', 'Zoom in widget: %s' %widget)
self.clean_redraw()
def onpan(self, widget, mode):
#logger.log('notice', 'Pan in widget: %s' %widget)
self.clean_redraw()
def onresize(self, widget):
#logger.log('notice', 'resize event')
self.clean_redraw()
def clean_redraw(self):
2007-02-27 16:05:21 +01:00
if self._use_blit == True:
#logger.log('notice', 'blit -> clean redraw ')
self.set_current_selection(None)
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.set_current_selection(self._current_selection)
else:
self._background = None
def init_draw(self):
lw = scipy.zeros(self.xaxis_data.shape)
self.sc = self.axes.scatter(self.xaxis_data, self.yaxis_data,
2007-03-14 17:27:23 +01:00
s=self.s, c=self.c, linewidth=lw,
zorder=3, **self.kw)
2007-04-24 19:04:29 +02:00
self._mappable = self.sc
self.axes.axhline(0, color='k', lw=1., zorder=1)
self.axes.axvline(0, color='k', lw=1., zorder=1)
self.selection_collection = self.axes.scatter(self.xaxis_data,
self.yaxis_data,
alpha=0,
c='w',s=self.s,
edgecolor='r',
linewidth=0,
zorder=4)
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
def is_mappable_with(self, obj):
"""Returns True if dataset/selection is mappable with this plot.
"""
if isinstance(obj, fluents.dataset.Dataset):
if self.current_dim in obj.get_dim_name():
return True
elif isinstance(obj, fluents.dataset.Selection):
if self.current_dim in obj.get_dim_name():
return True
else:
return False
def _update_color_from_dataset(self, data):
"""Updates the facecolors from a dataset.
"""
array = data.asarray()
#only support for 2d-arrays:
try:
m, n = array.shape
except:
raise ValueError, "No support for more than 2 dimensions."
# is dataset a vector or matrix?
if not n==1:
# we have a category dataset
if isinstance(data, fluents.dataset.CategoryDataset):
vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1)
else:
vec = array.sum(1)
else:
vec = array.ravel()
identifiers = self.dataset_1.get_identifiers(self.current_dim, sorted=True)
indices = data.get_indices(self.current_dim, identifiers)
existing_ids = data.existing_identifiers(self.current_dim, identifiers)
v = vec[indices]
vec_min = min(vec[vec > -scipy.inf])
vec_max = max(vec[vec < scipy.inf])
v[v==scipy.inf] = vec_max
v[v==-scipy.inf] = vec_min
indices = self.dataset_1.get_indices(self.current_dim, existing_ids)
map_vec = vec_min*scipy.ones(len(identifiers))
map_vec[indices] = v
# update facecolors
self.sc.set_array(map_vec)
self.sc.set_clim(vec_min, vec_max)
self.sc.update_scalarmappable() #sets facecolors from array
2007-08-14 18:10:35 +02:00
if hasattr(self.sc.cmap, "_lut"):
print "changing lut"
self.sc.cmap._lut[-1,:] = [.5,.5,.5,1]
self.sc.cmap._lut[0,:] = [.5,.5,.5,1]
else:
print "No lut present"
self.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key):
ydata = self.yaxis_data
xdata = self.xaxis_data
# find indices of selected area
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
assert x1<=x2
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
2007-01-15 14:47:18 +01:00
def lasso_select_callback(self, verts, key=None):
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
linewidth = scipy.zeros(self.xaxis_data.shape, 'f')
index = self.get_index_from_selection(self.dataset_1, selection)
if len(index) > 0:
linewidth[index] = 1.5
2007-02-26 10:08:50 +01:00
self.selection_collection.set_linewidth(linewidth)
2007-02-27 16:05:21 +01:00
if self._use_blit and len(index)>0 :
2006-10-06 12:20:53 +02:00
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
2006-10-06 12:20:53 +02:00
self.canvas.restore_region(self._background)
2007-02-26 10:08:50 +01:00
self.axes.draw_artist(self.selection_collection)
2006-10-06 12:20:53 +02:00
self.canvas.blit()
else:
self.canvas.draw()
class ImagePlot(Plot):
2007-03-14 17:27:23 +01:00
@plotlogger
def __init__(self, dataset, **kw):
2007-02-27 16:05:21 +01:00
Plot.__init__(self, kw.get('name', 'Image Plot'))
self.dataset = dataset
# Initial draw
2007-02-27 16:05:21 +01:00
self.axes.grid(False)
self.axes.imshow(dataset.asarray(), interpolation='nearest')
self.axes.axis('tight')
2007-04-24 19:04:29 +02:00
self._mappable = self.axes.images[0]
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
2007-01-17 16:40:33 +01:00
class HistogramPlot(Plot):
""" Histogram plot.
If dataset is 1-dim the current_dim is set and selections may
be performed. For dataset> 1.dim the histogram is over all values
and selections are not defined,"""
2007-03-14 17:27:23 +01:00
@plotlogger
2007-01-17 16:40:33 +01:00
def __init__(self, dataset, **kw):
Plot.__init__(self, kw['name'])
self.dataset = dataset
self._data = dataset.asarray()
# If dataset is 1-dim we may do selections
if dataset.shape[0]==1:
self.current_dim = dataset.get_dim_name(1)
if dataset.shape[1]==1:
self.current_dim = dataset.get_dim_name(0)
# Set default paramteters
if not kw.has_key('bins'):
kw['bins'] = self._get_binsize()
# Initial draw
2007-02-27 16:05:21 +01:00
self.axes.grid(False)
bins = min(self._data.size, kw['bins'])
count, lims, self.patches = self.axes.hist(self._data, bins=bins)
# Add identifiers to the individual patches
if self.current_dim != None:
for i, patch in enumerate(self.patches):
if i==len(self.patches)-1:
end_lim = self._data.max() + 1
else:
end_lim = lims[i+1]
bool_ind = scipy.bitwise_and(self._data>=lims[i],
self._data<=end_lim)
patch.index = scipy.where(bool_ind)[0]
if self.current_dim==None:
# Disable selection modes
logger.log('notice', 'Disabled selections in Histogram Plot')
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
if self.current_dim == None: return
# make (x1, y1) the lower left corner
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
self.active_patches = []
for patch in self.patches:
xmin = patch.xy[0]
xmax = xmin + patch.get_width()
ymin, ymax = -0.001, patch.get_height()
if xmax>x1 and xmin<x2 and (ymax> y2 or ymax>y1):
self.active_patches.append(patch)
if not self.active_patches: return
ids = set()
for patch in self.active_patches:
ids.update(self.dataset.get_identifiers(self.current_dim,
patch.index))
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def lasso_select_callback(self, verts, key):
if self.current_dim == None: return
self.active_patches = []
for patch in self.patches:
if scipy.any(points_inside_poly(verts, patch.get_verts())):
self.active_patches.append(patch)
if not self.active_patches: return
ids = set()
for patch in self.active_patches:
ids.update(self.dataset.get_identifiers(self.current_dim,
patch.index))
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
index = self.get_index_from_selection(self.dataset, selection)
for patch in self.patches:
patch.set_facecolor('b')
for patch in self.patches:
bin_selected = scipy.intersect1d(patch.index, index).size
if bin_selected>0:
bin_total = len(patch.index)
# fixme: engineering color
prop = -scipy.log(1.0*bin_selected/bin_total)
b = min(prop, 1)
r = max(.5, 1-b)
g = 0
patch.set_facecolor((r,g,b,1))
self.canvas.draw()
def _get_binsize(self, min_bins=2, max_bins=100):
""" Automatic bin selection, as described by Shimazaki."""
bin_vec = scipy.arange(min_bins, max_bins, 1)
D = self._data.ptp()/bin_vec
cost = scipy.empty((bin_vec.shape[0],), 'f')
for i, bins in enumerate(bin_vec):
count, lims = scipy.histogram(self._data, bins)
cost[i] = (2*count.mean() - count.var())/(D[i]**2)
best_bin_size = bin_vec[scipy.argmin(cost)]
return best_bin_size
class BarPlot(Plot):
"""Bar plot.
2007-01-17 16:40:33 +01:00
Ordinary bar plot for (column) vectors.
For matrices there is one color for each row.
"""
2007-03-14 17:27:23 +01:00
@plotlogger
2007-02-27 16:05:21 +01:00
def __init__(self, dataset, **kw):
Plot.__init__(self, kw.get('name', 'Bar Plot'))
self.dataset = dataset
2007-02-27 16:05:21 +01:00
2007-01-17 16:40:33 +01:00
# Initial draw
2007-02-27 16:05:21 +01:00
self.axes.grid(False)
n, m = dataset.shape
if m>1:
2007-03-14 17:27:23 +01:00
clrs = matplotlib.cm.ScalarMappable().to_rgba(range(n))
for i, row in enumerate(dataset.asarray()):
left = scipy.arange(i+1, m*n+1, n)
height = row
color = clrs[i]
c = (color[0], color[1], color[2])
self.axes.bar(left, height,color=c)
else:
height = dataset.asarray().ravel()
left = scipy.arange(1, n, 1)
self.axes.bar(left, height)
2007-01-17 16:40:33 +01:00
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
class NetworkPlot(Plot):
2007-03-14 17:27:23 +01:00
@plotlogger
2007-02-27 16:05:21 +01:00
def __init__(self, dataset, pos=None, nodecolor='b', nodesize=40,
prog='neato', with_labels=False, name='Network Plot'):
Plot.__init__(self, name)
self.dataset = dataset
self.graph = dataset.asnetworkx()
2007-02-27 16:05:21 +01:00
self._prog = prog
self._pos = pos
self._nodesize = nodesize
self._nodecolor = nodecolor
self._with_labels = with_labels
self.current_dim = self.dataset.get_dim_name(0)
2007-02-27 16:05:21 +01:00
if not self._pos:
self._pos = networkx.graphviz_layout(self.graph, self._prog)
self._xy = scipy.asarray([self._pos[node] for node in self.dataset.get_identifiers(self.current_dim, sorted=True)])
self.xaxis_data = self._xy[:,0]
self.yaxis_data = self._xy[:,1]
# Initial draw
self.default_props = {'nodesize' : 50,
2007-03-14 17:27:23 +01:00
'nodecolor' : 'blue',
'edge_color' : 'gray',
'edge_color_selected' : 'red'}
2007-02-27 16:05:21 +01:00
self.node_collection = None
self.edge_collection = None
self.node_labels = None
lw = scipy.zeros(self.xaxis_data.shape)
self.node_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data,
s=self._nodesize,
c=self._nodecolor,
linewidth=lw,
zorder=3)
2007-04-24 19:21:40 +02:00
self._mappable = self.node_collection
# selected nodes is a transparent graph that adjust node-edge visibility
# according to the current selection needed to get get the selected
# nodes 'on top' as zorder may not be defined individually
2007-02-27 16:05:21 +01:00
self.selected_nodes = self.axes.scatter(self.xaxis_data,
self.yaxis_data,
s=self._nodesize,
c=self._nodecolor,
linewidth=lw,
zorder=4,
alpha=0)
edge_color = self.default_props['edge_color']
2007-02-27 16:05:21 +01:00
self.edge_collection = networkx.draw_networkx_edges(self.graph,
self._pos,
ax=self.axes,
edge_color=edge_color)
2007-03-14 17:27:23 +01:00
# edge color rgba-arrays
2007-07-24 14:19:13 +02:00
self._edge_color_rgba = matlib.repmat(ColorConverter().to_rgba(edge_color),
self.graph.number_of_edges(),1)
self._edge_color_selected = ColorConverter().to_rgba(self.default_props['edge_color_selected'])
2007-02-27 16:05:21 +01:00
if self._with_labels:
self.node_labels = networkx.draw_networkx_labels(self.graph,
self._pos,
ax=self.axes)
# remove axes, frame and grid
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.grid(False)
self.axes.set_frame_on(False)
self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
2007-02-27 16:05:21 +01:00
ydata = self.yaxis_data
xdata = self.xaxis_data
# find indices of selected area
2007-02-27 16:05:21 +01:00
if x1>x2:
x1, x2 = x2, x1
2007-02-27 16:05:21 +01:00
if y1>y2:
y1, y2 = y2, y1
2007-02-27 16:05:21 +01:00
assert x1<=x2
assert y1<=y2
2006-09-18 19:23:34 +02:00
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
2007-02-27 16:05:21 +01:00
ids = self.dataset.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
2007-01-15 14:47:18 +01:00
def lasso_select_callback(self, verts, key=None):
2007-02-27 16:05:21 +01:00
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
2007-01-15 14:47:18 +01:00
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
2007-02-27 16:05:21 +01:00
ids = self.dataset.get_identifiers(self.current_dim, index)
2007-01-15 14:47:18 +01:00
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
linewidth = scipy.zeros(self.xaxis_data.shape)
edge_color_rgba = self._edge_color_rgba.copy()
2007-02-27 16:05:21 +01:00
index = self.get_index_from_selection(self.dataset, selection)
if len(index) > 0:
2007-07-24 14:19:13 +02:00
linewidth[index] = 2
idents = selection[self.current_dim]
edge_index = [i for i,edge in enumerate(self.graph.edges()) if (edge[0] in idents and edge[1] in idents)]
if len(edge_index)>0:
for i in edge_index:
edge_color_rgba[i,:] = self._edge_color_selected
self._A = None
self.edge_collection._colors = edge_color_rgba
2007-02-27 16:05:21 +01:00
self.selected_nodes.set_linewidth(linewidth)
self.canvas.draw()
class VennPlot(Plot):
2007-03-14 17:27:23 +01:00
@plotlogger
def __init__(self, name="Venn diagram"):
Plot.__init__(self, name)
2007-02-27 17:28:03 +01:00
2007-02-27 16:05:21 +01:00
# init draw
2007-02-27 17:28:03 +01:00
self._init_bck()
for c in self._venn_patches:
2007-02-26 10:08:50 +01:00
self.axes.add_patch(c)
for mrk in self._markers:
2007-02-26 10:08:50 +01:00
self.axes.add_patch(mrk)
self.axes.set_xlim([-3, 3])
self.axes.set_ylim([-2.5, 3.5])
self._last_active = set()
2007-02-26 10:08:50 +01:00
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.axis('equal')
2007-02-27 16:05:21 +01:00
self.axes.grid(False)
2007-02-26 10:08:50 +01:00
self.axes.set_frame_on(False)
self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
2007-02-26 10:08:50 +01:00
def _init_bck(self):
res = 50
a = .5
r = 1.5
mr = .2
self.c1 = c1 = Circle((-1,0), radius=r, alpha=a, facecolor='b')
self.c2 = c2 = Circle((1,0), radius=r, alpha=a, facecolor='r')
self.c3 = c3 = Circle((0, scipy.sqrt(3)), radius=r, alpha=a, facecolor='g')
self.c1marker = Circle((-1.25, -.25), radius=mr, facecolor='y', alpha=0)
self.c2marker = Circle((1.25, -.25), radius=mr, facecolor='y', alpha=0)
self.c3marker = Circle((0, scipy.sqrt(3)+.25), radius=mr, facecolor='y', alpha=0)
self.c1c2marker = Circle((0, -.15), radius=mr, facecolor='y', alpha=0)
self.c1c3marker = Circle((-scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
self.c2c3marker = Circle((scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
self.c1c2c3marker = Circle((0, .6), radius=mr, facecolor='y', alpha=0)
c1.elements = set(['a', 'b', 'c', 'f'])
c2.elements = set(['a', 'c', 'd', 'e'])
c3.elements = set(['a', 'e', 'f', 'g'])
self.active_elements = set()
self.all_elements = c1.elements.union(c2.elements).union(c3.elements)
c1.active = False
c2.active = False
c3.active = False
c1.name = 'Blue'
c2.name = 'Red'
c3.name = 'Green'
self._venn_patches = [c1, c2, c3]
self._markers = [self.c1marker, self.c2marker, self.c3marker,
self.c1c2marker, self.c1c3marker,
self.c2c3marker, self.c1c2c3marker]
self._tot_label = 'Tot: ' + str(len(self.all_elements))
self._sel_label = 'Sel: ' + str(len(self.active_elements))
2007-02-26 10:08:50 +01:00
self._legend = self.axes.legend((self._tot_label, self._sel_label),
loc='upper right')
def set_selection(self, selection, patch=None):
if patch:
patch.selection = selection
else:
selection_set = False
for patch in self._venn_patches:
if len(patch.elements)==0:
patch.elements = selection
selection_set = True
if not selection_set:
self.venn_patches[0].elements = selection
def lasso_select_callback(self, verts, key=None):
if verts==None:
verts = (self._event.xdata, self._event.ydata)
if key!='shift':
for m in self._markers:
m.set_alpha(0)
self._patches_within_verts(verts, key)
active = [i.active for i in self._venn_patches]
if active==[True, False, False]:
self.c1marker.set_alpha(1)
self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
elif active== [False, True, False]:
self.c2marker.set_alpha(1)
self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
elif active== [False, False, True]:
self.c3marker.set_alpha(1)
self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
elif active==[True, True, False]:
self.c1c2marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c2.elements)
elif active==[True, False, True]:
self.c1c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements)
elif active==[False, True, True]:
self.c2c3marker.set_alpha(1)
self.active_elements = self.c2.elements.intersection(self.c3.elements)
elif active==[True, True, True]:
self.c1c2c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)
if key=='shift':
self.active_elements = self.active_elements.union(self._last_active)
self._last_active = self.active_elements.copy()
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend.texts[1].set_text(self._sel_label)
2007-02-26 10:08:50 +01:00
self.axes.figure.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key):
verts = [(x1, y1), (x2, y2)]
if key!='shift':
for m in self._markers:
m.set_alpha(0)
self._patches_within_verts(verts, key)
active = [i.active for i in self._venn_patches]
if active==[True, False, False]:
self.c1marker.set_alpha(1)
self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
elif active== [False, True, False]:
self.c2marker.set_alpha(1)
self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
elif active== [False, False, True]:
self.c3marker.set_alpha(1)
self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
elif active==[True, True, False]:
self.c1c2marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c2.elements)
elif active==[True, False, True]:
self.c1c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements)
elif active==[False, True, True]:
self.c2c3marker.set_alpha(1)
self.active_elements = self.c2.elements.intersection(self.c3.elements)
elif active==[True, True, True]:
self.c1c2c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)
if key=='shift':
self.active_elements = self.active_elements.union(self._last_active)
self._last_active = self.active_elements.copy()
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend.texts[1].set_text(self._sel_label)
2007-02-26 10:08:50 +01:00
self.axes.figure.canvas.draw()
def _patches_within_verts(self, verts, key):
xy = scipy.array(verts).mean(0)
for venn_patch in self._venn_patches:
venn_patch.active = False
if self._distance(venn_patch.center,xy)<venn_patch.radius:
venn_patch.active = True
def _distance(self, (x1,y1),(x2,y2)):
return scipy.sqrt( (x2-x1)**2 + (y2-y1)**2 )
class PlotThresholder:
"""Mixin class for plots that needs to filter nodes within a threshold
range.
"""
def __init__(self, text="x"):
"""Constructor.
@param text: Name of the variable the threshold is on.
"""
self._threshold_ds = None
self._add_spin_buttons(text)
self._sb_min.set_sensitive(False)
self._sb_max.set_sensitive(False)
def set_threshold_dataset(self, ds):
"""Sets the dataset to threshold on.
@param ds: A dataset where one dimension corresponds to the select dimension
in the plot, and any other dimensions have length 1
"""
self._threshold_ds = ds
self._sb_min.set_sensitive(True)
self._sb_max.set_sensitive(True)
def _add_spin_buttons(self, text):
"""Adds spin buttons to the toolbar for selecting minimum and maximum
threshold values on information content."""
sb_min = gtk.SpinButton(digits=2)
sb_min.set_range(0, 100)
sb_min.set_value(0)
sb_min.set_increments(.1, 1.)
sb_min.connect('value-changed', self._on_value_changed)
self._sb_min = sb_min
sb_max = gtk.SpinButton(digits=2)
sb_max.set_range(0, 100)
sb_max.set_value(1)
sb_max.set_increments(.1, 1.)
sb_max.connect('value-changed', self._on_value_changed)
self._sb_max = sb_max
label = gtk.Label(" < %s < " % text)
hbox = gtk.HBox()
hbox.pack_start(sb_min)
hbox.pack_start(label)
hbox.pack_start(sb_max)
ti = gtk.ToolItem()
ti.set_expand(False)
ti.add(hbox)
sb_min.show()
sb_max.show()
label.show()
hbox.show()
ti.show()
self._toolbar.insert(ti, -1)
ti.set_tooltip(self._toolbar.tooltips, "Set threshold")
def set_threshold(self, min, max):
"""Sets min and max to the given values.
Updates the plot accordingly to show only values that have a
value within the boundaries. Other values are
also excluded from being selected from the plot.
@param ic_min Do not show nodes with IC below this value.
@param ic_max Do not show nodes with IC above this value.
"""
ds = self._threshold_ds
if ds == None:
return
icnodes = ds.existing_identifiers('go-terms', self._map_ids)
icindices = ds.get_indices('go-terms', icnodes)
2007-08-08 14:24:14 +02:00
a = ds.asarray()[icindices].sum(1)
good = set(scipy.array(icnodes)[(a>=min) & (a<=max)])
sizes = scipy.zeros(len(self._map_ids))
visible = set()
for i, n in enumerate(self._map_ids):
if n in good:
sizes[i] = 50
visible.add(n)
else:
sizes[i] = 0
self.visible = visible
self._mappable._sizes = sizes
self.canvas.draw()
def get_nodes_within_bounds(self):
"""Get a list of all nodes within the bounds of the selection in the
seleted dataset.
"""
pass
def filter_nodes(self, nodes):
"""Filter a list of nodes and return only those that are within the
threshold boundaries."""
pass
def _on_value_changed(self, sb):
"""Callback on spin button value changes."""
min = self._sb_min.get_value()
max = self._sb_max.get_value()
self.set_threshold(min, max)
# Create zoom-changed signal
gobject.signal_new('zoom-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
(gobject.TYPE_PYOBJECT,))
# Create pan/zoom-changed signal
gobject.signal_new('pan-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
(gobject.TYPE_PYOBJECT,))
# Create plot-resize-changed signal
gobject.signal_new('plot-resize-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
(gobject.TYPE_PYOBJECT,))