Projects/laydi
Projects
/
laydi
Archived
7
0
Fork 0

Switch to numpy (scipy_version>1.0) changes

This commit is contained in:
Arnar Flatberg 2006-08-31 10:04:19 +00:00
parent a0786d521a
commit 0b58c5ea28
3 changed files with 22 additions and 27 deletions

View File

@ -1,4 +1,4 @@
from scipy import atleast_2d,asarray,ArrayType,shape,nonzero,io,transpose from scipy import ndarray,atleast_2d,asarray
from scipy import sort as array_sort from scipy import sort as array_sort
from itertools import izip from itertools import izip
import shelve import shelve
@ -40,11 +40,11 @@ class Dataset:
self._name = name self._name = name
self._identifiers = identifiers self._identifiers = identifiers
self._type = 'n' self._type = 'n'
if isinstance(array,ArrayType): if isinstance(array,ndarray):
array = atleast_2d(asarray(array)) array = atleast_2d(asarray(array))
# vectors are column vectors # vectors are column vectors
if array.shape[0]==1: if array.shape[0]==1:
array = transpose(array) array = array.T
self.shape = array.shape self.shape = array.shape
if identifiers!=None: if identifiers!=None:
self._set_identifiers(identifiers,self._all_dims) self._set_identifiers(identifiers,self._all_dims)
@ -55,10 +55,7 @@ class Dataset:
self._array = array self._array = array
else: else:
raise ValueError, "Array input must be of ArrayType" raise ValueError, "Array input must be of type ndarray"
#def __str__(self):
# return self._name + ":\n" + "Dim names: " + self._dims.__str__()
def __iter__(self): def __iter__(self):
"""Returns an iterator over dimensions of dataset.""" """Returns an iterator over dimensions of dataset."""
@ -177,7 +174,9 @@ class Dataset:
Identifiers are the unique names (strings) for a variable in a given dim. Identifiers are the unique names (strings) for a variable in a given dim.
Index (Indices) are the Identifiers position in a matrix in a given dim.""" Index (Indices) are the Identifiers position in a matrix in a given dim.
If none of the input identifiers are found an empty index is returned
"""
if idents==None: if idents==None:
index = array_sort(self._map[dim].values()) index = array_sort(self._map[dim].values())
else: else:
@ -217,7 +216,7 @@ class CategoryDataset(Dataset):
""" """
data={} data={}
for name,ind in self._map[self.get_dim_name(0)].items(): for name,ind in self._map[self.get_dim_name(0)].items():
data[name] = self.get_identifiers(self.get_dim_name(1),list(nonzero(self._array[ind,:]))) data[name] = self.get_identifiers(self.get_dim_name(1),list(self._array[ind,:].nonzero()))
self._dictlists = data self._dictlists = data
self.has_dictlists = True self.has_dictlists = True
return data return data
@ -227,7 +226,7 @@ class CategoryDataset(Dataset):
""" """
ret_list = [] ret_list = []
for cat_name,ind in self._map[self.get_dim_name(1)].items(): for cat_name,ind in self._map[self.get_dim_name(1)].items():
ids = self.get_identifiers(self.get_dim_name(0),nonzero(self._array[:,ind])) ids = self.get_identifiers(self.get_dim_name(0),self._array[:,ind].nonzero())
selection = Selection(cat_name) selection = Selection(cat_name)
selection.select(cat_name,ids) selection.select(cat_name,ids)
ret_list.append(selection) ret_list.append(selection)
@ -263,7 +262,7 @@ class GraphDataset(Dataset):
""" """
import networkx as nx import networkx as nx
m,n = shape(A)# adjacency matrix must be of type that evals to true/false for neigbours m,n = A.shape# adjacency matrix must be of type that evals to true/false for neigbours
if m!=n: if m!=n:
raise IOError, "Adjacency matrix must be square" raise IOError, "Adjacency matrix must be square"
if nx_type=='graph': if nx_type=='graph':

View File

