diff --git a/fluents/dataset.py b/fluents/dataset.py index 141198f..09fb58b 100644 --- a/fluents/dataset.py +++ b/fluents/dataset.py @@ -272,7 +272,20 @@ class CategoryDataset(Dataset): def __init__(self, array, identifiers=None, name='C'): Dataset.__init__(self, array, identifiers=identifiers, name=name) - + + def asspmatrix(self): + if isinstance(self._array, sparse.spmatrix): + return self._array + else: + arr = self.asarray() + return sparse.csr_matrix(arr.astype('i')) + + def tospmatrix(self): + if isinstance(self._array, sparse.spmatrix): + self._array = self._array.tocsr() + else: + self._array = sparse.scr_matrix(self._array) + def as_dict_lists(self): """Returns data as dict of identifiers along first dim. @@ -423,7 +436,7 @@ class Selection(dict): def select(self, axis, labels): self[axis] = labels -def write_ftsv(fd, ds, decimals=7, sep='\t', fmt=None): +def write_ftsv(fd, ds, decimals=7, sep='\t', fmt=None, sp_format=True): """Writes a dataset in fluents tab separated values (ftsv) form. @param fd: An open file descriptor to the output file. @@ -466,7 +479,10 @@ def write_ftsv(fd, ds, decimals=7, sep='\t', fmt=None): print >> fd, "# name: %s" % ds.get_name() # Write data - m = ds.asarray() + if hasattr(ds, "asspmatrix") and sp_format == True: + m = ds.asspmatrix() + else: + m = ds.asarray() if isinstance(m, sparse.spmatrix): _write_sparse_elements(fd, m, fmt, sep) else: