From 0b58c5ea286262796a5dccd518f68a92d3b4b2c0 Mon Sep 17 00:00:00 2001
From: flatberg <flatberg@pvv.ntnu.no>
Date: Thu, 31 Aug 2006 10:04:19 +0000
Subject: [PATCH] Switch to numpy (scipy_version>1.0) changes

---
 system/dataset.py          | 21 ++++++++++-----------
 system/plots.py            | 22 +++++++++-------------
 workflows/test_workflow.py |  6 +++---
 3 files changed, 22 insertions(+), 27 deletions(-)

diff --git a/system/dataset.py b/system/dataset.py
index 2c4bc49..9d39665 100644
--- a/system/dataset.py
+++ b/system/dataset.py
@@ -1,4 +1,4 @@
-from scipy import atleast_2d,asarray,ArrayType,shape,nonzero,io,transpose
+from scipy import ndarray,atleast_2d,asarray
 from scipy import sort as array_sort
 from itertools import izip
 import shelve
@@ -40,11 +40,11 @@ class Dataset:
         self._name = name
         self._identifiers = identifiers
         self._type = 'n'
-        if isinstance(array,ArrayType):
+        if isinstance(array,ndarray):
             array = atleast_2d(asarray(array))
             # vectors are column vectors 
             if array.shape[0]==1:
-                array = transpose(array)
+                array = array.T
             self.shape = array.shape
             if identifiers!=None:
                 self._set_identifiers(identifiers,self._all_dims)
@@ -55,10 +55,7 @@ class Dataset:
             self._array = array
             
         else:
-            raise ValueError, "Array input must be of ArrayType"
-                        
-    #def __str__(self):
-    #    return self._name + ":\n" + "Dim names: " +  self._dims.__str__()
+            raise ValueError, "Array input must be of type ndarray"
 
     def __iter__(self):
         """Returns an iterator over dimensions of dataset."""
@@ -177,7 +174,9 @@ class Dataset:
         
         
         Identifiers are the unique names (strings) for a variable in a given dim.
-        Index (Indices) are the Identifiers position in a matrix in a given dim."""
+        Index (Indices) are the Identifiers position in a matrix in a given dim.
+        If none of the input identifiers are found an empty index is returned
+        """
         if idents==None:
             index = array_sort(self._map[dim].values())
         else:
@@ -217,7 +216,7 @@ class CategoryDataset(Dataset):
         """
         data={}
         for name,ind in self._map[self.get_dim_name(0)].items():
-            data[name] = self.get_identifiers(self.get_dim_name(1),list(nonzero(self._array[ind,:])))
+            data[name] = self.get_identifiers(self.get_dim_name(1),list(self._array[ind,:].nonzero()))
         self._dictlists = data
         self.has_dictlists = True
         return data
@@ -227,7 +226,7 @@ class CategoryDataset(Dataset):
         """
         ret_list = []
         for cat_name,ind in self._map[self.get_dim_name(1)].items():
-            ids = self.get_identifiers(self.get_dim_name(0),nonzero(self._array[:,ind]))
+            ids = self.get_identifiers(self.get_dim_name(0),self._array[:,ind].nonzero())
             selection = Selection(cat_name)
             selection.select(cat_name,ids)
             ret_list.append(selection)
@@ -263,7 +262,7 @@ class GraphDataset(Dataset):
         
         """
         import networkx as nx
-        m,n = shape(A)# adjacency matrix must be of type that evals to true/false for neigbours
+        m,n = A.shape# adjacency matrix must be of type that evals to true/false for neigbours
         if m!=n:
             raise IOError, "Adjacency matrix must be square"
         if nx_type=='graph':
diff --git a/system/plots.py b/system/plots.py
index d74b55d..cb66c58 100644
--- a/system/plots.py
+++ b/system/plots.py
@@ -418,7 +418,7 @@ class LineViewPlot(Plot):
         self.line_collection = {}
         x_axis = scipy.arrayrange(self._data.shape[minor_axis])
         for i in range(self._data.shape[major_axis]):
-            yi = scipy.take(self._data,[i],axis=major_axis)
+            yi = self._data.take([i],major_axis)
             if self.use_blit:
                 l,=self.ax.plot(x_axis,yi,'k',alpha=.05,animated=True)
             else:
@@ -513,16 +513,15 @@ has no color and size options."""
             y1, y2 = y2, y1
         assert x1<=x2
         assert y1<=y2
-
-        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))
+        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
         ids = self.dataset_1.get_identifiers(self.current_dim, index)
         self.selection_listener(self.current_dim, ids)
 
     def set_current_selection(self, selection):
         ids = selection[self.current_dim] # current identifiers
         index = self.dataset_1.get_indices(self.current_dim, ids)
-        xdata_new = scipy.take(self.xaxis_data, index) #take data
-        ydata_new = scipy.take(self.yaxis_data, index)
+        xdata_new = self.xaxis_data.take(index) #take data
+        ydata_new = self.yaxis_data.take(index)
         #remove old selection
         if self._selection_line:
             self.ax.lines.remove(self._selection_line)
@@ -545,11 +544,11 @@ has no color and size options."""
         y_index = dataset_2[sel_dim][id_2]
         self.xaxis_data = dataset_1._array[:,x_index]
         self.yaxis_data = dataset_2._array[:,y_index]
-        lw = scipy.zeros(self.xaxis_data.shape,'f')
+        lw = scipy.zeros(self.xaxis_data.shape)
         self.ax.scatter(self.xaxis_data,self.yaxis_data,s=s,c=c,linewidth=lw,edgecolor='k',alpha=.6,cmap = cm.Set1)
         self.ax.set_title(self.get_title())
         # collection
-        self.coll = ax.collections[0]
+        self.coll = self.ax.collections[0]
 
         # add canvas to widget
         self.add(self.canvas)
@@ -575,22 +574,19 @@ has no color and size options."""
         assert x1<=x2
         assert y1<=y2
 
-        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))
-
+        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
         ids = self.dataset_1.get_identifiers(self.current_dim, index)
         self.selection_listener(self.current_dim, ids)
 
     def set_current_selection(self, selection):
         ids = selection[self.current_dim] # current identifiers
         index = self.dataset_1.get_indices(self.current_dim, ids)
-        lw = scipy.zeros(self.xaxis_data.shape,'f')
+        lw = scipy.zeros(self.xaxis_data.shape)
         scipy.put(lw,index,2.)
-        zo = lw.copy() + 1 #z-order, selected on top
         self.coll.set_linewidth(lw)
-        self.coll.set_zorder(zo)
         self._toolbar.forward() #update data lims before draw
         self.canvas.draw()
-
+        
         
 class NetworkPlot(Plot):
     def __init__(self, dataset, **kw):
diff --git a/workflows/test_workflow.py b/workflows/test_workflow.py
index 799060b..272a274 100644
--- a/workflows/test_workflow.py
+++ b/workflows/test_workflow.py
@@ -102,10 +102,10 @@ class TestDataFunction(workflow.Function):
 
     def run(self):
         logger.log('notice', 'Injecting foo test data')
-        x = randn(500,30)
+        x = randn(5000,4)
         X = dataset.Dataset(x)
-        p = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='p')
-        p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='p2')
+        p = plots.ScatterPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='scatter')
+        p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='marker')
         graph = networkx.XGraph()
         for x in 'ABCDEF':
             for y in 'ADE':