From 1091bea0e9252024323b9dcde33d431aa96aad92 Mon Sep 17 00:00:00 2001 From: flatberg Date: Wed, 19 Apr 2006 10:37:44 +0000 Subject: [PATCH] category data and plot selection update| --- system/annotations.py | 28 +++++++++++++++++++++ system/dataset.py | 58 ++++++++++++++++++++++++++++++++----------- system/plots.py | 24 ++++++++++++------ 3 files changed, 89 insertions(+), 21 deletions(-) diff --git a/system/annotations.py b/system/annotations.py index 17cc4bc..8c02b69 100644 --- a/system/annotations.py +++ b/system/annotations.py @@ -1,6 +1,9 @@ from sets import Set as set set.update = set.union_update +import dataset +import scipy + class AnnotationsException(Exception): pass @@ -63,3 +66,28 @@ class Annotations: """ return self.dimensions.has_key(dim) + def to_dataset(self,dim): + """ Returns a dataset representation of annotations. + """ + if self.has_dimension(dim): + num_dim1 = len(set(self.dimensions[dim])) #number of unique genes + all_genes = set(self.dimensions[dim]) + all_categories = set() + for cat in self.dimensions[dim].values(): + all_categories.update(cat) + num_dim1 = len(all_genes) #number of unique genes + num_dim2 = len(all_categories) #number of unique categories + gene_list=[] + cat_list=[] + matrix = scipy.zeros((num_dim1,num_dim2),'bwu') + for i,gene in enumerate(all_genes): + gene_list.append(gene) + for j,cat in enumerate(all_categories): + cat_list.append(cat) + matrix[i,j] = 1 + def_list = [['genes',gene_list],['go',cat_list]] + + return dataset.Dataset(matrix,def_list) + + + diff --git a/system/dataset.py b/system/dataset.py index 11ccdb3..f474b42 100644 --- a/system/dataset.py +++ b/system/dataset.py @@ -1,5 +1,5 @@ import logger -from scipy import array,take,asarray,shape +from scipy import array,take,asarray,shape,nonzero import project from itertools import izip @@ -10,19 +10,14 @@ class Dataset: A Dataset is an n-way array with defined string identifiers across all dimensions. """ - def __init__(self,input_array,def_list,parents=None): + def __init__(self,input_array,def_list): self._data = asarray(input_array) self.dims = shape(self._data) - self.parents = parents self.def_list = def_list self._ids_set = set() self.ids={} - self.children=[] self._dim_num = {} self._dim_names = [] - if parents!=None: - for parent in self.parents: - parent.children.append(self) if len(def_list)!=len(self.dims): raise ValueError,"array dims and identifyer mismatch" for axis,(dim_name,ids) in enumerate(def_list): @@ -37,11 +32,6 @@ class Dataset: self._ids_set = self._ids_set.union(set(ids)) self._dim_num[dim_name] = axis self._dim_names.append(dim_name) - #if dim_name in project.c_p.dim_names: - # if ids: - # # check that identifers are same as before - # raise NotImplementedError - # else: for df,d in izip(def_list,self.dims): #check that data and labels match df=df[1] @@ -75,10 +65,50 @@ class Dataset: n_dim = self.dims[axis] return [str(axis) + '_' + str(i) for i in range(n_dim)] - def index_to_id(self,dim_name,index): + def extract_id_from_index(self,dim_name,index): + """Returns a set of ids from array/list of indexes.""" dim_ids = self.ids[dim_name] - return [id for id,ind in dim_ids.items() if ind in index] + return set([id for id,ind in dim_ids.items() if ind in index]) + + def extract_index_from_id(self,dim_name,id): + """Returns an array of indexes from a set/list of identifiers + (or a single id)""" + dim_ids = self.ids[dim_name] + return array([ind for name,ind in dim_ids.items() if name in id]) + +class CategoryDataset(Dataset): + def __init__(self,array,def_list): + Dataset.__init__(self,array,def_list) + + def get_elements_by_category(self,dim,category): + """Returns all elements along input dim belonging to category. + Assumes a two-dim category data only! + """ + if type(category)!=list: + raise ValueError, "category must be list" + gene_ids = [] + axis_dim = self._dim_num[dim] + cat_index = self.extract_index_from_id(category) + for ind in cat_index: + if axis_dim==0: + gene_indx = nonzero(self._data[:,ind]) + elif axis_dim==1: + gene_indx = nonzero(self._data[ind,:]) + else: + ValueError, "Only support for 2-dim data" + gene_ids.append(self.extract_id_from_index(dim,gene_index)) + return gene_ids + + + + + + + + + + class Selection: """Handles selected identifiers along each dimension of a dataset""" def __init__(self): diff --git a/system/plots.py b/system/plots.py index b415414..a59d257 100644 --- a/system/plots.py +++ b/system/plots.py @@ -237,9 +237,9 @@ class ScatterPlot (Plot): self.ax = ax = fig.add_subplot(111) self.x_dataset = project.c_p.datasets[0] x = self.x_dataset._data - self.a = a = x[:,0] - self.b = b = x[:,1] - ax.plot(a,b,'og') + self.xaxis_data = xaxis_data = x[:,0] + self.yaxis_data = yaxis_data = x[:,1] + ax.plot(xaxis_data,yaxis_data,'og') self.canvas = FigureCanvas(fig) self.add(self.canvas) rectprops = dict(facecolor='blue', edgecolor = 'black', @@ -256,14 +256,24 @@ class ScatterPlot (Plot): logger.log('debug', "(%3.2f, %3.2f) --> (%3.2f, %3.2f)"%(x1,y1,x2,y2)) logger.log('debug',"The button you used were:%s, %s "%(event1.button, event2.button)) # get all points within x1, y1, x2, y2 - ydata = self.b - xdata = self.a - index =scipy.nonzero((xdatax1) & (ydatay2)) + ydata = self.yaxis_data + xdata = self.xaxis_data + if x1>x2: + if y1x2) & (ydata>y1) & (ydatax2) & (ydatay2)) + else: + if y1x2) & (xdatay1) & (ydatax2) & (xdatay2)) + if len(index)==0: logger.log('debug','No points selected!') else: logger.log('debug','Selected:\n%s'%index) - ids = self.x_dataset.index_to_id('samples',index) + ids = self.x_dataset.extract_id_from_index('samples',index) logger.log('debug','Selected identifiers:\n%s'%ids) xdata_new = scipy.take(xdata,index) ydata_new = scipy.take(ydata,index)