category data and plot selection update|
This commit is contained in:
		| @@ -1,6 +1,9 @@ | |||||||
|  |  | ||||||
| from sets import Set as set | from sets import Set as set | ||||||
| set.update = set.union_update | set.update = set.union_update | ||||||
|  | import dataset | ||||||
|  | import scipy | ||||||
|  |  | ||||||
|  |  | ||||||
| class AnnotationsException(Exception): | class AnnotationsException(Exception): | ||||||
|     pass |     pass | ||||||
| @@ -63,3 +66,28 @@ class Annotations: | |||||||
|         """ |         """ | ||||||
|         return self.dimensions.has_key(dim) |         return self.dimensions.has_key(dim) | ||||||
|  |  | ||||||
|  |     def to_dataset(self,dim): | ||||||
|  |         """ Returns a dataset representation of annotations. | ||||||
|  |         """ | ||||||
|  |         if self.has_dimension(dim): | ||||||
|  |             num_dim1 = len(set(self.dimensions[dim])) #number of unique genes | ||||||
|  |             all_genes = set(self.dimensions[dim]) | ||||||
|  |             all_categories = set() | ||||||
|  |             for cat in self.dimensions[dim].values(): | ||||||
|  |                 all_categories.update(cat) | ||||||
|  |             num_dim1 = len(all_genes) #number of unique genes | ||||||
|  |             num_dim2 = len(all_categories) #number of unique categories | ||||||
|  |             gene_list=[] | ||||||
|  |             cat_list=[] | ||||||
|  |         matrix = scipy.zeros((num_dim1,num_dim2),'bwu') | ||||||
|  |         for i,gene in enumerate(all_genes): | ||||||
|  |             gene_list.append(gene) | ||||||
|  |             for j,cat in enumerate(all_categories): | ||||||
|  |                 cat_list.append(cat) | ||||||
|  |                 matrix[i,j] = 1 | ||||||
|  |         def_list = [['genes',gene_list],['go',cat_list]] | ||||||
|  |  | ||||||
|  |         return dataset.Dataset(matrix,def_list) | ||||||
|  |          | ||||||
|  |          | ||||||
|  |      | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| import logger | import logger | ||||||
| from scipy import array,take,asarray,shape | from scipy import array,take,asarray,shape,nonzero | ||||||
| import project | import project | ||||||
| from itertools import izip  | from itertools import izip  | ||||||
|  |  | ||||||
| @@ -10,19 +10,14 @@ class Dataset: | |||||||
|     A Dataset is an n-way array with defined string identifiers across |     A Dataset is an n-way array with defined string identifiers across | ||||||
|     all dimensions. |     all dimensions. | ||||||
|     """ |     """ | ||||||
|     def __init__(self,input_array,def_list,parents=None): |     def __init__(self,input_array,def_list): | ||||||
|         self._data = asarray(input_array) |         self._data = asarray(input_array) | ||||||
|         self.dims = shape(self._data) |         self.dims = shape(self._data) | ||||||
|         self.parents = parents |  | ||||||
|         self.def_list = def_list |         self.def_list = def_list | ||||||
|         self._ids_set = set() |         self._ids_set = set() | ||||||
|         self.ids={} |         self.ids={} | ||||||
|         self.children=[] |  | ||||||
|         self._dim_num = {} |         self._dim_num = {} | ||||||
|         self._dim_names = [] |         self._dim_names = [] | ||||||
|         if parents!=None: |  | ||||||
|             for parent in self.parents: |  | ||||||
|                 parent.children.append(self) |  | ||||||
|         if len(def_list)!=len(self.dims): |         if len(def_list)!=len(self.dims): | ||||||
|             raise ValueError,"array dims and identifyer mismatch" |             raise ValueError,"array dims and identifyer mismatch" | ||||||
|         for axis,(dim_name,ids) in enumerate(def_list): |         for axis,(dim_name,ids) in enumerate(def_list): | ||||||
| @@ -37,11 +32,6 @@ class Dataset: | |||||||
|             self._ids_set = self._ids_set.union(set(ids)) |             self._ids_set = self._ids_set.union(set(ids)) | ||||||
|             self._dim_num[dim_name] = axis |             self._dim_num[dim_name] = axis | ||||||
|             self._dim_names.append(dim_name) |             self._dim_names.append(dim_name) | ||||||
|             #if dim_name in project.c_p.dim_names: |  | ||||||
|             #    if ids: |  | ||||||
|             #        # check that identifers are same as before |  | ||||||
|             #        raise NotImplementedError |  | ||||||
|             #    else: |  | ||||||
|                      |                      | ||||||
|         for df,d in izip(def_list,self.dims): #check that data and labels match  |         for df,d in izip(def_list,self.dims): #check that data and labels match  | ||||||
|             df=df[1] |             df=df[1] | ||||||
| @@ -75,10 +65,50 @@ class Dataset: | |||||||
|         n_dim = self.dims[axis] |         n_dim = self.dims[axis] | ||||||
|         return [str(axis) + '_' + str(i) for i in range(n_dim)] |         return [str(axis) + '_' + str(i) for i in range(n_dim)] | ||||||
|  |  | ||||||
|     def index_to_id(self,dim_name,index): |     def extract_id_from_index(self,dim_name,index): | ||||||
|  |         """Returns a set of ids from array/list of indexes.""" | ||||||
|         dim_ids = self.ids[dim_name] |         dim_ids = self.ids[dim_name] | ||||||
|         return [id for id,ind in dim_ids.items() if ind in index] |         return set([id for id,ind in dim_ids.items() if ind in index]) | ||||||
|  |  | ||||||
|  |     def extract_index_from_id(self,dim_name,id): | ||||||
|  |         """Returns an array of indexes from a set/list of identifiers | ||||||
|  |         (or a single id)""" | ||||||
|  |         dim_ids = self.ids[dim_name] | ||||||
|  |         return array([ind for name,ind in dim_ids.items() if name in id]) | ||||||
|  |          | ||||||
|      |      | ||||||
|  | class CategoryDataset(Dataset): | ||||||
|  |     def __init__(self,array,def_list): | ||||||
|  |         Dataset.__init__(self,array,def_list) | ||||||
|  |  | ||||||
|  |     def get_elements_by_category(self,dim,category): | ||||||
|  |         """Returns all elements along input dim belonging to category. | ||||||
|  |         Assumes a two-dim category data only! | ||||||
|  |         """ | ||||||
|  |         if type(category)!=list: | ||||||
|  |             raise ValueError, "category must be list" | ||||||
|  |         gene_ids = [] | ||||||
|  |         axis_dim = self._dim_num[dim] | ||||||
|  |         cat_index = self.extract_index_from_id(category) | ||||||
|  |         for ind in cat_index: | ||||||
|  |             if axis_dim==0: | ||||||
|  |                 gene_indx = nonzero(self._data[:,ind]) | ||||||
|  |             elif axis_dim==1: | ||||||
|  |                 gene_indx = nonzero(self._data[ind,:]) | ||||||
|  |             else: | ||||||
|  |                 ValueError, "Only support for 2-dim data" | ||||||
|  |             gene_ids.append(self.extract_id_from_index(dim,gene_index)) | ||||||
|  |         return gene_ids | ||||||
|  |                  | ||||||
|  |          | ||||||
|  |          | ||||||
|  |                  | ||||||
|  |              | ||||||
|  |          | ||||||
|  |          | ||||||
|  |  | ||||||
|  |      | ||||||
|  |  | ||||||
| class Selection: | class Selection: | ||||||
|     """Handles selected identifiers along each dimension of a dataset""" |     """Handles selected identifiers along each dimension of a dataset""" | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|   | |||||||
| @@ -237,9 +237,9 @@ class ScatterPlot (Plot): | |||||||
|         self.ax = ax = fig.add_subplot(111) |         self.ax = ax = fig.add_subplot(111) | ||||||
|         self.x_dataset = project.c_p.datasets[0] |         self.x_dataset = project.c_p.datasets[0] | ||||||
|         x = self.x_dataset._data |         x = self.x_dataset._data | ||||||
|         self.a = a =  x[:,0] |         self.xaxis_data = xaxis_data =  x[:,0] | ||||||
|         self.b = b = x[:,1] |         self.yaxis_data = yaxis_data = x[:,1] | ||||||
|         ax.plot(a,b,'og') |         ax.plot(xaxis_data,yaxis_data,'og') | ||||||
|         self.canvas = FigureCanvas(fig) |         self.canvas = FigureCanvas(fig) | ||||||
|         self.add(self.canvas) |         self.add(self.canvas) | ||||||
|         rectprops = dict(facecolor='blue', edgecolor = 'black', |         rectprops = dict(facecolor='blue', edgecolor = 'black', | ||||||
| @@ -256,14 +256,24 @@ class ScatterPlot (Plot): | |||||||
|         logger.log('debug', "(%3.2f, %3.2f) --> (%3.2f, %3.2f)"%(x1,y1,x2,y2)) |         logger.log('debug', "(%3.2f, %3.2f) --> (%3.2f, %3.2f)"%(x1,y1,x2,y2)) | ||||||
|         logger.log('debug',"The button you used were:%s, %s "%(event1.button, event2.button)) |         logger.log('debug',"The button you used were:%s, %s "%(event1.button, event2.button)) | ||||||
|         # get all points within x1, y1, x2, y2 |         # get all points within x1, y1, x2, y2 | ||||||
|         ydata = self.b |         ydata = self.yaxis_data | ||||||
|         xdata = self.a |         xdata = self.xaxis_data | ||||||
|         index =scipy.nonzero((xdata<x2) & (xdata>x1) & (ydata<y1) & (ydata>y2)) |         if x1>x2: | ||||||
|  |             if y1<y2: | ||||||
|  |                 index =scipy.nonzero((xdata<x1) & (xdata>x2) & (ydata>y1) & (ydata<y2)) | ||||||
|  |             else: | ||||||
|  |                 index =scipy.nonzero((xdata<x1) & (xdata>x2) & (ydata<y1) & (ydata>y2)) | ||||||
|  |         else: | ||||||
|  |             if y1<y2: | ||||||
|  |                 index =scipy.nonzero((xdata>x2) & (xdata<x1) & (ydata>y1) & (ydata<y2)) | ||||||
|  |             else: | ||||||
|  |                 index =scipy.nonzero((xdata>x2) & (xdata<x1) & (ydata<y1) & (ydata>y2)) | ||||||
|  |          | ||||||
|         if len(index)==0: |         if len(index)==0: | ||||||
|             logger.log('debug','No points selected!') |             logger.log('debug','No points selected!') | ||||||
|         else: |         else: | ||||||
|             logger.log('debug','Selected:\n%s'%index) |             logger.log('debug','Selected:\n%s'%index) | ||||||
|             ids = self.x_dataset.index_to_id('samples',index) |             ids = self.x_dataset.extract_id_from_index('samples',index) | ||||||
|             logger.log('debug','Selected identifiers:\n%s'%ids) |             logger.log('debug','Selected identifiers:\n%s'%ids) | ||||||
|             xdata_new = scipy.take(xdata,index) |             xdata_new = scipy.take(xdata,index) | ||||||
|             ydata_new = scipy.take(ydata,index) |             ydata_new = scipy.take(ydata,index) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user