from scipy import atleast_2d,asarray,ArrayType,shape
from scipy import sort as array_sort
from itertools import izip

class Dataset:
    """The Dataset base class.
    
    A Dataset is an n-way array with defined string identifiers across
    all dimensions.

    example of use:

    ---
    dim_name_rows = 'rows'
    names_rows = ('row_a','row_b')
    ids_1 = [dim_name_rows, names_rows]

    dim_name_cols = 'cols'
    names_cols = ('col_a','col_b','col_c','col_d')
    ids_2 = [dim_name_cols, names_cols]

    Array_X = rand(2,4)
    data = Dataset(Array_X,(ids_1,ids_2),name="Testing")

    dim_names = [dim for dim in data]

    column_identifiers = [id for id in data['cols'].keys()]
    column_index = [index for index in data['cols'].values()]

    'cols' in data -> True

    ---

    data = Dataset(rand(10,20)) (generates dims and ids (no links))
    """
    def __init__(self,array,identifiers=None,name='Unnamed dataset'):
        self._dims = [] #existing dimensions in this dataset
        self._map = {} # internal mapping for dataset:  identifier <--> index
        self._name = name
        if isinstance(array,ArrayType):
            array = atleast_2d(asarray(array))
            self.shape = array.shape
            if identifiers!=None:
                self._set_identifiers(identifiers,self._all_dims)
            else:
                ids = self._create_identifiers(self.shape,self._all_dims)
                self._set_identifiers(ids,self._all_dims)
            
            self._array = array
            
        else:
            raise ValueError, "Array input must be of ArrayType"
                        
    def __str__(self):
        return self._name + ":\n" + "Dim names: " +  self._dims.__str__()

    def __iter__(self):
        """Returns an iterator over dimensions of dataset."""
        return self._dims.__iter__()

    def __contains__(self,dim):
        """Returns True if dim is a dimension name in dataset."""
        # return self._dims.__contains__(dim)
        return self._map.__contains__(dim)

    def __len__(self):
        """Returns the number of dimensions in the dataset"""
        return len(self._map)

    def __getitem__(self,dim):
        """Return the identifers along the dimension dim."""
        return self._map[dim]

    def _create_identifiers(self,shape,all_dims):
        """Creates dimension names and identifier names, and returns
        identifiers."""
        
        dim_names = ['rows','cols'] 
        ids = []
        for axis,n in enumerate(shape):
            if axis<2:
                dim_suggestion = dim_names[axis]
            else:
                dim_suggestion = 'dim'
            dim_suggestion = self._suggest_dim_name(dim_suggestion,all_dims) 
            identifier_creation = [str(axis) + "_" + i for i in map(str,range(n))]
            ids.append((dim_suggestion,identifier_creation))
            all_dims.add(dim_suggestion)
        return ids

    def _set_identifiers(self,identifiers,all_dims):
        """Creates internal mapping of identifiers structure."""
        for dim,ids in identifiers:
            pos_map={}
            if dim not in self._dims:
                self._dims.append(dim)
                all_dims.add(dim)
            else:
                raise ValueError, "Dimension names must be unique whitin dataset"
            
            for pos,id in enumerate(ids):
                pos_map[id] = pos
            self._map[dim] = pos_map
            
    def _suggest_dim_name(self,dim_name,all_dims):
        """Suggests a unique name for dim and returns it"""
        c = 0
        new_name = dim_name
        while new_name in all_dims:
            new_name = dim_name + "_" + str(c)
            c+=1
        return new_name
        
    def asarray(self):
        """Returns the numeric array (data) of dataset"""
        if not self.has_array:
            raise ValueError, "Dataset is empty"
        else:
            return self._array

    def add_array(self,array):
        """Adds array as an ArrayType object.
        A one-dim array is transformed to a two-dim array (row-vector)
        """
        
        if self.shape!=array.shape:
            raise ValueError, "Input array must be of similar dimensions as dataset"
        self._array = atleast_2d(asarray(array))

    def get_name(self):
        """Returns dataset name"""
        return self._name

    def get_all_dims(self):
        """Returns all dimensions in project"""
        return self._all_dims

    def get_dim_name(self,axis=None):
        """Returns dim name for an axis, if no axis is provided it returns a list of dims"""
        if type(axis)==int:
            return self._dims[axis]
        else:
            return [dim for dim in self]

    def get_identifiers(self, dim, indices=None,sorted=True):
        """Returns identifiers along dim, sorted by position (index) is optional.

        You can optionally provide a list of indices to get only the
        identifiers of a given position.

        Identifiers are the unique names (strings) for a variable in a given dim.
        Index (Indices) are the Identifiers position in a matrix in a given dim.
        """
        if sorted==True:
            items = self._map[dim].items()
            backitems = [ [v[1],v[0]] for v in items]
            backitems.sort()
            ids = [ backitems[i][1] for i in range(0,len(backitems))]
            
        else:
            ids = self._map[dim].keys()

        if indices != None:
            ids = [ids[index] for index in indices]

        return ids
        

    def get_indices(self, dim, idents=None):
        """Returns indices for identifiers along dimension.

        You can optionally provide a list of identifiers to retrieve a index subset.
        
        
        Identifiers are the unique names (strings) for a variable in a given dim.
        Index (Indices) are the Identifiers position in a matrix in a given dim."""
        if idents==None:
            index = array_sort(self._map[dim].values())
        else:
            index = [self._map[dim][key] for key in idents]
        return asarray(index)
     
