This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
laydi/fluents/lib/blmplots.py

459 lines
17 KiB
Python
Raw Normal View History

2006-12-18 12:59:12 +01:00
"""Specialised plots for functions defined in blmfuncs.py.
fixme:
-- If scatterplot is not inited with a colorvector there will be no
colorbar, but when adding colors the colorbar shoud be created.
"""
2007-03-14 17:17:21 +01:00
2007-07-26 20:26:50 +02:00
from matplotlib import cm,patches
2007-01-31 13:59:21 +01:00
import gtk
2007-03-14 17:17:21 +01:00
import fluents
2007-08-24 11:14:24 +02:00
from fluents import plots, main,logger
2007-03-14 17:17:21 +01:00
import scipy
2007-07-26 20:26:50 +02:00
from scipy import dot,sum,diag,arange,log,mean,newaxis,sqrt,apply_along_axis,empty
from scipy.stats import corrcoef
2007-07-26 20:26:50 +02:00
def correlation_loadings(data, T, test=True):
""" Returns correlation loadings.
:input:
- D: [nsamps, nvars], data (non-centered data)
- T: [nsamps, a_max], Scores
:ouput:
- R: [nvars, a_max], Correlation loadings
:notes:
"""
nsamps, nvars = data.shape
nsampsT, a_max = T.shape
if nsamps!=nsampsT: raise IOError("D/T mismatch")
# center
data = data - data.mean(0)
R = empty((nvars, a_max),'d')
for a in range(a_max):
for k in range(nvars):
R[k,a] = corrcoef(data[:,k], T[:,a])[0,1]
return R
2007-01-31 13:59:21 +01:00
class BlmScatterPlot(plots.ScatterPlot):
"""Scatter plot used for scores and loadings in bilinear models."""
def __init__(self, title, model, absi=0, ordi=1, part_name='T', color_by=None):
2007-08-14 18:12:28 +02:00
self.model = model
2007-01-31 13:59:21 +01:00
if model.model.has_key(part_name)!=True:
raise ValueError("Model part: %s not found in model" %mod_param)
self._T = model.model[part_name]
if self._T.shape[1]==1:
logger.log('notice', 'Scores have only one component')
absi= ordi = 0
self._absi = absi
self._ordi = ordi
self._cmap = cm.summer
2007-08-14 18:12:28 +02:00
2007-01-31 13:59:21 +01:00
dataset_1 = model.as_dataset(part_name)
2006-12-18 12:59:12 +01:00
id_dim = dataset_1.get_dim_name(0)
sel_dim = dataset_1.get_dim_name(1)
id_1, = dataset_1.get_identifiers(sel_dim, [absi])
id_2, = dataset_1.get_identifiers(sel_dim, [ordi])
2007-01-31 13:59:21 +01:00
col = 'b'
if model.model.has_key(color_by):
col = model.model[color_by].ravel()
plots.ScatterPlot.__init__(self, dataset_1, dataset_1, id_dim, sel_dim, id_1, id_2 ,c=col ,s=40 , name=title)
self._mappable.set_cmap(self._cmap)
self.sc = self._mappable
2007-01-31 13:59:21 +01:00
self.add_pc_spin_buttons(self._T.shape[1], absi, ordi)
2007-08-14 18:12:28 +02:00
2007-01-31 13:59:21 +01:00
def set_facecolor(self, colors):
"""Set patch facecolors.
"""
pass
def set_alphas(self, alphas):
"""Set alpha channel for all patches."""
pass
def set_sizes(self, sizes):
"""Set patch sizes."""
pass
2007-08-14 18:12:28 +02:00
2007-08-24 11:14:24 +02:00
def set_expvar_axlabels(self, param=None):
if param == None:
param = self._expvar_param
else:
self._expvar_param = param
2007-08-14 18:12:28 +02:00
if not self.model.model.has_key(param):
self.model.model[param] = None
if self.model.model[param]==None:
2007-08-24 11:14:24 +02:00
logger.log('notice', 'Param: %s not in model' %param)
print self.model.model.keys()
print self.model.model[param]
2007-08-14 18:12:28 +02:00
pass #fixme: do expvar calc here if not present
else:
expvar = self.model.model[param]
xstr = "Comp: %s , %.1f " %(self._absi, expvar[self._absi+1])
ystr = "Comp: %s , %.1f " %(self._ordi, expvar[self._ordi+1])
self.axes.set_xlabel(xstr)
self.axes.set_ylabel(ystr)
2006-12-18 12:59:12 +01:00
2007-01-31 13:59:21 +01:00
def add_pc_spin_buttons(self, amax, absi, ordi):
sb_a = gtk.SpinButton(climb_rate=1)
sb_a.set_range(1, amax)
2007-03-14 17:17:21 +01:00
sb_a.set_value(absi+1)
2007-01-31 13:59:21 +01:00
sb_a.set_increments(1, 5)
sb_a.connect('value_changed', self.set_absicca)
sb_o = gtk.SpinButton(climb_rate=1)
sb_o.set_range(1, amax)
2007-03-14 17:17:21 +01:00
sb_o.set_value(ordi+1)
2007-01-31 13:59:21 +01:00
sb_o.set_increments(1, 5)
sb_o.connect('value_changed', self.set_ordinate)
hbox = gtk.HBox()
gtk_label_a = gtk.Label("A:")
gtk_label_o = gtk.Label(" O:")
toolitem = gtk.ToolItem()
toolitem.set_expand(False)
toolitem.set_border_width(2)
toolitem.add(hbox)
hbox.pack_start(gtk_label_a)
hbox.pack_start(sb_a)
hbox.pack_start(gtk_label_o)
hbox.pack_start(sb_o)
self._toolbar.insert(toolitem, -1)
toolitem.set_tooltip(self._toolbar.tooltips, "Set Principal component")
self._toolbar.show_all() #do i need this?
def set_absicca(self, sb):
self._absi = sb.get_value_as_int() - 1
xy = self._T[:,[self._absi, self._ordi]]
self.xaxis_data = xy[:,0]
self.yaxis_data = xy[:,1]
self.sc._offsets = xy
2007-03-14 17:17:21 +01:00
self.selection_collection._offsets = xy
self.canvas.draw_idle()
pad = abs(self.xaxis_data.min()-self.xaxis_data.max())*0.05
2007-08-24 11:14:24 +02:00
new_lims = (self.xaxis_data.min() - pad, self.xaxis_data.max() + pad)
2007-03-14 17:17:21 +01:00
self.axes.set_xlim(new_lims, emit=True)
2007-08-14 18:12:28 +02:00
self.set_expvar_axlabels()
2007-03-14 17:17:21 +01:00
self.canvas.draw_idle()
2006-12-18 12:59:12 +01:00
2007-01-31 13:59:21 +01:00
def set_ordinate(self, sb):
self._ordi = sb.get_value_as_int() - 1
xy = self._T[:,[self._absi, self._ordi]]
self.xaxis_data = xy[:,0]
self.yaxis_data = xy[:,1]
self.sc._offsets = xy
2007-03-14 17:17:21 +01:00
self.selection_collection._offsets = xy
pad = abs(self.yaxis_data.min()-self.yaxis_data.max())*0.05
2007-08-24 11:14:24 +02:00
new_lims = (self.yaxis_data.min() - pad, self.yaxis_data.max() + pad)
2007-03-14 17:17:21 +01:00
self.axes.set_ylim(new_lims, emit=True)
2007-08-14 18:12:28 +02:00
self.set_expvar_axlabels()
2007-03-14 17:17:21 +01:00
self.canvas.draw_idle()
2007-01-31 13:59:21 +01:00
def show_labels(self, index=None):
if self._text_labels == None:
x = self.xaxis_data
y = self.yaxis_data
self._text_labels = {}
for name, n in self.dataset_1[self.current_dim].items():
2007-03-14 17:17:21 +01:00
txt = self.axes.text(x[n],y[n], name)
2007-01-31 13:59:21 +01:00
txt.set_visible(False)
self._text_labels[n] = txt
if index!=None:
self.hide_labels()
for indx,txt in self._text_labels.items():
if indx in index:
txt.set_visible(True)
2007-08-24 11:14:24 +02:00
self.canvas.draw_idle()
2007-01-31 13:59:21 +01:00
def hide_labels(self):
for txt in self._text_labels.values():
txt.set_visible(False)
2007-08-24 11:14:24 +02:00
self.canvas.draw_idle()
class PcaScreePlot(plots.BarPlot):
def __init__(self, model):
title = "Pca, (%s) Scree" %model._dataset['X'].get_name()
ds = model.as_dataset('eigvals')
if ds==None:
logger.log('notice', 'Model does not contain eigvals')
plots.BarPlot.__init__(self, ds, name=title)
2006-12-18 12:59:12 +01:00
2007-03-14 17:17:21 +01:00
2007-01-31 13:59:21 +01:00
class PcaScorePlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Pca scores (%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T')
2007-08-24 11:14:24 +02:00
self.set_expvar_axlabels(param="expvarx")
2006-12-18 12:59:12 +01:00
2007-01-31 13:59:21 +01:00
class PcaLoadingPlot(BlmScatterPlot):
2007-01-25 12:58:10 +01:00
def __init__(self, model, absi=0, ordi=1):
2007-01-31 13:59:21 +01:00
title = "Pca loadings (%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='p_tsq')
2007-08-24 13:48:29 +02:00
self.set_expvar_axlabels(param="expvarx")
2007-01-31 13:59:21 +01:00
class PlsScorePlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Pls scores (%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T')
2007-08-24 11:14:24 +02:00
class PlsXLoadingPlot(BlmScatterPlot):
2007-01-31 13:59:21 +01:00
def __init__(self, model, absi=0, ordi=1):
2007-08-24 11:14:24 +02:00
title = "Pls x-loadings (%s)" %model._dataset['X'].get_name()
2007-01-31 13:59:21 +01:00
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='w_tsq')
2007-08-14 18:12:28 +02:00
#self.set_expvar_axlabels(self, param="expvarx")
2006-12-18 12:59:12 +01:00
2007-08-24 11:14:24 +02:00
class PlsYLoadingPlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Pls y-loadings (%s)" %model._dataset['Y'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Q')
class PlsCorrelationLoadingPlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Pls correlation loadings (%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='CP')
2007-08-14 18:12:28 +02:00
class LplsScorePlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "L-pls scores (%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, 'T')
self.set_expvar_axlabels("evx")
2007-07-23 20:07:10 +02:00
class LplsXLoadingPlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Lpls x-loadings (%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='P', color_by='tsqx')
2007-08-14 18:12:28 +02:00
self.set_expvar_axlabels("evx")
2007-07-23 20:07:10 +02:00
class LplsZLoadingPlot(BlmScatterPlot, plots.PlotThresholder):
2007-07-23 20:07:10 +02:00
def __init__(self, model, absi=0, ordi=1):
title = "Lpls z-loadings (%s)" %model._dataset['Z'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='L', color_by='tsqz')
2007-08-14 18:12:28 +02:00
self.set_expvar_axlabels(param="evz")
plots.PlotThresholder.__init__(self, "IC")
def _update_color_from_dataset(self, ds):
BlmScatterPlot._update_color_from_dataset(self, ds)
self.set_threshold_dataset(ds)
2007-07-23 19:33:21 +02:00
2007-07-26 20:26:50 +02:00
class LplsXCorrelationPlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Lpls x-corr. loads (%s)" %model._dataset['X'].get_name()
if not model.model.has_key('Rx'):
R = correlation_loadings(model._data['X'], model.model['T'])
model.model['Rx'] = R
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rx')
2007-08-14 18:12:28 +02:00
self.set_expvar_axlabels("evx")
2007-07-26 20:26:50 +02:00
radius = 1
center = (0,0)
c100 = patches.Circle(center,radius=radius,
facecolor='gray',
alpha=.1,
zorder=1)
2007-07-30 20:04:42 +02:00
c50 = patches.Circle(center, radius= sqrt(radius/2.0),
2007-07-26 20:26:50 +02:00
facecolor='gray',
alpha=.1,
zorder=2)
self.axes.add_patch(c100)
self.axes.add_patch(c50)
self.axes.axhline(lw=1.5,color='k')
self.axes.axvline(lw=1.5,color='k')
self.axes.set_xlim([-1.05,1.05])
self.axes.set_ylim([-1.05, 1.05])
self.canvas.show()
class LplsZCorrelationPlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Lpls z-corr. loads (%s)" %model._dataset['Z'].get_name()
if not model.model.has_key('Rz'):
R = correlation_loadings(model._data['Z'].T, model.model['W'])
model.model['Rz'] = R
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='Rz')
2007-08-14 18:12:28 +02:00
self.set_expvar_axlabels("evz")
2007-07-26 20:26:50 +02:00
radius = 1
center = (0,0)
c100 = patches.Circle(center,radius=radius,
facecolor='gray',
alpha=.1,
zorder=1)
2007-07-30 20:04:42 +02:00
c50 = patches.Circle(center, radius=sqrt(radius/2.0),
2007-07-26 20:26:50 +02:00
facecolor='gray',
alpha=.1,
zorder=2)
self.axes.add_patch(c100)
self.axes.add_patch(c50)
self.axes.axhline(lw=1.5,color='k')
self.axes.axvline(lw=1.5,color='k')
self.axes.set_xlim([-1.05,1.05])
self.axes.set_ylim([-1.05, 1.05])
self.canvas.show()
2007-07-23 19:33:21 +02:00
class LplsHypoidCorrelationPlot(BlmScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Hypoid correlations(%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, part_name='W')
2007-07-26 20:26:50 +02:00
2007-08-14 18:12:28 +02:00
class LplsExplainedVariancePlot(plots.Plot):
def __init__(self, model):
self.model = model
plots.Plot.__init__(self, "Explained variance")
xax = scipy.arange(model.model['evx'].shape[0])
self.axes.plot(xax, model.model['evx'], 'b-', label='X', linewidth=1.5)
self.axes.plot(xax, model.model['evy'], 'k-', label='Y', linewidth=1.5)
self.axes.plot(xax, model.model['evz'], 'g-', label='Z', linewidth=1.5)
self.canvas.draw()
2006-12-18 12:59:12 +01:00
class LineViewXc(plots.LineViewPlot):
"""A line view of centered raw data
"""
2007-01-25 12:58:10 +01:00
def __init__(self, model, name='Profiles'):
2007-07-26 20:26:50 +02:00
dx = model._dataset['X']
plots.LineViewPlot.__init__(self, dx, 1, None, False,name)
self.add_center_check_button(self.data_is_centered)
def add_center_check_button(self, ticked):
"""Add a checker button for centerd view of data."""
cb = gtk.CheckButton("Center")
cb.set_active(ticked)
cb.connect('toggled', self._toggle_center)
toolitem = gtk.ToolItem()
toolitem.set_expand(False)
toolitem.set_border_width(2)
toolitem.add(cb)
self._toolbar.insert(toolitem, -1)
toolitem.set_tooltip(self._toolbar.tooltips, "Column center the line view")
self._toolbar.show_all() #do i need this?
def _toggle_center(self, active):
if self.data_is_centered:
self._data = self._data + self._mn_data
self.data_is_centered = False
else:
self._mn_data = self._data.mean(0)
self._data = self._data - self._mn_data
self.data_is_centered = True
self.make_lines()
self.set_background()
self.set_current_selection(main.project.get_selection())
2007-01-25 12:58:10 +01:00
2006-12-18 12:59:12 +01:00
class ParalellCoordinates(plots.Plot):
"""Parallell coordinates for score loads with many comp.
"""
2007-01-25 12:58:10 +01:00
def __init__(self, model, p='loads'):
2006-12-18 12:59:12 +01:00
pass
2007-01-25 12:58:10 +01:00
2006-12-18 12:59:12 +01:00
class PlsQvalScatter(plots.ScatterPlot):
"""A vulcano like plot of loads vs qvals
"""
2007-01-25 12:58:10 +01:00
def __init__(self, model, pc=0):
if not model.model.has_key('w_tsq'):
2007-01-31 13:59:21 +01:00
return None
2007-07-23 19:33:21 +02:00
self._W = model.model['W']
dataset_1 = model.as_dataset('W')
2007-01-25 12:58:10 +01:00
dataset_2 = model.as_dataset('w_tsq')
2006-12-18 12:59:12 +01:00
id_dim = dataset_1.get_dim_name(0) #genes
sel_dim = dataset_1.get_dim_name(1) #_comp
sel_dim_2 = dataset_2.get_dim_name(1) #_zero_dim
id_1, = dataset_1.get_identifiers(sel_dim, [0])
id_2, = dataset_2.get_identifiers(sel_dim_2, [0])
2007-01-25 12:58:10 +01:00
if model.model.has_key('w_tsq'):
col = model.model['w_tsq'].ravel()
2007-03-14 17:17:21 +01:00
#col = normalise(col)
2006-12-18 12:59:12 +01:00
else:
col = 'g'
plots.ScatterPlot.__init__(self, dataset_1, dataset_2,
id_dim, sel_dim, id_1, id_2,
c=col, s=20, sel_dim_2=sel_dim_2,
name='Load Volcano')
2007-01-31 13:59:21 +01:00
2007-01-25 12:58:10 +01:00
class PredictionErrorPlot(plots.Plot):
"""A boxplot of prediction error vs. comp. number.
"""
2007-01-31 13:59:21 +01:00
def __init__(self, model, name="Prediction Error"):
2007-01-25 12:58:10 +01:00
if not model.model.has_key('sep'):
logger.log('notice', 'Model has no calculations of sep')
2007-01-31 13:59:21 +01:00
return None
2007-01-25 12:58:10 +01:00
plots.Plot.__init__(self, name)
self._frozen = True
self.current_dim = 'johndoe'
2007-03-14 17:17:21 +01:00
self.axes = self.fig.add_subplot(111)
2007-01-25 12:58:10 +01:00
# draw
sep = model.model['sep']
aopt = model.model['aopt']
2007-03-14 17:17:21 +01:00
bx_plot_lines = self.axes.boxplot(sqrt(sep))
aopt_marker = self.axes.axvline(aopt, linewidth=10,
2007-01-25 12:58:10 +01:00
color='r',zorder=0,
alpha=.5)
# add canvas
self.add(self.canvas)
self.canvas.show()
2006-12-18 12:59:12 +01:00
2007-01-25 12:58:10 +01:00
def set_current_selection(self, selection):
pass
class TRBiplot(plots.ScatterPlot):
def __init__(self, model, absi=0, ordi=1):
title = "Target rotation biplot(%s)" %model._dataset['X'].get_name()
BlmScatterPlot.__init__(self, title, model, absi, ordi, 'B')
B = model.model.get('B')
# normalize B
Bnorm = scipy.apply_along_axis(scipy.linalg.norm, 1, B)
x = model._dataset['X'].copy()
Xc = x._array - mean(x._array,0)[newaxis]
w_rot = B/Bnorm
t_rot = dot(Xc, w_rot)
2007-01-25 12:58:10 +01:00
2006-12-18 12:59:12 +01:00
class InfluencePlot(plots.ScatterPlot):
2007-08-24 11:14:24 +02:00
""" Returns a leverage vs resiudal scatter plot.
2006-12-18 12:59:12 +01:00
"""
2007-08-24 11:14:24 +02:00
def __init__(self, model, dim, name="Influence"):
if not model.model.has_key('levx'):
logger.log('notice', 'Model has no calculations of leverages')
return
if not model.model.has_key('ssqx'):
logger.log('notice', 'Model has no calculations of residuals')
return
ds1 = model.as_dataset('levx')
ds2 = model.as_dataset('ssqx')
plots.ScatterPlot.__init__(self, ds1, ds2,
id_dim, sel_dim, id_1, id_2,
c=col, s=20, sel_dim_2=sel_dim_2,
name='Load Volcano')
2006-12-18 12:59:12 +01:00
2007-01-31 13:59:21 +01:00
class RMSEPPlot(plots.BarPlot):
def __init__(self, model, name="RMSEP"):
if not model.model.has_key('rmsep'):
logger.log('notice', 'Model has no calculations of sep')
return
dataset = model.as_dataset('rmsep')
plots.BarPlot.__init__(self, dataset, name=name)
2006-12-18 12:59:12 +01:00
def normalise(x):
"""Scale vector x to [0,1]
"""
x = x - x.min()
x = x/x.max()
return x