@ -418,7 +418,7 @@ class LineViewPlot(Plot):
self.line_collection = {} self.line_collection = {}
x_axis = scipy.arrayrange(self._data.shape[minor_axis]) x_axis = scipy.arrayrange(self._data.shape[minor_axis])
for i in range(self._data.shape[major_axis]): for i in range(self._data.shape[major_axis]):
yi = scipy.take(self._data,[i],axis=major_axis) yi = self._data.take([i],major_axis)
if self.use_blit: if self.use_blit:
l,=self.ax.plot(x_axis,yi,'k',alpha=.05,animated=True) l,=self.ax.plot(x_axis,yi,'k',alpha=.05,animated=True)
else: else:
@ -513,16 +513,15 @@ has no color and size options."""
y1, y2 = y2, y1 y1, y2 = y2, y1
assert x1<=x2 assert x1<=x2
assert y1<=y2 assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))
ids = self.dataset_1.get_identifiers(self.current_dim, index) ids = self.dataset_1.get_identifiers(self.current_dim, index)
self.selection_listener(self.current_dim, ids) self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection): def set_current_selection(self, selection):
ids = selection[self.current_dim] # current identifiers ids = selection[self.current_dim] # current identifiers
index = self.dataset_1.get_indices(self.current_dim, ids) index = self.dataset_1.get_indices(self.current_dim, ids)
xdata_new = scipy.take(self.xaxis_data, index) #take data xdata_new = self.xaxis_data.take(index) #take data
ydata_new = scipy.take(self.yaxis_data, index) ydata_new = self.yaxis_data.take(index)
#remove old selection #remove old selection
if self._selection_line: if self._selection_line:
self.ax.lines.remove(self._selection_line) self.ax.lines.remove(self._selection_line)
@ -545,11 +544,11 @@ has no color and size options."""
y_index = dataset_2[sel_dim][id_2] y_index = dataset_2[sel_dim][id_2]
self.xaxis_data = dataset_1._array[:,x_index] self.xaxis_data = dataset_1._array[:,x_index]
self.yaxis_data = dataset_2._array[:,y_index] self.yaxis_data = dataset_2._array[:,y_index]
lw = scipy.zeros(self.xaxis_data.shape,'f') lw = scipy.zeros(self.xaxis_data.shape)
self.ax.scatter(self.xaxis_data,self.yaxis_data,s=s,c=c,linewidth=lw,edgecolor='k',alpha=.6,cmap = cm.Set1) self.ax.scatter(self.xaxis_data,self.yaxis_data,s=s,c=c,linewidth=lw,edgecolor='k',alpha=.6,cmap = cm.Set1)
self.ax.set_title(self.get_title()) self.ax.set_title(self.get_title())
# collection # collection
self.coll = ax.collections[0] self.coll = self.ax.collections[0]
# add canvas to widget # add canvas to widget
self.add(self.canvas) self.add(self.canvas)
@ -575,22 +574,19 @@ has no color and size options."""
assert x1<=x2 assert x1<=x2
assert y1<=y2 assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2)) index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
ids = self.dataset_1.get_identifiers(self.current_dim, index) ids = self.dataset_1.get_identifiers(self.current_dim, index)
self.selection_listener(self.current_dim, ids) self.selection_listener(self.current_dim, ids)
def set_current_selection(self, selection): def set_current_selection(self, selection):
ids = selection[self.current_dim] # current identifiers ids = selection[self.current_dim] # current identifiers
index = self.dataset_1.get_indices(self.current_dim, ids) index = self.dataset_1.get_indices(self.current_dim, ids)
lw = scipy.zeros(self.xaxis_data.shape,'f') lw = scipy.zeros(self.xaxis_data.shape)
scipy.put(lw,index,2.) scipy.put(lw,index,2.)
zo = lw.copy() + 1 #z-order, selected on top
self.coll.set_linewidth(lw) self.coll.set_linewidth(lw)
self.coll.set_zorder(zo)
self._toolbar.forward() #update data lims before draw self._toolbar.forward() #update data lims before draw
self.canvas.draw() self.canvas.draw()
class NetworkPlot(Plot): class NetworkPlot(Plot):
def __init__(self, dataset, **kw): def __init__(self, dataset, **kw):

View File

@ -102,10 +102,10 @@ class TestDataFunction(workflow.Function):
def run(self): def run(self):
logger.log('notice', 'Injecting foo test data') logger.log('notice', 'Injecting foo test data')
x = randn(500,30) x = randn(5000,4)
X = dataset.Dataset(x) X = dataset.Dataset(x)
p = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='p') p = plots.ScatterPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='scatter')
p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='p2') p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='marker')
graph = networkx.XGraph() graph = networkx.XGraph()
for x in 'ABCDEF': for x in 'ABCDEF':
for y in 'ADE': for y in 'ADE':