from scipy import ndarray, atleast_2d, asarray, intersect1d, zeros
from scipy import empty, sparse, where
from scipy import sort as array_sort
from itertools import izip
import shelve
import copy
import re

class Universe(object):
    def __init__(self, name):
        self.name = name
        self._ids = {}

    def register(self, dim):
        """Increase reference count for identifiers in Dimension object dim"""
        if dim.name != self.name:
            return
        for i in dim:
            self._ids[i] = self._ids.get(i, 0) + 1
    
    def unregister(self, dim):
        """Update reference count for identifiers in Dimension object dim
           Update reference count for identifiers in Dimension object dim, and remove all
           identifiers with a reference count of 0, as they do not (by definition) exist 
           any longer.
        """
        if dim.name != self.name:
            return
        for i in dim:
            refcount = self._ids[i]
            if refcount == 1:
                self._ids.pop(i)
            else:
                self._ids[i] -= 1

    def __str__(self):
        return "%s: %i elements, %i references" % (self.name, len(self._ids), sum(self._ids.values()))

    def __contains__(self, element):
        return self._ids.__contains__(element)

    def __len__(self):
        return len(self._ids)

    def intersection(self, dim):
        return set(self._ids).intersection(dim.idset)


class Dimension(object):
    """A Dimension represents the set of identifiers an object has along an axis.
    """
    def __init__(self, name, ids=[]):
        self.name = name
        self.idset = set(ids)
        self.idlist = list(ids)

    def __getitem__(self, element):
        return self.idlist[element]

    def __getslice__(self, start, end):
        return self.idlist[start:end]

    def __contains__(self, element):
        return self.idset.__contains__(element)
        
    def __str__(self):
        return "%s: %s" % (self.name, str(self.idlist))

    def __len__(self):
        return len(self.idlist)

    def __iter__(self):
        return iter(self.idlist)

    def intersection(self, dim):
        return self.idset.intersection(dim.idset)


