diff --git a/system/dataset.py b/system/dataset.py index 35013bd..379e0e6 100644 --- a/system/dataset.py +++ b/system/dataset.py @@ -2,6 +2,7 @@ from scipy import ndarray,atleast_2d,asarray from scipy import sort as array_sort from itertools import izip import shelve +import copy class Dataset: """The Dataset base class. @@ -40,22 +41,23 @@ class Dataset: self._name = name self._identifiers = identifiers self._type = 'n' - if isinstance(array,ndarray): + try: array = atleast_2d(asarray(array)) - # vectors are column vectors - if array.shape[0]==1: - array = array.T - self.shape = array.shape - if identifiers!=None: - self._set_identifiers(identifiers,self._all_dims) - else: - self._identifiers = self._create_identifiers(self.shape,self._all_dims) - self._set_identifiers(self._identifiers,self._all_dims) - - self._array = array - + except: + print "Cant cast array as numpy-array" + return + # vectors are column vectors + if array.shape[0]==1: + array = array.T + self.shape = array.shape + + if identifiers!=None: + self._set_identifiers(identifiers,self._all_dims) else: - raise ValueError, "Array input must be of type ndarray" + self._identifiers = self._create_identifiers(self.shape,self._all_dims) + self._set_identifiers(self._identifiers,self._all_dims) + + self._array = array def __iter__(self): """Returns an iterator over dimensions of dataset.""" @@ -189,7 +191,10 @@ class Dataset: for key in idents if self._map[dim].has_key(key)] return asarray(index) - + def copy(self): + return copy.deepcopy(self) + + class CategoryDataset(Dataset): """The category dataset class. @@ -245,8 +250,9 @@ class CategoryDataset(Dataset): class GraphDataset(Dataset): """The graph dataset class. - A dataset class for representing graphs using an adjacency matrix - (aka. restricted to square symmetric signed integers matrices) + A dataset class for representing graphs using an (weighted) + adjacency matrix + (aka. restricted to square symmetric matrices) If the library NetworkX is installed, there is support for representing the graph as a NetworkX.Graph, or NetworkX.XGraph structure. @@ -254,7 +260,7 @@ class GraphDataset(Dataset): def __init__(self,array=None,identifiers=None,shape=None,all_dims=[],**kwds): Dataset.__init__(self,array=array,identifiers=identifiers,name='A') - self.has_graph = False + self._graph = None self._type = 'g' def asnetworkx(self,nx_type='graph'): @@ -262,37 +268,44 @@ class GraphDataset(Dataset): ids = self.get_identifiers(dim,sorted=True) adj_mat = self.asarray() G = self._graph_from_adj_matrix(adj_mat,labels=ids) - self.has_graph = True + self._graph = G return G - def _graph_from_adj_matrix(self,A,labels=None,nx_type='graph'): - """Creates a networkx graph class from adjacency matrix and - ordered labels. nx_type = ['graph',['xgraph']] labels = None, - results in string-numbered labels + def _graph_from_adj_matrix(self,A,labels=None): + """Creates a networkx graph class from adjacency + (possibly weighted) matrix and ordered labels. + nx_type = ['graph',['xgraph']] + labels = None, results in string-numbered labels """ - import networkx as nx + + try: + import networkx as nx + except: + print "Failed in import of NetworkX" + return m,n = A.shape# adjacency matrix must be of type that evals to true/false for neigbours if m!=n: raise IOError, "Adjacency matrix must be square" - if nx_type=='graph': + + if A[A[:,0].nonzero()[0],0]==1: #unweighted graph G = nx.Graph() - elif nx_type=='x_graph': - G = nx.XGraph() else: - raise IOError, "Unknown graph type: %s" %nx_type + G = nx.XGraph() if labels==None: # if labels not provided mark vertices with numbers labels = [str(i) for i in range(m)] - for nbrs,head in izip(A,labels): for i,nbr in enumerate(nbrs): if nbr: tail = labels[i] - G.add_edge(head,tail) + if type(G)==nx.XGraph: + G.add_edge(head,tail,nbr) + else: + G.add_edge(head,tail) return G - + Dataset._all_dims=set() class ReverseDict(dict): @@ -310,7 +323,10 @@ class ReverseDict(dict): def __setitem__(self, key, value): dict.__setitem__(self, key, value) - self.reverse[value] = key + try: + self.reverse[value] = key + except: + self.reverse = {value:key} def to_file(filepath,dataset,name=None): """Write dataset to file. A file may contain multiple datasets.