from scipy import ndarray,atleast_2d,asarray,intersect1d
from scipy import sort as array_sort
from itertools import izip
import shelve
import copy

class Dataset:
    """The Dataset base class.
    
    A Dataset is an n-way array with defined string identifiers across
    all dimensions.

    example of use:

    ---
    dim_name_rows = 'rows'
    names_rows = ('row_a','row_b')
    ids_1 = [dim_name_rows, names_rows]

    dim_name_cols = 'cols'
    names_cols = ('col_a','col_b','col_c','col_d')
    ids_2 = [dim_name_cols, names_cols]

    Array_X = rand(2,4)
    data = Dataset(Array_X,(ids_1,ids_2),name="Testing")

    dim_names = [dim for dim in data]

    column_identifiers = [id for id in data['cols'].keys()]
    column_index = [index for index in data['cols'].values()]

    'cols' in data -> True

    ---

    data = Dataset(rand(10,20)) (generates dims and ids (no links))
    """
    def __init__(self,array,identifiers=None,name='Unnamed dataset'):
        self._dims = [] #existing dimensions in this dataset
        self._map = {} # internal mapping for dataset:  identifier <--> index
        self._name = name
        self._identifiers = identifiers
        self._type = 'n'
        try:
            array = atleast_2d(asarray(array))
        except:
            print "Cant cast array as numpy-array"
            return
        # vectors are column vectors 
        if array.shape[0]==1:
            array = array.T
        self.shape = array.shape

        if identifiers!=None:
            self._set_identifiers(identifiers,self._all_dims)
        else:
            self._identifiers = self._create_identifiers(self.shape,self._all_dims)
            self._set_identifiers(self._identifiers,self._all_dims)
            
        self._array = array

    def __iter__(self):
        """Returns an iterator over dimensions of dataset."""
        return self._dims.__iter__()

    def __contains__(self,dim):
        """Returns True if dim is a dimension name in dataset."""
        # return self._dims.__contains__(dim)
        return self._map.__contains__(dim)

    def __len__(self):
        """Returns the number of dimensions in the dataset"""
        return len(self._map)

    def __getitem__(self,dim):
        """Return the identifers along the dimension dim."""
        return self._map[dim]

    def _create_identifiers(self,shape,all_dims):
        """Creates dimension names and identifier names, and returns
        identifiers."""
        
        dim_names = ['rows','cols'] 
        ids = []
        for axis,n in enumerate(shape):
            if axis<2:
                dim_suggestion = dim_names[axis]
            else:
                dim_suggestion = 'dim'
            dim_suggestion = self._suggest_dim_name(dim_suggestion,all_dims) 
            identifier_creation = [str(axis) + "_" + i for i in map(str,range(n))]
            ids.append((dim_suggestion,identifier_creation))
            all_dims.add(dim_suggestion)
        return ids

    def _set_identifiers(self,identifiers,all_dims):
        """Creates internal mapping of identifiers structure."""
        for dim,ids in identifiers:
            pos_map = ReverseDict()
            if dim not in self._dims:
                self._dims.append(dim)
                all_dims.add(dim)
            else:
                raise ValueError, "Dimension names must be unique whitin dataset"
            
            for pos,id in enumerate(ids):
                pos_map[id] = pos
            self._map[dim] = pos_map
            
    def _suggest_dim_name(self,dim_name,all_dims):
        """Suggests a unique name for dim and returns it"""
        c = 0
        new_name = dim_name
        while new_name in all_dims:
            new_name = dim_name + "_" + str(c)
            c+=1
        return new_name
        
    def asarray(self):
        """Returns the numeric array (data) of dataset"""
        return self._array

    def add_array(self,array):
        """Adds array as an ArrayType object.
        A one-dim array is transformed to a two-dim array (row-vector)
        """
        
        if self.shape!=array.shape:
            raise ValueError, "Input array must be of similar dimensions as dataset"
        self._array = atleast_2d(asarray(array))

    def get_name(self):
        """Returns dataset name"""
        return self._name

    def get_all_dims(self):
        """Returns all dimensions in project"""
        return self._all_dims

    def get_dim_name(self,axis=None):
        """Returns dim name for an axis, if no axis is provided it
        returns a list of dims"""
        if type(axis)==int:
            return self._dims[axis]
        else:
            return [dim for dim in self]

    def get_identifiers(self, dim, indices=None,sorted=False):
        """Returns identifiers along dim, sorted by position (index)
        is optional.
	
        You can optionally provide a list/ndarray of indices to get
        only the identifiers of a given position.

        Identifiers are the unique names (strings) for a variable in a
        given dim.  Index (Indices) are the Identifiers position in a
        matrix in a given dim.
        """
        if indices!=None:
            if len(indices)==0:# if empty list or empty array
                return []
        
        if indices != None:
            # be sure to match intersection
            #indices = intersect1d(self.get_indices(dim),indices)
            ids = [self._map[dim].reverse[i] for i in indices]
	else:
            if sorted==True:
                ids = [self._map[dim].reverse[i] for i in array_sort(self._map[dim].values())]
            else:
                ids = self._map[dim].keys()

        return ids
        
    def get_indices(self, dim, idents=None):
        """Returns indices for identifiers along dimension.
        You can optionally provide a list of identifiers to retrieve a
        index subset.
        
        
        Identifiers are the unique names (strings) for a variable in a
        given dim.  Index (Indices) are the Identifiers position in a
        matrix in a given dim.  If none of the input identifiers are
        found an empty index is returned
        """
        if idents==None:
            index = array_sort(self._map[dim].values())
        else:
            index = [self._map[dim][key]
                     for key in idents if self._map[dim].has_key(key)]
        return asarray(index)        

    def copy(self):
        """ Returns deepcopy of dataset.
        """
        return copy.deepcopy(self)