class Dataset(object):
    """The Dataset base class.
    
    A Dataset is an n-way array with defined string identifiers across
    all dimensions.

    example of use:

    ---
    dim_name_rows = 'rows'
    names_rows = ('row_a','row_b')
    ids_1 = [dim_name_rows, names_rows]

    dim_name_cols = 'cols'
    names_cols = ('col_a','col_b','col_c','col_d')
    ids_2 = [dim_name_cols, names_cols]

    Array_X = rand(2,4)
    data = Dataset(Array_X,(ids_1,ids_2),name="Testing")

    dim_names = [dim for dim in data]

    column_identifiers = [id for id in data['cols'].keys()]
    column_index = [index for index in data['cols'].values()]

    'cols' in data -> True

    ---

    data = Dataset(rand(10,20)) (generates dims and ids (no links))
    """
    
    def __init__(self, array, identifiers=None, name='Unnamed dataset'):
        self._dims = [] #existing dimensions in this dataset
        self._map = {} # internal mapping for dataset:  identifier <--> index
        self._name = name
        self._identifiers = identifiers
        
        if not isinstance(array, sparse.spmatrix):
            array = atleast_2d(asarray(array))
        # vector are column (array)
        if array.shape[0] == 1:
            array = array.T
        self.shape = array.shape
        
        if identifiers != None:
            self._validate_identifiers(identifiers)
            self._set_identifiers(identifiers, self._all_dims)
        else:
            self._identifiers = self._create_identifiers(self.shape, self._all_dims)
            self._set_identifiers(self._identifiers, self._all_dims)
        self._array = array

    def __iter__(self):
        """Returns an iterator over dimensions of dataset."""
        return self._dims.__iter__()

    def __contains__(self,dim):
        """Returns True if dim is a dimension name in dataset."""
        # return self._dims.__contains__(dim)
        return self._map.__contains__(dim)

    def __len__(self):
        """Returns the number of dimensions in the dataset"""
        return len(self._map)

    def __getitem__(self,dim):
        """Return the identifers along the dimension dim."""
        return self._map[dim]

    def _create_identifiers(self, shape, all_dims):
        """Creates dimension names and identifier names, and returns
        identifiers."""
        
        dim_names = ['rows','cols'] 
        ids = []
        for axis, n in enumerate(shape):
            if axis < 2:
                dim_suggestion = dim_names[axis]
            else:
                dim_suggestion = 'dim'
            dim_suggestion = self._suggest_dim_name(dim_suggestion, all_dims) 
            identifier_creation = [str(axis) + "_" + i for i in map(str, range(n))]
            ids.append((dim_suggestion, identifier_creation))
            all_dims.add(dim_suggestion)
        return ids

    def _set_identifiers(self, identifiers, all_dims):
        """Creates internal mapping of identifiers structure."""
        for dim, ids in identifiers:
            pos_map = ReverseDict()
            if dim not in self._dims:
                self._dims.append(dim)
                all_dims.add(dim)
            else:
                raise ValueError, "Dimension names must be unique whitin dataset"
            for pos, id in enumerate(ids):
                pos_map[id] = pos
            self._map[dim] = pos_map
            
    def _suggest_dim_name(self,dim_name,all_dims):
        """Suggests a unique name for dim and returns it"""
        c = 0
        new_name = dim_name
        while new_name in all_dims:
            new_name = dim_name + "_" + str(c)
            c += 1
        return new_name
        
    def asarray(self):
        """Returns the numeric array (data) of dataset"""
        if isinstance(self._array, sparse.spmatrix):
            return self._array.toarray()
        return self._array

    def set_array(self, array):
        """Adds array as an ArrayType object.
        A one-dim array is transformed to a two-dim array (row-vector)
        """
        if not isinstance(array, type(self._array)):
            raise ValueError("Input array of type: %s does not match existing array type: %s") %(type(array), type(self._array))
        if self.shape != array.shape:
            raise ValueError, "Input array must be of similar dimensions as dataset"
        self._array = atleast_2d(asarray(array))

    def get_name(self):
        """Returns dataset name"""
        return self._name

    def get_all_dims(self):
        """Returns all dimensions in project"""
        return self._all_dims

    def get_dim_name(self, axis=None):
        """Returns dim name for an axis, if no axis is provided it
        returns a list of dims"""
        if type(axis) == int:
            return self._dims[axis]
        else:
            return [dim for dim in self._dims]

    def common_dims(self, ds):
        """Returns a list of the common dimensions in the two datasets."""
        dims = self.get_dim_name()
        ds_dims = ds.get_dim_name()
        return [d for d in dims if d in ds_dims]
        
    def get_identifiers(self, dim, indices=None, sorted=False):
        """Returns identifiers along dim, sorted by position (index)
        is optional.
	
        You can optionally provide a list/ndarray of indices to get
        only the identifiers of a given position.

        Identifiers are the unique names (strings) for a variable in a
        given dim.  Index (Indices) are the Identifiers position in a
        matrix in a given dim.
        """
        if indices != None:
            if len(indices) == 0:# if empty list or empty array
                return []
        if indices != None:
            # be sure to match intersection
            #indices = intersect1d(self.get_indices(dim),indices)
            ids = [self._map[dim].reverse[i] for i in indices]
	else:
            if sorted == True:
                ids = [self._map[dim].reverse[i] for i in array_sort(self._map[dim].values())]
            else:
                ids = self._map[dim].keys()

        return ids
        
    def get_indices(self, dim, idents=None):
        """Returns indices for identifiers along dimension.
        You can optionally provide a list of identifiers to retrieve a
        index subset.
        
        Identifiers are the unique names (strings) for a variable in a
        given dim.  Index (Indices) are the Identifiers position in a
        matrix in a given dim.  If none of the input identifiers are
        found an empty index is returned
        """
        if not isinstance(idents, list) and not isinstance(idents, set):
            raise ValueError("idents needs to be a list/set got: %s" %type(idents))
        if idents == None:
            index = array_sort(self._map[dim].values())
        else:
            index = [self._map[dim][key]
                     for key in idents if self._map[dim].has_key(key)]
        return asarray(index)        

    def existing_identifiers(self, dim, idents):
        """Filters a list of identifiers to find those that are present in the
        dataset.

        The most common use of this function is to get a list of
        identifiers who correspond one to one with the list of indices produced
        when get_indices is given an identifier list. That is
        ds.get_indices(dim, idents) and ds.exisiting_identifiers(dim, idents)
        will have the same order.

        @param dim: A dimension present in the dataset.
        @param idents: A list of identifiers along the given dimension.
        @return: A list of identifiers in the same order as idents, but
        without elements not present in the dataset.
        """
        if not isinstance(idents, list) and not isinstance(idents, set):
            raise ValueError("idents needs to be a list/set got: %s" %type(idents))

        return [key for key in idents if self._map[dim].has_key(key)]

    def copy(self):
        """ Returns deepcopy of dataset.
        """
        return copy.deepcopy(self)

    def subdata(self, dim, idents):
        """Returns a new dataset based on dimension and given identifiers.
        
        """
        ds = self.copy()
        indices = array_sort(ds.get_indices(dim, idents))
        
        idents = ds.get_identifiers(dim, indices=indices)
        if not idents:
            raise ValueError("No of identifers from: \n%s \nfound in %s" %(str(idents), ds._name))
        ax = [i for i, name in enumerate(ds._dims) if name == dim][0]
        subarr = ds._array.take(indices, ax)
        new_indices = range(len(idents))
        ds._map[dim] = ReverseDict(zip(idents, new_indices))
        ds.shape = tuple(len(ds._map[d]) for d in ds._dims)
        ds.set_array(subarr)
        return ds
    
    def transpose(self):
        """Returns a copy of transpose of a dataset.

        As for the moment: only support for 2D-arrays.
        """
        
        assert(len(self.shape) == 2)
        ds = self.copy()
        ds._array = ds._array.T
        ds._dims.reverse()
        ds.shape = ds._array.shape
        return ds
    
    def _validate_identifiers(self, identifiers):
        for dim_name, ids in identifiers: 
            if len(set(ids)) != len(ids):
                raise ValueError("Identifiers not unique in : %s" %dim_name)
        identifier_shape = [len(i[1]) for i in identifiers]
        if len(identifier_shape) != len(self.shape):
            raise ValueError("Identifier list length must equal array dims")
        for ni, na in zip(identifier_shape, self.shape):
            if ni != na:
                raise ValueError, "Identifier-array mismatch: %s: (idents: %s, array: %s)" %(self._name, ni, na)


