From 0b58c5ea286262796a5dccd518f68a92d3b4b2c0 Mon Sep 17 00:00:00 2001 From: flatberg <flatberg@pvv.ntnu.no> Date: Thu, 31 Aug 2006 10:04:19 +0000 Subject: [PATCH] Switch to numpy (scipy_version>1.0) changes --- system/dataset.py | 21 ++++++++++----------- system/plots.py | 22 +++++++++------------- workflows/test_workflow.py | 6 +++--- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/system/dataset.py b/system/dataset.py index 2c4bc49..9d39665 100644 --- a/system/dataset.py +++ b/system/dataset.py @@ -1,4 +1,4 @@ -from scipy import atleast_2d,asarray,ArrayType,shape,nonzero,io,transpose +from scipy import ndarray,atleast_2d,asarray from scipy import sort as array_sort from itertools import izip import shelve @@ -40,11 +40,11 @@ class Dataset: self._name = name self._identifiers = identifiers self._type = 'n' - if isinstance(array,ArrayType): + if isinstance(array,ndarray): array = atleast_2d(asarray(array)) # vectors are column vectors if array.shape[0]==1: - array = transpose(array) + array = array.T self.shape = array.shape if identifiers!=None: self._set_identifiers(identifiers,self._all_dims) @@ -55,10 +55,7 @@ class Dataset: self._array = array else: - raise ValueError, "Array input must be of ArrayType" - - #def __str__(self): - # return self._name + ":\n" + "Dim names: " + self._dims.__str__() + raise ValueError, "Array input must be of type ndarray" def __iter__(self): """Returns an iterator over dimensions of dataset.""" @@ -177,7 +174,9 @@ class Dataset: Identifiers are the unique names (strings) for a variable in a given dim. - Index (Indices) are the Identifiers position in a matrix in a given dim.""" + Index (Indices) are the Identifiers position in a matrix in a given dim. + If none of the input identifiers are found an empty index is returned + """ if idents==None: index = array_sort(self._map[dim].values()) else: @@ -217,7 +216,7 @@ class CategoryDataset(Dataset): """ data={} for name,ind in self._map[self.get_dim_name(0)].items(): - data[name] = self.get_identifiers(self.get_dim_name(1),list(nonzero(self._array[ind,:]))) + data[name] = self.get_identifiers(self.get_dim_name(1),list(self._array[ind,:].nonzero())) self._dictlists = data self.has_dictlists = True return data @@ -227,7 +226,7 @@ class CategoryDataset(Dataset): """ ret_list = [] for cat_name,ind in self._map[self.get_dim_name(1)].items(): - ids = self.get_identifiers(self.get_dim_name(0),nonzero(self._array[:,ind])) + ids = self.get_identifiers(self.get_dim_name(0),self._array[:,ind].nonzero()) selection = Selection(cat_name) selection.select(cat_name,ids) ret_list.append(selection) @@ -263,7 +262,7 @@ class GraphDataset(Dataset): """ import networkx as nx - m,n = shape(A)# adjacency matrix must be of type that evals to true/false for neigbours + m,n = A.shape# adjacency matrix must be of type that evals to true/false for neigbours if m!=n: raise IOError, "Adjacency matrix must be square" if nx_type=='graph': diff --git a/system/plots.py b/system/plots.py index d74b55d..cb66c58 100644 --- a/system/plots.py +++ b/system/plots.py @@ -418,7 +418,7 @@ class LineViewPlot(Plot): self.line_collection = {} x_axis = scipy.arrayrange(self._data.shape[minor_axis]) for i in range(self._data.shape[major_axis]): - yi = scipy.take(self._data,[i],axis=major_axis) + yi = self._data.take([i],major_axis) if self.use_blit: l,=self.ax.plot(x_axis,yi,'k',alpha=.05,animated=True) else: @@ -513,16 +513,15 @@ has no color and size options.""" y1, y2 = y2, y1 assert x1<=x2 assert y1<=y2 - - index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2)) + index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0] ids = self.dataset_1.get_identifiers(self.current_dim, index) self.selection_listener(self.current_dim, ids) def set_current_selection(self, selection): ids = selection[self.current_dim] # current identifiers index = self.dataset_1.get_indices(self.current_dim, ids) - xdata_new = scipy.take(self.xaxis_data, index) #take data - ydata_new = scipy.take(self.yaxis_data, index) + xdata_new = self.xaxis_data.take(index) #take data + ydata_new = self.yaxis_data.take(index) #remove old selection if self._selection_line: self.ax.lines.remove(self._selection_line) @@ -545,11 +544,11 @@ has no color and size options.""" y_index = dataset_2[sel_dim][id_2] self.xaxis_data = dataset_1._array[:,x_index] self.yaxis_data = dataset_2._array[:,y_index] - lw = scipy.zeros(self.xaxis_data.shape,'f') + lw = scipy.zeros(self.xaxis_data.shape) self.ax.scatter(self.xaxis_data,self.yaxis_data,s=s,c=c,linewidth=lw,edgecolor='k',alpha=.6,cmap = cm.Set1) self.ax.set_title(self.get_title()) # collection - self.coll = ax.collections[0] + self.coll = self.ax.collections[0] # add canvas to widget self.add(self.canvas) @@ -575,22 +574,19 @@ has no color and size options.""" assert x1<=x2 assert y1<=y2 - index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2)) - + index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0] ids = self.dataset_1.get_identifiers(self.current_dim, index) self.selection_listener(self.current_dim, ids) def set_current_selection(self, selection): ids = selection[self.current_dim] # current identifiers index = self.dataset_1.get_indices(self.current_dim, ids) - lw = scipy.zeros(self.xaxis_data.shape,'f') + lw = scipy.zeros(self.xaxis_data.shape) scipy.put(lw,index,2.) - zo = lw.copy() + 1 #z-order, selected on top self.coll.set_linewidth(lw) - self.coll.set_zorder(zo) self._toolbar.forward() #update data lims before draw self.canvas.draw() - + class NetworkPlot(Plot): def __init__(self, dataset, **kw): diff --git a/workflows/test_workflow.py b/workflows/test_workflow.py index 799060b..272a274 100644 --- a/workflows/test_workflow.py +++ b/workflows/test_workflow.py @@ -102,10 +102,10 @@ class TestDataFunction(workflow.Function): def run(self): logger.log('notice', 'Injecting foo test data') - x = randn(500,30) + x = randn(5000,4) X = dataset.Dataset(x) - p = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='p') - p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='p2') + p = plots.ScatterPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='scatter') + p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='marker') graph = networkx.XGraph() for x in 'ABCDEF': for y in 'ADE':