class CategoryDataset(Dataset):
    """The category dataset class.

    A dataset for representing class information as binary
    matrices (0/1-matrices).

    There is support for using a less memory demanding, and
    fast intersection look-ups by representing the binary matrix as a
    dictionary in each dimension.

    Always has linked dimension in first dim:
    ex matrix:
             go_term1    go_term2  ...
    gene_1
    gene_2
    gene_3
    .
    .
    .
    """
    
    def __init__(self,array,identifiers=None,name='C'):
        Dataset.__init__(self,array,identifiers=identifiers,name=name)
        self.has_dictlists = False
        self._type = 'c'
            
    def as_dict_lists(self):
        """Returns data as dict of indices along first dim.

        ex: data['gene_id'] = ['map0030','map0010', ...]
        """
        data={}
        for name,ind in self._map[self.get_dim_name(0)].items():
            data[name] = self.get_identifiers(self.get_dim_name(1),
                                              list(self._array[ind,:].nonzero()))
        self._dictlists = data
        self.has_dictlists = True
        return data

    def as_selections(self):
        """Returns data as a list of Selection objects.
        """
        ret_list = []
        for cat_name,ind in self._map[self.get_dim_name(1)].items():
            ids = self.get_identifiers(self.get_dim_name(0),
                                       self._array[:,ind].nonzero()[0])
            selection = Selection(cat_name)
            selection.select(self.get_dim_name(0), ids)
            ret_list.append(selection)
        return ret_list
    

class GraphDataset(Dataset):
    """The graph dataset class.

    A dataset class for representing graphs using an (weighted)
    adjacency matrix
    (aka. restricted to square symmetric matrices)
    
    If the library NetworkX is installed, there is support for
    representing the graph as a NetworkX.Graph, or NetworkX.XGraph structure.
    """
    
    def __init__(self,array=None,identifiers=None,shape=None,all_dims=[],**kwds):
        Dataset.__init__(self,array=array,identifiers=identifiers,name='A')
        self._graph = None
        self._type = 'g'
        self._pos = None
        
    def asnetworkx(self,nx_type='graph'):
        dim = self.get_dim_name()[0]
        ids = self.get_identifiers(dim,sorted=True)
        adj_mat = self.asarray()
        G = self._graph_from_adj_matrix(adj_mat,labels=ids)
        self._graph = G
        return G
    
    def _graph_from_adj_matrix(self,A,labels=None):
        """Creates a networkx graph class from adjacency
        (possibly weighted) matrix and ordered labels.
        
        nx_type = ['graph',['xgraph']]
        labels = None, results in string-numbered labels
        """
        
        try:
            import networkx as nx
        except:
            print "Failed in import of NetworkX"
            return
        m,n = A.shape# adjacency matrix must be of type that evals to true/false for neigbours
        if m!=n:
            raise IOError, "Adjacency matrix must be square"

        if A[A[:,0].nonzero()[0][0],0]==1: #unweighted graph
            G = nx.Graph()
        else:
            G = nx.XGraph()

        if labels==None: # if labels not provided mark vertices with numbers
            labels = [str(i) for i in range(m)]

        for nbrs,head in izip(A,labels):
            for i,nbr in enumerate(nbrs):
                if nbr:
                    tail = labels[i]
                    if type(G)==nx.XGraph:
                        G.add_edge(head,tail,nbr)
                    else:
                        G.add_edge(head,tail)
        return G
        
Dataset._all_dims=set()

class ReverseDict(dict):
    """
    A dictionary which can lookup values by key, and keys by value.
    All values and keys must be hashable, and unique.

    d = ReverseDict((['a',1],['b',2]))
    print d['a'] --> 1
    print d.reverse[1] --> 'a'
    """
    def __init__(self, *args, **kw):
        dict.__init__(self, *args, **kw)
        self.reverse = dict([[v,k] for k,v in self.items()])

    def __setitem__(self, key, value):
        dict.__setitem__(self, key, value)
        try:
            self.reverse[value] = key
        except:
            self.reverse = {value:key}

def to_file(filepath,dataset,name=None):
    """Write dataset to file. A file may contain multiple datasets.
    append to file by using option mode='a'
    """
    if not name:
        name = dataset._name
    data = shelve.open(filepath,flag='c',protocol=2)
    if data: #we have an append 
        names = data.keys()
        if name in names:
            print "Data with name: %s overwritten" %dataset._name
    sub_data = {'array':dataset._array,'idents':dataset._identifiers,'type':dataset._type}
    data[name] = sub_data 
    data.close()

def from_file(filepath):
    """Read dataset from file """
    data = shelve.open(filepath,flag='r')
    out_data = []
    for name in data.keys():
        sub_data = data[name]
        if sub_data['type']=='c':
            out_data.append(CategoryDataset(sub_data['array'],identifiers=sub_data['idents'],name=name))
        elif sub_data['type']=='g':
            out_data.append(GraphDataset(sub_data['array'],identifiers=sub_data['idents'],name=name))
        else:
            out_data.append(Dataset(sub_data['array'],identifiers=sub_data['idents'],name=name)) 
            
    return out_data
    
class Selection(dict):
    """Handles selected identifiers along each dimension of a dataset"""

    def __init__(self, title='Unnamed Selecton'):
        self.title = title
        
    def __getitem__(self, key):
        if not self.has_key(key):
            return None
        return dict.__getitem__(self, key)

    def dims(self):
        return self.keys()
        
    def axis_len(self, axis):
        if self._selection.has_key(axis):
            return len(self._selection[axis])
        return 0

    def select(self, axis, labels):
        self[axis] = labels