class CategoryDataset(Dataset):
    """The category dataset class.

    A dataset for representing class information as binary
    matrices (0/1-matrices).

    There is support for using a less memory demanding, sparse format. The
    prefered (default) format for a category dataset is the compressed sparse row 
    format (csr)

    Always has linked dimension in first dim:
    ex matrix:
    .        go_term1    go_term2  ...
    gene_1
    gene_2
    gene_3
    .
    .
    .
    
    """
    
    def __init__(self, array, identifiers=None, name='C'):
        Dataset.__init__(self, array, identifiers=identifiers, name=name)

    def as_spmatrix(self):
        if isinstance(self._array, sparse.spmatrix):
            return self._array
        else:
            arr = self.asarray()
            return sparse.csr_matrix(arr.astype('i'))

    def to_spmatrix(self):
        if isinstance(self._array, sparse.spmatrix):
            self._array = self._array.tocsr()
        else:
            self._array = sparse.scr_matrix(self._array)

    def as_dictlists(self):
        """Returns data as dict of identifiers along first dim.

        ex: data['gene_1'] = ['map0030','map0010', ...]
        
        fixme: Deprecated?
        """
        data = {}
        for name, ind in self._map[self.get_dim_name(0)].items():
            if isinstance(self._array, ndarray):
                indices = self._array[ind,:].nonzero()[0]
            elif isinstance(self._array, sparse.spmatrix):
                if not isinstance(self._array, sparse.csr_matrix):
                    array = self._array.tocsr()
                else:
                    array = self._array
                indices = array[ind,:].indices
            if len(indices) == 0: # should we allow categories with no members?
                continue
            data[name] = self.get_identifiers(self.get_dim_name(1), indices)
        self._dictlists = data
        return data

    def as_selections(self):
        """Returns data as a list of Selection objects.

        The list of selections is not ordered (sorted) by any means.
        """
        ret_list = []
        for cat_name, ind in self._map[self.get_dim_name(1)].items():
            if isinstance(self._array, sparse.spmatrix):
                if not isinstance(self._array, sparse.csc_matrix):
                    self._array = self._array.tocsc()
                indices = self._array[:,ind].indices
            else:
                indices = self._array[:,ind].nonzero()[0]
            if len(indices) == 0:
                continue
            ids = self.get_identifiers(self.get_dim_name(0), indices)
            selection = Selection(cat_name)
            selection.select(self.get_dim_name(0), ids)
            ret_list.append(selection)
        return ret_list
    

