This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
laydi/fluents/plots.py

1092 lines
42 KiB
Python

import pygtk
import gobject
import gtk
import matplotlib
from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg
from matplotlib.nxutils import points_inside_poly
from matplotlib.figure import Figure
from matplotlib.collections import LineCollection
from matplotlib.patches import Polygon,Rectangle,Circle
from matplotlib.lines import Line2D
from matplotlib.mlab import prctile
from matplotlib.colors import ColorConverter
import networkx
import scipy
from numpy import matlib
import fluents
import logger
import view
def plotlogger(func, name=None):
def wrapped(parent, *args, **kw):
parent.__args = args
parent.__kw = kw
return func(parent, *args, **kw)
return wrapped
class Plot(view.View):
def __init__(self, title):
view.View.__init__(self, title)
logger.log('debug', 'plot %s init' %title)
self.selection_listener = None
self.current_dim = None
self._current_selection = None
self._frozen = False
self._init_mpl()
def _init_mpl(self):
# init matplotlib related stuff
self._background = None
self._colorbar = None
self._mappable = None
self._use_blit = False
self.fig = Figure()
self.canvas = FigureCanvasGTKAgg(self.fig)
self.axes = self.fig.gca()
self._toolbar = view.PlotToolbar(self)
self._key_press = self.canvas.mpl_connect(
'key_press_event', self.on_key_press)
self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK)
self.add(self.canvas)
self.canvas.show()
def set_frozen(self, frozen):
"""A frozen plot will not be updated when the current
selection is changed."""
self._frozen = frozen
if not frozen and self._current_selection != None:
self.set_current_selection(self._current_selection)
def get_title(self):
return self.title
def get_toolbar(self):
return self._toolbar
def selection_changed(self, dim_name, selection):
""" Selection observer handle.
A selection change in a plot is only drawn if:
1.) plot is sensitive to selections (not freezed)
2.) plot is visible (has a view)
3.) the selections dim_name is the plot's dimension.
"""
self._current_selection = selection
if self._frozen \
or not self.get_property('visible') \
or self.current_dim != dim_name:
return
self.set_current_selection(selection)
def set_selection_listener(self, listener):
"""Allow project to listen to selections.
The selection will propagate back to all plots through the
selection_changed() method. The listener will be called as
listener(dimension_name, ids).
"""
self.selection_listener = listener
def update_selection(self, ids, key=None):
"""Returns updated current selection from ids.
If a key is pressed we use the appropriate mode.
key map:
shift : union
control : intersection
"""
print "Inside update selection, key: %s" %key
if key == 'shift':
ids = set(ids).union(self._current_selection[self.current_dim])
elif key == 'control':
ids = set(ids).intersection(self._current_selection[self.current_dim])
return ids
def set_current_selection(self, selection):
"""Called whenever the plot should change the selection.
This method is a dummy method, so that specialized plots that have
no implemented selection can ignore selections alltogether.
"""
pass
def rectangle_select_callback(self, *args):
"""Overrriden in subclass."""
if hasattr(self, 'canvas'):
self.canvas.draw()
def lasso_select_callback(self, *args):
"""Overrriden in subclass."""
if hasattr(self, 'canvas'):
self.canvas.draw()
def get_index_from_selection(self, dataset, selection):
"""Returns the index vector of current selection in given dim."""
if not selection: return []
ids = selection.get(self.current_dim, []) # current identifiers
if not ids : return []
return dataset.get_indices(self.current_dim, ids)
def on_key_press(self, event):
if event.key == 'c':
self._toggle_colorbar()
def _toggle_colorbar(self):
if self._colorbar == None:
if self._mappable == None:
logger.log('notice', 'No mappable in this plot')
return
if self._mappable._A != None: # we need colormapping
# get axes original position
self._ax_last_pos = self.axes.get_position()
self._colorbar = self.fig.colorbar(self._mappable)
self._colorbar.draw_all()
self.canvas.draw()
else:
# remove colorbar
# remove, axes, observers, colorbar instance, and restore viewlims
cb, ax = self._mappable.colorbar
self.fig.delaxes(ax)
self._mappable.observers = [obs for obs in self._mappable.observers if obs !=self._colorbar]
self._colorbar = None
self.axes.set_position(self._ax_last_pos)
self.canvas.draw()
class LineViewPlot(Plot):
"""Line view plot with percentiles.
A line view of vectors across a specified dimension of input dataset.
No selection interaction is defined.
Only support for 2d-arrays.
input:
-- major_axis : dim_number for line dim (see scipy.ndarray for axis def.)
-- minor_axis : needs definition only for higher order arrays
fixme: slow
"""
@plotlogger
def __init__(self, dataset, major_axis=1, minor_axis=None, center=True,name="Line view"):
Plot.__init__(self, name)
self.dataset = dataset
self._data = dataset.asarray()
if len(self._data.shape)==2 and not minor_axis:
minor_axis = major_axis - 1
self.major_axis = major_axis
self.minor_axis = minor_axis
self.current_dim = self.dataset.get_dim_name(major_axis)
self.data_is_centered = False
self._mn_data = 0
if center and len(self._data.shape)==2:
if minor_axis==0:
self._mn_data = self._data.mean(minor_axis)
else:
self._mn_data = self._data.mean(minor_axis)[:,newaxis]
self._data = self._data - self._mn_data
self.data_is_centered = True
#initial line collection
self.line_coll = None
self.make_lines()
# draw background
self.set_background()
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
def make_lines(self):
"""Creates one line for each item along major axis."""
if self.line_coll: # remove any previous selection lines, if any
self.axes.collections.remove(self.line_coll)
self.line_coll = None
self.line_segs = []
x_axis = scipy.arange(self._data.shape[self.minor_axis])
for xi in range(self._data.shape[self.major_axis]):
yi = self._data.take([xi], self.major_axis).ravel()
self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)])
def set_background(self):
"""Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median.
"""
if self._data.shape[self.minor_axis]<6:
return
# clean old patches if any
if len(self.axes.patches)>0:
self.axes.patches = []
# clean old lines (median) if any
if len(self.axes.lines)>0:
self.axes.lines = []
# settings
patch_color = 'b' #blue
patch_lw = 0 #no edges
patch_alpha = .15 # transparancy
median_color = 'b' #blue
median_width = 1.5 #linewidth
percentiles = [0, 5, 25, 50, 75, 100]
# ordinate
xax = scipy.arange(self._data.shape[self.minor_axis])
#vertices
verts_0 = [] #100,0
verts_1 = [] # 90,10
verts_2 = [] # 75,25
med = []
# add top vertices the low vertices (do i need an order?)#background
for i in xax:
prct = prctile(self._data.take([i], self.minor_axis), percentiles)
verts_0.append((i, prct[0]))
verts_1.append((i, prct[1]))
verts_2.append((i, prct[2]))
med.append(prct[3])
for i in xax[::-1]:
prct = prctile(self._data.take([i], self.minor_axis), percentiles)
verts_0.append((i, prct[-1]))
verts_1.append((i, prct[-2]))
verts_2.append((i, prct[-3]))
# make polygons from vertices
bck0 = Polygon(verts_0, alpha=patch_alpha, lw=patch_lw,
facecolor=patch_color)
bck1 = Polygon(verts_1, alpha=patch_alpha, lw=patch_lw,
facecolor=patch_color)
bck2 = Polygon(verts_2, alpha=patch_alpha, lw=patch_lw,
facecolor=patch_color)
# add polygons to axes
self.axes.add_patch(bck0)
self.axes.add_patch(bck1)
self.axes.add_patch(bck2)
# median line
self.axes.plot(xax, med, median_color, linewidth=median_width)
# set y-limits
padding = 0.1
self.axes.set_ylim([self._data.min() - padding, self._data.max() + padding])
def set_current_selection(self, selection):
"""Draws the current selection.
"""
index = self.get_index_from_selection(self.dataset, selection)
if self.line_coll:
self.axes.collections.remove(self.line_coll)
segs = [self.line_segs[i] for i in index]
self.line_coll = LineCollection(segs, colors=(1,0,0,1))
self.axes.add_collection(self.line_coll)
#draw
if self._use_blit:
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
self.axes.draw_artist(self.lines)
self.canvas.blit()
else:
self.canvas.draw()
class ScatterMarkerPlot(Plot):
"""The ScatterMarkerPlot is faster than regular scatterplot, but
has no color and size options."""
@plotlogger
def __init__(self, dataset_1, dataset_2, id_dim, sel_dim,
id_1, id_2, s=6, name="Scatter plot"):
Plot.__init__(self, name)
self.current_dim = id_dim
self.dataset_1 = dataset_1
self.ms = s
x_index = dataset_1[sel_dim][id_1]
y_index = dataset_2[sel_dim][id_2]
self.xaxis_data = dataset_1._array[:, x_index]
self.yaxis_data = dataset_2._array[:, y_index]
# init draw
self._selection_line = None
self.line = self.axes.plot(self.xaxis_data,
self.yaxis_data, 'o',
markeredgewidth=0,
markersize=s)
self.axes.axhline(0, color='k', lw=1., zorder=1)
self.axes.axvline(0, color='k', lw=1., zorder=1)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
ydata = self.yaxis_data
xdata = self.xaxis_data
# find indices of selected area
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
assert x1<=x2
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def lasso_select_callback(self, verts, key=None):
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
#remove old selection
if self._selection_line:
self.axes.lines.remove(self._selection_line)
index = self.get_index_from_selection(self.dataset_1, selection)
if len(index)==0:
# no selection
self.canvas.draw()
self._selection_line = None
return
xdata_new = self.xaxis_data.take(index) #take data
ydata_new = self.yaxis_data.take(index)
self._selection_line = Line2D(xdata_new, ydata_new
,marker='o', markersize=self.ms,
linewidth=0, markerfacecolor='r',
markeredgewidth=1.0)
self.axes.add_line(self._selection_line)
if self._use_blit:
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
if self.selection_line:
self.axes.draw_artist(self._selection_line)
self.canvas.blit()
else:
self.canvas.draw()
class ScatterPlot(Plot):
"""The ScatterPlot is slower than scattermarker, but has size option."""
@plotlogger
def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, c='b', s=30, sel_dim_2=None, name="Scatter plot", **kw):
"""Initializes a scatter plot.
"""
Plot.__init__(self, name)
self.dataset_1 = dataset_1
self.s = s
self.c = c
self.kw = kw
self.current_dim = id_dim
self._map_ids = dataset_1.get_identifiers(id_dim, sorted=True)
x_index = dataset_1[sel_dim][id_1]
if sel_dim_2:
y_index = dataset_2[sel_dim_2][id_2]
else:
y_index = dataset_2[sel_dim][id_2]
self.xaxis_data = dataset_1._array[:, x_index]
self.yaxis_data = dataset_2._array[:, y_index]
# init draw
self.init_draw()
# signals to enable correct use of blit
self.connect('zoom-changed', self.onzoom)
self.connect('pan-changed', self.onpan)
self.need_redraw = False
self.canvas.mpl_connect('resize_event', self.onresize)
def onzoom(self, widget, mode):
#logger.log('notice', 'Zoom in widget: %s' %widget)
self.clean_redraw()
def onpan(self, widget, mode):
#logger.log('notice', 'Pan in widget: %s' %widget)
self.clean_redraw()
def onresize(self, widget):
#logger.log('notice', 'resize event')
self.clean_redraw()
def clean_redraw(self):
if self._use_blit == True:
#logger.log('notice', 'blit -> clean redraw ')
self.set_current_selection(None)
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.set_current_selection(self._current_selection)
else:
self._background = None
def init_draw(self):
lw = scipy.zeros(self.xaxis_data.shape)
self.sc = self.axes.scatter(self.xaxis_data, self.yaxis_data,
s=self.s, c=self.c, linewidth=lw,
zorder=3, **self.kw)
self._mappable = self.sc
self.axes.axhline(0, color='k', lw=1., zorder=1)
self.axes.axvline(0, color='k', lw=1., zorder=1)
self.selection_collection = self.axes.scatter(self.xaxis_data,
self.yaxis_data,
alpha=0,
c='w',s=self.s,
edgecolor='r',
linewidth=0,
zorder=4)
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
def is_mappable_with(self, obj):
"""Returns True if dataset/selection is mappable with this plot.
"""
if isinstance(obj, fluents.dataset.Dataset):
if self.current_dim in obj.get_dim_name():
return True
elif isinstance(obj, fluents.dataset.Selection):
if self.current_dim in obj.get_dim_name():
return True
else:
return False
def _update_color_from_dataset(self, data):
"""Updates the facecolors from a dataset.
"""
array = data.asarray()
#only support for 2d-arrays:
try:
m, n = array.shape
except:
raise ValueError, "No support for more than 2 dimensions."
# is dataset a vector or matrix?
if not n==1:
# we have a category dataset
if isinstance(data, fluents.dataset.CategoryDataset):
vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1)
else:
vec = array.sum(1)
else:
vec = array.ravel()
identifiers = self.dataset_1.get_identifiers(self.current_dim, sorted=True)
indices = data.get_indices(self.current_dim, identifiers)
existing_ids = data.existing_identifiers(self.current_dim, identifiers)
v = vec[indices]
vec_min = min(vec[vec > -scipy.inf])
vec_max = max(vec[vec < scipy.inf])
v[v==scipy.inf] = vec_max
v[v==-scipy.inf] = vec_min
indices = self.dataset_1.get_indices(self.current_dim, existing_ids)
map_vec = vec_min*scipy.ones(len(identifiers))
map_vec[indices] = v
# update facecolors
self.sc.set_array(map_vec)
self.sc.set_clim(vec_min, vec_max)
self.sc.update_scalarmappable() #sets facecolors from array
if hasattr(self.sc.cmap, "_lut"):
print "changing lut"
self.sc.cmap._lut[-1,:] = [.5,.5,.5,1]
self.sc.cmap._lut[0,:] = [.5,.5,.5,1]
else:
print "No lut present"
self.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key):
ydata = self.yaxis_data
xdata = self.xaxis_data
# find indices of selected area
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
assert x1<=x2
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def lasso_select_callback(self, verts, key=None):
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
linewidth = scipy.zeros(self.xaxis_data.shape, 'f')
index = self.get_index_from_selection(self.dataset_1, selection)
if len(index) > 0:
linewidth[index] = 1.5
self.selection_collection.set_linewidth(linewidth)
if self._use_blit and len(index)>0 :
if self._background is None:
self._background = self.canvas.copy_from_bbox(self.axes.bbox)
self.canvas.restore_region(self._background)
self.axes.draw_artist(self.selection_collection)
self.canvas.blit()
else:
self.canvas.draw()
class ImagePlot(Plot):
@plotlogger
def __init__(self, dataset, **kw):
Plot.__init__(self, kw.get('name', 'Image Plot'))
self.dataset = dataset
# Initial draw
self.axes.grid(False)
self.axes.imshow(dataset.asarray(), interpolation='nearest')
self.axes.axis('tight')
self._mappable = self.axes.images[0]
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
class HistogramPlot(Plot):
""" Histogram plot.
If dataset is 1-dim the current_dim is set and selections may
be performed. For dataset> 1.dim the histogram is over all values
and selections are not defined,"""
@plotlogger
def __init__(self, dataset, **kw):
Plot.__init__(self, kw['name'])
self.dataset = dataset
self._data = dataset.asarray()
# If dataset is 1-dim we may do selections
if dataset.shape[0]==1:
self.current_dim = dataset.get_dim_name(1)
if dataset.shape[1]==1:
self.current_dim = dataset.get_dim_name(0)
# Set default paramteters
if not kw.has_key('bins'):
kw['bins'] = self._get_binsize()
# Initial draw
self.axes.grid(False)
bins = min(self._data.size, kw['bins'])
count, lims, self.patches = self.axes.hist(self._data, bins=bins)
# Add identifiers to the individual patches
if self.current_dim != None:
for i, patch in enumerate(self.patches):
if i==len(self.patches)-1:
end_lim = self._data.max() + 1
else:
end_lim = lims[i+1]
bool_ind = scipy.bitwise_and(self._data>=lims[i],
self._data<=end_lim)
patch.index = scipy.where(bool_ind)[0]
if self.current_dim==None:
# Disable selection modes
logger.log('notice', 'Disabled selections in Histogram Plot')
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
if self.current_dim == None: return
# make (x1, y1) the lower left corner
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
self.active_patches = []
for patch in self.patches:
xmin = patch.xy[0]
xmax = xmin + patch.get_width()
ymin, ymax = -0.001, patch.get_height()
if xmax>x1 and xmin<x2 and (ymax> y2 or ymax>y1):
self.active_patches.append(patch)
if not self.active_patches: return
ids = set()
for patch in self.active_patches:
ids.update(self.dataset.get_identifiers(self.current_dim,
patch.index))
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def lasso_select_callback(self, verts, key):
if self.current_dim == None: return
self.active_patches = []
for patch in self.patches:
if scipy.any(points_inside_poly(verts, patch.get_verts())):
self.active_patches.append(patch)
if not self.active_patches: return
ids = set()
for patch in self.active_patches:
ids.update(self.dataset.get_identifiers(self.current_dim,
patch.index))
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
index = self.get_index_from_selection(self.dataset, selection)
for patch in self.patches:
patch.set_facecolor('b')
for patch in self.patches:
bin_selected = scipy.intersect1d(patch.index, index).size
if bin_selected>0:
bin_total = len(patch.index)
# fixme: engineering color
prop = -scipy.log(1.0*bin_selected/bin_total)
b = min(prop, 1)
r = max(.5, 1-b)
g = 0
patch.set_facecolor((r,g,b,1))
self.canvas.draw()
def _get_binsize(self, min_bins=2, max_bins=100):
""" Automatic bin selection, as described by Shimazaki."""
bin_vec = scipy.arange(min_bins, max_bins, 1)
D = self._data.ptp()/bin_vec
cost = scipy.empty((bin_vec.shape[0],), 'f')
for i, bins in enumerate(bin_vec):
count, lims = scipy.histogram(self._data, bins)
cost[i] = (2*count.mean() - count.var())/(D[i]**2)
best_bin_size = bin_vec[scipy.argmin(cost)]
return best_bin_size
class BarPlot(Plot):
"""Bar plot.
Ordinary bar plot for (column) vectors.
For matrices there is one color for each row.
"""
@plotlogger
def __init__(self, dataset, **kw):
Plot.__init__(self, kw.get('name', 'Bar Plot'))
self.dataset = dataset
# Initial draw
self.axes.grid(False)
n, m = dataset.shape
if m>1:
clrs = matplotlib.cm.ScalarMappable().to_rgba(range(n))
for i, row in enumerate(dataset.asarray()):
left = scipy.arange(i+1, m*n+1, n)
height = row
color = clrs[i]
c = (color[0], color[1], color[2])
self.axes.bar(left, height,color=c)
else:
height = dataset.asarray().ravel()
left = scipy.arange(1, n, 1)
self.axes.bar(left, height)
# Disable selection modes
self._toolbar.freeze_button.set_sensitive(False)
self._toolbar.set_mode_sensitive('select', False)
self._toolbar.set_mode_sensitive('lassoselect', False)
class NetworkPlot(Plot):
@plotlogger
def __init__(self, dataset, pos=None, nodecolor='b', nodesize=40,
prog='neato', with_labels=False, name='Network Plot'):
Plot.__init__(self, name)
self.dataset = dataset
self.graph = dataset.asnetworkx()
self._prog = prog
self._pos = pos
self._nodesize = nodesize
self._nodecolor = nodecolor
self._with_labels = with_labels
self.current_dim = self.dataset.get_dim_name(0)
if not self._pos:
self._pos = networkx.graphviz_layout(self.graph, self._prog)
self._xy = scipy.asarray([self._pos[node] for node in self.dataset.get_identifiers(self.current_dim, sorted=True)])
self.xaxis_data = self._xy[:,0]
self.yaxis_data = self._xy[:,1]
# Initial draw
self.default_props = {'nodesize' : 50,
'nodecolor' : 'blue',
'edge_color' : 'gray',
'edge_color_selected' : 'red'}
self.node_collection = None
self.edge_collection = None
self.node_labels = None
lw = scipy.zeros(self.xaxis_data.shape)
self.node_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data,
s=self._nodesize,
c=self._nodecolor,
linewidth=lw,
zorder=3)
self._mappable = self.node_collection
# selected nodes is a transparent graph that adjust node-edge visibility
# according to the current selection needed to get get the selected
# nodes 'on top' as zorder may not be defined individually
self.selected_nodes = self.axes.scatter(self.xaxis_data,
self.yaxis_data,
s=self._nodesize,
c=self._nodecolor,
linewidth=lw,
zorder=4,
alpha=0)
edge_color = self.default_props['edge_color']
self.edge_collection = networkx.draw_networkx_edges(self.graph,
self._pos,
ax=self.axes,
edge_color=edge_color)
# edge color rgba-arrays
self._edge_color_rgba = matlib.repmat(ColorConverter().to_rgba(edge_color),
self.graph.number_of_edges(),1)
self._edge_color_selected = ColorConverter().to_rgba(self.default_props['edge_color_selected'])
if self._with_labels:
self.node_labels = networkx.draw_networkx_labels(self.graph,
self._pos,
ax=self.axes)
# remove axes, frame and grid
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.grid(False)
self.axes.set_frame_on(False)
self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
def rectangle_select_callback(self, x1, y1, x2, y2, key):
ydata = self.yaxis_data
xdata = self.xaxis_data
# find indices of selected area
if x1>x2:
x1, x2 = x2, x1
if y1>y2:
y1, y2 = y2, y1
assert x1<=x2
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = self.dataset.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def lasso_select_callback(self, verts, key=None):
xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
index = scipy.nonzero(points_inside_poly(xys, verts))[0]
ids = self.dataset.get_identifiers(self.current_dim, index)
ids = self.update_selection(ids, key)
self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection):
linewidth = scipy.zeros(self.xaxis_data.shape)
edge_color_rgba = self._edge_color_rgba.copy()
index = self.get_index_from_selection(self.dataset, selection)
if len(index) > 0:
linewidth[index] = 2
idents = selection[self.current_dim]
edge_index = [i for i,edge in enumerate(self.graph.edges()) if (edge[0] in idents and edge[1] in idents)]
if len(edge_index)>0:
for i in edge_index:
edge_color_rgba[i,:] = self._edge_color_selected
self._A = None
self.edge_collection._colors = edge_color_rgba
self.selected_nodes.set_linewidth(linewidth)
self.canvas.draw()
class VennPlot(Plot):
@plotlogger
def __init__(self, name="Venn diagram"):
Plot.__init__(self, name)
# init draw
self._init_bck()
for c in self._venn_patches:
self.axes.add_patch(c)
for mrk in self._markers:
self.axes.add_patch(mrk)
self.axes.set_xlim([-3, 3])
self.axes.set_ylim([-2.5, 3.5])
self._last_active = set()
self.axes.set_xticks([])
self.axes.set_yticks([])
self.axes.axis('equal')
self.axes.grid(False)
self.axes.set_frame_on(False)
self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
def _init_bck(self):
res = 50
a = .5
r = 1.5
mr = .2
self.c1 = c1 = Circle((-1,0), radius=r, alpha=a, facecolor='b')
self.c2 = c2 = Circle((1,0), radius=r, alpha=a, facecolor='r')
self.c3 = c3 = Circle((0, scipy.sqrt(3)), radius=r, alpha=a, facecolor='g')
self.c1marker = Circle((-1.25, -.25), radius=mr, facecolor='y', alpha=0)
self.c2marker = Circle((1.25, -.25), radius=mr, facecolor='y', alpha=0)
self.c3marker = Circle((0, scipy.sqrt(3)+.25), radius=mr, facecolor='y', alpha=0)
self.c1c2marker = Circle((0, -.15), radius=mr, facecolor='y', alpha=0)
self.c1c3marker = Circle((-scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
self.c2c3marker = Circle((scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
self.c1c2c3marker = Circle((0, .6), radius=mr, facecolor='y', alpha=0)
c1.elements = set(['a', 'b', 'c', 'f'])
c2.elements = set(['a', 'c', 'd', 'e'])
c3.elements = set(['a', 'e', 'f', 'g'])
self.active_elements = set()
self.all_elements = c1.elements.union(c2.elements).union(c3.elements)
c1.active = False
c2.active = False
c3.active = False
c1.name = 'Blue'
c2.name = 'Red'
c3.name = 'Green'
self._venn_patches = [c1, c2, c3]
self._markers = [self.c1marker, self.c2marker, self.c3marker,
self.c1c2marker, self.c1c3marker,
self.c2c3marker, self.c1c2c3marker]
self._tot_label = 'Tot: ' + str(len(self.all_elements))
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend = self.axes.legend((self._tot_label, self._sel_label),
loc='upper right')
def set_selection(self, selection, patch=None):
if patch:
patch.selection = selection
else:
selection_set = False
for patch in self._venn_patches:
if len(patch.elements)==0:
patch.elements = selection
selection_set = True
if not selection_set:
self.venn_patches[0].elements = selection
def lasso_select_callback(self, verts, key=None):
if verts==None:
verts = (self._event.xdata, self._event.ydata)
if key!='shift':
for m in self._markers:
m.set_alpha(0)
self._patches_within_verts(verts, key)
active = [i.active for i in self._venn_patches]
if active==[True, False, False]:
self.c1marker.set_alpha(1)
self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
elif active== [False, True, False]:
self.c2marker.set_alpha(1)
self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
elif active== [False, False, True]:
self.c3marker.set_alpha(1)
self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
elif active==[True, True, False]:
self.c1c2marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c2.elements)
elif active==[True, False, True]:
self.c1c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements)
elif active==[False, True, True]:
self.c2c3marker.set_alpha(1)
self.active_elements = self.c2.elements.intersection(self.c3.elements)
elif active==[True, True, True]:
self.c1c2c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)
if key=='shift':
self.active_elements = self.active_elements.union(self._last_active)
self._last_active = self.active_elements.copy()
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend.texts[1].set_text(self._sel_label)
self.axes.figure.canvas.draw()
def rectangle_select_callback(self, x1, y1, x2, y2, key):
verts = [(x1, y1), (x2, y2)]
if key!='shift':
for m in self._markers:
m.set_alpha(0)
self._patches_within_verts(verts, key)
active = [i.active for i in self._venn_patches]
if active==[True, False, False]:
self.c1marker.set_alpha(1)
self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
elif active== [False, True, False]:
self.c2marker.set_alpha(1)
self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
elif active== [False, False, True]:
self.c3marker.set_alpha(1)
self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
elif active==[True, True, False]:
self.c1c2marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c2.elements)
elif active==[True, False, True]:
self.c1c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements)
elif active==[False, True, True]:
self.c2c3marker.set_alpha(1)
self.active_elements = self.c2.elements.intersection(self.c3.elements)
elif active==[True, True, True]:
self.c1c2c3marker.set_alpha(1)
self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)
if key=='shift':
self.active_elements = self.active_elements.union(self._last_active)
self._last_active = self.active_elements.copy()
self._sel_label = 'Sel: ' + str(len(self.active_elements))
self._legend.texts[1].set_text(self._sel_label)
self.axes.figure.canvas.draw()
def _patches_within_verts(self, verts, key):
xy = scipy.array(verts).mean(0)
for venn_patch in self._venn_patches:
venn_patch.active = False
if self._distance(venn_patch.center,xy)<venn_patch.radius:
venn_patch.active = True
def _distance(self, (x1,y1),(x2,y2)):
return scipy.sqrt( (x2-x1)**2 + (y2-y1)**2 )
class PlotThresholder:
"""Mixin class for plots that needs to filter nodes within a threshold
range.
"""
def __init__(self, text="x"):
"""Constructor.
@param text: Name of the variable the threshold is on.
"""
self._threshold_ds = None
self._add_spin_buttons(text)
self._sb_min.set_sensitive(False)
self._sb_max.set_sensitive(False)
def set_threshold_dataset(self, ds):
"""Sets the dataset to threshold on.
@param ds: A dataset where one dimension corresponds to the select dimension
in the plot, and any other dimensions have length 1
"""
self._threshold_ds = ds
self._sb_min.set_sensitive(True)
self._sb_max.set_sensitive(True)
def _add_spin_buttons(self, text):
"""Adds spin buttons to the toolbar for selecting minimum and maximum
threshold values on information content."""
sb_min = gtk.SpinButton(digits=2)
sb_min.set_range(0, 100)
sb_min.set_value(0)
sb_min.set_increments(.1, 1.)
sb_min.connect('value-changed', self._on_value_changed)
self._sb_min = sb_min
sb_max = gtk.SpinButton(digits=2)
sb_max.set_range(0, 100)
sb_max.set_value(1)
sb_max.set_increments(.1, 1.)
sb_max.connect('value-changed', self._on_value_changed)
self._sb_max = sb_max
label = gtk.Label(" < %s < " % text)
hbox = gtk.HBox()
hbox.pack_start(sb_min)
hbox.pack_start(label)
hbox.pack_start(sb_max)
ti = gtk.ToolItem()
ti.set_expand(False)
ti.add(hbox)
sb_min.show()
sb_max.show()
label.show()
hbox.show()
ti.show()
self._toolbar.insert(ti, -1)
ti.set_tooltip(self._toolbar.tooltips, "Set threshold")
def set_threshold(self, min, max):
"""Sets min and max to the given values.
Updates the plot accordingly to show only values that have a
value within the boundaries. Other values are
also excluded from being selected from the plot.
@param ic_min Do not show nodes with IC below this value.
@param ic_max Do not show nodes with IC above this value.
"""
ds = self._threshold_ds
if ds == None:
return
icnodes = ds.existing_identifiers('go-terms', self._map_ids)
icindices = ds.get_indices('go-terms', icnodes)
a = ds.asarray()[icindices].sum(1)
good = set(scipy.array(icnodes)[(a>=min) & (a<=max)])
sizes = scipy.zeros(len(self._map_ids))
visible = set()
for i, n in enumerate(self._map_ids):
if n in good:
sizes[i] = 50
visible.add(n)
else:
sizes[i] = 0
self.visible = visible
self._mappable._sizes = sizes
self.canvas.draw()
def get_nodes_within_bounds(self):
"""Get a list of all nodes within the bounds of the selection in the
seleted dataset.
"""
pass
def filter_nodes(self, nodes):
"""Filter a list of nodes and return only those that are within the
threshold boundaries."""
pass
def _on_value_changed(self, sb):
"""Callback on spin button value changes."""
min = self._sb_min.get_value()
max = self._sb_max.get_value()
self.set_threshold(min, max)
# Create zoom-changed signal
gobject.signal_new('zoom-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
(gobject.TYPE_PYOBJECT,))
# Create pan/zoom-changed signal
gobject.signal_new('pan-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
(gobject.TYPE_PYOBJECT,))
# Create plot-resize-changed signal
gobject.signal_new('plot-resize-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
(gobject.TYPE_PYOBJECT,))