class CategoryDataset(Dataset):
    """The category dataset class.

    A dataset for representing class information as binary
    matrices (0/1-matrices).

    There is support for using a less memory demanding, and
    fast intersection look-ups by representing the binary matrix as a
    dictionary in each dimension.
    """
    
    def __init__(self):
        Dataset.__init__(self)
        self.has_collection = False
            
    def as_array(self):
        """Returns data as binary matrix"""
        if not self.has_array and self.has_collection:
            #build numeric array
            pass

    def as_collection(self,dim):
        """Returns data as collection along dim"""
        pass
    
    def add_collection(self,input_dict):
        """Adds a category data as collection.

        A collection is a datastructure that contains a dictionary for
        each pair of dimension in dataset, keyed by identifiers and
        values is a set of identifiers in the other dimension
        """
        #build category data as double dicts
        pass


class GraphDataset(Dataset):
    """The graph dataset class.

    A dataset class for representing graphs using an adjacency matrix
    (aka. restricted to square symmetric signed integers matrices)
    
    If the library NetworkX is installed, there is support for
    representing the graph as a NetworkX.Graph, or NetworkX.XGraph structure.
    """
    def __init__(self,array=None,identifiers=None,shape=None,all_dims=[],**kwds):
        Dataset.__init__(self,array=array,identifiers=identifiers,name='A')
        self.has_graph = False
        
    def asnetworkx(self,nx_type='graph'):
        dim = self.get_dim_names()[0]
        ids = self.get_identifiers(dim)
        adj_mat = self.asarray()
        G = self._graph_from_adj_matrix(adj_mat,labels=ids)
        self.has_graph = True
        return G
    
    def _graph_from_adj_matrix(self,A,labels=None,nx_type='graph'):
        """Creates a networkx graph class from adjacency matrix and ordered labels.
        nx_type = ['graph',['xgraph']]
        labels = None, results in string-numbered labels
        
        """
        import networkx as nx
        m,n = shape(A)# adjacency matrix must be of type that evals to true/false for neigbours
        if m!=n:
            raise IOError, "Adjacency matrix must be square"
        if nx_type=='graph':
            G = nx.Graph()
        elif nx_type=='x_graph':
            G = nx.XGraph()
        else:
            raise IOError, "Unknown graph type: %s" %nx_type

        if labels==None: # if labels not provided mark vertices with numbers
            labels = [str(i) for i in range(m)]

	
        for nbrs,head in izip(A,labels):
            for i,nbr in enumerate(nbrs):
                if nbr:
                    tail = labels[i]
                    G.add_edge(head,tail)
        return G
Dataset._all_dims=set()

class Selection:
    """Handles selected identifiers along each dimension of a dataset"""
    def __init__(self):
        self.current_selection={}