class GraphDataset(Dataset):
    """The graph dataset class.

    A dataset class for representing graphs. The constructor may use an 
    incidence matrix (possibly sparse) or (if networkx installed) a 
    networkx.(X)Graph structure.
    
    If the networkx library is installed, there is support for
    representing the graph as a networkx.Graph, or networkx.XGraph structure.
    """
    
    def __init__(self, input, identifiers=None, name='A', nodepos = None):      
        if isinstance(input, sparse.spmatrix):
            arr = input
        else:
            try:
                arr = asarray(input)
            except:
                raise ValueError("Could not identify input")
        Dataset.__init__(self, array=arr, identifiers=identifiers, name=name)
        self._graph = None
        self.nodepos = nodepos
        
    def as_spmatrix(self):
        if isinstance(self._array, sparse.spmatrix):
            return self._array
        else:
            arr = self.asarray()
            return sparse.csr_matrix(arr.astype('i'))

    def to_spmatrix(self):
        if isinstance(self._array, sparse.spmatrix):
            self._array = self._array.tocsr()
        else:
            self._array = sparse.scr_matrix(self._array)
    
    def asnetworkx(self):
        if self._graph != None:
            return self._graph
        dim0, dim1 = self.get_dim_name()
        node_ids = self.get_identifiers(dim0, sorted=True)
        edge_ids = self.get_identifiers(dim1, sorted=True)
        G, weights = self._graph_from_incidence_matrix(self._array, node_ids=node_ids, edge_ids=edge_ids)
        self._graph = G
        return G
        
    def from_networkx(cls, G, node_dim, edge_dim, sp_format=True):
        """Create graph dataset from networkx graph.
        
        When G is a Graph/Digraph edge identifiers will be created,
        else (XGraoh/XDigraph) it is assumed that edge attributes are
        the edge identifiers.
        """
        
        import networkx as nx
        n = G.number_of_nodes()
        m = G.number_of_edges()
        
        if isinstance(G, nx.DiGraph):
            G = nx.XDiGraph(G)
        elif isinstance(G, nx.Graph):
            G = nx.XGraph(G)
        
        edge_ids = [e[2] for e in G.edges()]
        node_ids = map(str, G.nodes())
        n2ind = {}
        for ind, node in enumerate(node_ids):
            n2ind[node] = ind
        
        if sp_format:
            I = sparse.lil_matrix((n, m))
        else:
            I = zeros((m, n), dtype='i')
        
        for i, (h, t, eid) in enumerate(G.edges()):
            if eid != None:
                edge_ids[i] = eid
            else:
                edge_ids[i] = 'e_' + str(i)
            hind = n2ind[str(h)]
            tind = n2ind[str(t)]
            I[hind, i] = 1
            if G.is_directed():
                I[tind, i] = -1
            else:
                I[tind, i] = 1
        idents = [[node_dim, node_ids], [edge_dim, edge_ids]]
        if G.name != '':
            name = G.name
        else:
            name = 'A'
        ds = GraphDataset(I, idents, name)
        return ds
    
    from_networkx = classmethod(from_networkx)            
    
    def _incidence2adjacency(self, I):
        """Incidence to adjacency matrix.
        
        I*I.T - eye(n)?
        """
        raise NotImplementedError
    
    def _graph_from_incidence_matrix(self, I, node_ids, edge_ids):
        """Creates a networkx graph class from incidence
        (possibly weighted) matrix and ordered labels.
        
        labels = None, results in string-numbered labels
        """
        try:
            import networkx as nx
        except:
            print "Failed in import of NetworkX"
            return None

        m, n = I.shape
        assert(m == len(node_ids))
        assert(n == len(edge_ids))
        weights = []
        directed = False
        G = nx.XDiGraph(name=self._name)
        if isinstance(I, sparse.spmatrix):
            I = I.tocsr()
        for ename, col in izip(edge_ids, I.T):
            if isinstance(I, sparse.spmatrix):
                node_ind = col.indices
                w1, w2 = col.data
            else:
                node_ind = where(col != 0)[0]
                w1, w2 = col[node_ind]
            node1 = node_ids[node_ind[0]]
            node2 = node_ids[node_ind[1]]
            if w1 < 0: # w1 is tail
                directed = True
                assert(w2 > 0 and (w1 + w2) == 0)
                G.add_edge(node2, node1, ename)
                weights.append(w2)
            else: #w2 is tail or graph is undirected
                assert(w1 > 0)
                if w2 < 0:
                    directed = True
                G.add_edge(node1, node2, ename)
                weights.append(w1)
        if not directed:
            G = G.to_undirected()
        return G, asarray(weights)

Dataset._all_dims = set()


class ReverseDict(dict):
    """A dictionary which can lookup values by key, and keys by value.
    
    All values and keys must be hashable, and unique.
    
    example:
    >>d = ReverseDict((['a',1],['b',2]))
    >>print d['a'] --> 1
    >>print d.reverse[1] --> 'a'
    """
    def __init__(self, *args, **kw):
        dict.__init__(self, *args, **kw)
        self.reverse = dict([[v, k] for k, v in self.items()])

    def __setitem__(self, key, value):
        dict.__setitem__(self, key, value)
        try:
            self.reverse[value] = key
        except:
            self.reverse = {value:key}


class Selection(dict):
    """Handles selected identifiers along each dimension of a dataset"""

    def __init__(self, title='Unnamed Selecton'):
        self.title = title
        
    def __getitem__(self, key):
        if not self.has_key(key):
            return None
        return dict.__getitem__(self, key)

    def dims(self):
        return self.keys()
        
    def axis_len(self, axis):
        if self._selection.has_key(axis):
            return len(self._selection[axis])
        return 0

    def select(self, axis, labels):
        self[axis] = labels


def write_ftsv(fd, ds, decimals=7, sep='\t', fmt=None, sp_format=True):
    """Writes a dataset in laydi tab separated values (ftsv) form.
    
    @param fd: An open file descriptor to the output file.
    @param ds: The dataset to be written. 
    @param decimals: Number of decimals, only supported for dataset.
    @param fmt: String formating
    The function handles datasets of these classes: 
    Dataset, CategoryDataset and GraphDataset
    """
    opened = False
    if isinstance(fd, str):
        fd = open(fd, 'w')
        opened = True
    
    # Write header information
    if isinstance(ds, CategoryDataset):
        type = 'category'
        if fmt == None:
            fmt = '%d'
    elif isinstance(ds, GraphDataset):
        type = 'network'
        if fmt == None:
            fmt = '%d'
    elif isinstance(ds, Dataset):
        type = 'dataset'
        if fmt == None:
            fmt = '%%.%df' % decimals
        else:
            fmt = '%%.%d' %decimals + fmt
    else:
        raise Exception("Unknown object type")
    fd.write('# type: %s' %type + '\n')

    for dim in ds.get_dim_name():
        fd.write("# dimension: %s" % dim)
        for ident in ds.get_identifiers(dim, sorted=True):
            fd.write(" " + ident)
        fd.write("\n")

    fd.write("# name: %s" % ds.get_name() + '\n')
    # xy-node-positions
    if type == 'network' and ds.nodepos != None:
        fd.write("# nodepos:")
        node_dim = ds.get_dim_name(0)
        for ident in ds.get_identifiers(node_dim, sorted=True):
            fd.write(" %s,%s" %ds.nodepos[ident])
        fd.write("\n")
    
    # Write data
    if hasattr(ds, "as_spmatrix") and sp_format == True:
        m = ds.as_spmatrix()
    else:
        m = ds.asarray()
    if isinstance(m, sparse.spmatrix):
        _write_sparse_elements(fd, m, fmt, sep)
    else:
        _write_elements(fd, m, fmt, sep)

    if opened:
        fd.close()

def read_ftsv(fd, sep=None):
    """Read a dataset in laydi tab separated values (ftsv) form and return it.
    
    @param fd: An open file descriptor.
    @return: A Dataset, CategoryDataset or GraphDataset depending on the information
    read.
    """
    opened = False
    if isinstance(fd, str):
        fd = open(fd)
        opened = True

    split_re = re.compile('^#\s*(\w+)\s*:\s*(.+)')
    dimensions = []
    identifiers = {}
    type = 'dataset'
    name = 'Unnamed dataset'
    sp_format = False
    nodepos = None
    # graphtype = 'graph'

    # Read header lines from file.
    line = fd.readline()
    while line:
        m = split_re.match(line)
        if m:
            key, val = m.groups()
            
            # The line is on the form;
            # dimension: dimname id1 id2 id3 ...
            if key == 'dimension':
                values = [v.strip() for v in val.split(' ')]
                dimensions.append(values[0])
                identifiers[values[0]] = values[1:]

            # Read type of dataset.
            # Should be dataset, category, or network
            elif key == 'type':
                type = val
            
            elif key == 'name':
                name = val
            
            # storage format
            # if sp_format is True then use coordinate triplets
            elif key == 'sp_format':
                if val in ['False', 'false', '0', 'F', 'f',]:
                    sp_format = False
                elif val in ['True', 'true', '1', 'T', 't']:
                    sp_format = True
                else:
                    raise ValueError("sp_format: %s not valid " %sp_format)
            
            elif key == 'nodepos':
                node_dim = dimensions[0]
                idents = identifiers[node_dim]
                nodepos = {}
                xys = val.split(" ")
                for node_id, xy in zip(idents, xys):
                    x, y = map(float, xy.split(","))
                    nodepos[node_id] = (x, y)
        
        else:
            break
        line = fd.readline()

    # Dimensions in the form [(dim1, [id1, id2, id3 ..) ...] 
    dims = [(x, identifiers[x]) for x in dimensions]
    dim_lengths = [len(identifiers[x]) for x in dimensions]

    # Create matrix and assign element reader
    if type == 'category':
        if sp_format:
            matrix = sparse.lil_matrix(dim_lengths)
        else:
            matrix = empty(dim_lengths, dtype='i')
    else:
        if sp_format:
            matrix = sparse.lil_matrix(dim_lengths)
        else:
            matrix = empty(dim_lengths)

    if sp_format:
        matrix = _read_sparse_elements(fd, matrix)
    else:
        matrix = _read_elements(fd, matrix)
    

    # Create dataset of specified type
    if type == 'category':
        ds = CategoryDataset(matrix, dims, name)
    elif type == 'network':
        ds = GraphDataset(matrix, dims, name=name, nodepos=nodepos)
    else:
        ds = Dataset(matrix, dims, name)

    if opened:
        fd.close()

    return ds

def write_csv(fd, ds, decimals=7, sep='\t'):
    """Write a dataset as comma/tab/whatever dilimited data.
    
    @param fd: An open file descriptor to the output file.
    @param ds: The dataset to be written. 
    @param decimals: Number of decimals, only supported for dataset.
    @param sep: Value separator
    """

    ## Open file if a string is passed instead of a file descriptor
    opened = False
    if isinstance(fd, str):
        fd = open(fd, 'w')
        opened = True
    
    ## Get data
    rowdim, coldim = ds.get_dim_name()
    rowids = ds.get_identifiers(rowdim)
    colids = ds.get_identifiers(coldim)
    a = ds.asarray()
    y, x = a.shape
    fmt = '%%%if' % decimals

    ## Write header
    fd.write(rowdim)
    fd.write(sep)
    for i, id in enumerate(colids):
        fd.write(id)
        fd.write(sep)
    fd.write('\n')

    ## Write matrix data
    for j in range(y):
        fd.write(rowids[j])
        fd.write(sep)
        for i in range(x):
            fd.write(fmt % (a[j, i],))
            fd.write(sep)
        fd.write('\n')
    
    ## If we opened the stream, close it
    if opened:
        fd.close()

def _write_sparse_elements(fd, arr, fmt='%d', sep=None):
    """ Sparse coordinate format.""" 
    fd.write('# sp_format: True\n\n')
    fmt = '%d %d ' + fmt + '\n'
    csr = arr.tocsr()
    for ii in xrange(csr.size):
        ir, ic = csr.rowcol(ii)
        data = csr.getdata(ii)
        fd.write(fmt % (ir, ic, data))

def _write_elements(fd, arr, fmt='%f', sep='\t'):
    """Standard value separated format."""
    fmt = fmt + sep
    fd.write('\n')
    y, x = arr.shape
    for j in range(y):
        for i in range(x):
            fd.write(fmt %arr[j, i])
        fd.write('\n')

def _read_elements(fd, arr, sep=None):
    line = fd.readline()
    i = 0
    while line:
        values = line.split(sep)
        for j, val in enumerate(values):
            arr[i,j] = float(val)
        i += 1
        line = fd.readline()
    return arr

def _read_sparse_elements(fd, arr, sep=None):
    line = fd.readline()
    while line:
        i, j, val = line.split()
        arr[int(i),int(j)] = float(val)
        line = fd.readline()
    return arr.tocsr()