import logger
from scipy import array,take,asarray,shape,nonzero
import project
from itertools import izip 


class Dataset:
    """Dataset base class.

    A Dataset is an n-way array with defined string identifiers across
    all dimensions.
    """
    def __init__(self,input_array,def_list):
        self._data = asarray(input_array)
        self.dims = shape(self._data)
        self.def_list = def_list
        self._ids_set = set()
        self.ids={}
        self._dim_num = {}
        self._dim_names = []
        if len(def_list)!=len(self.dims):
            raise ValueError,"array dims and identifyer mismatch"
        for axis,(dim_name,ids) in enumerate(def_list):
            enum_ids = {}
            if dim_name not in project.c_p.dim_names:
                dim_name = project.c_p.suggest_dim_name(dim_name)
                if not ids:
                    ids = self._create_identifiers(axis)
            for num,name in enumerate(ids):
                enum_ids[name] = num
            self.ids[dim_name] = enum_ids
            self._ids_set = self._ids_set.union(set(ids))
            self._dim_num[dim_name] = axis
            self._dim_names.append(dim_name)
                    
        for df,d in izip(def_list,self.dims): #check that data and labels match 
            df=df[1]
            if len(df)!=d and df:
                raise ValueError,"dim size and identifyer mismatch"
            
    def names(self,axis=0):
        """Returns identifier names of a dimension. NB: not in any order! """
        
        if type(axis)==int:
            dim_name = self._dim_names[axis]
        elif type(axis)==str:
            dim_name = axis
        return self.ids[dim_name].keys()
    
    def extract_data(self,ids,dim_name):
        """Extracts data along a dimension by identifiers"""
        new_def_list = self.def_list[:]
        ids_index = [self.ids[dim_name][id_name] for id_name in ids]
        dim_number = self._dim_num[dim_name]
        try:
            out_data = take(self._data,ids_index,axis=dim_number)
        except:
            raise ValueError
        new_def_list[dim_number][1] = ids
        extracted_data = Dataset(out_data,def_list=new_def_list,parents=self.parents)
        return extracted_data

    def _create_identifiers(self,axis):
        """Creates identifiers along an axis"""
        n_dim = self.dims[axis]
        return [str(axis) + '_' + str(i) for i in range(n_dim)]

    def extract_id_from_index(self,dim_name,index):
        """Returns a set of ids from array/list of indexes."""
        dim_ids = self.ids[dim_name]
        if type(index)==int:
            index = [index]
        return set([id for id,ind in dim_ids.items() if ind in index])

    def extract_index_from_id(self,dim_name,id):
        """Returns an array of indexes from a set/list of identifiers
        (or a single id)"""
        dim_ids = self.ids[dim_name]
        return array([ind for name,ind in dim_ids.items() if name in id])
        
    
class CategoryDataset(Dataset):
    def __init__(self,array,def_list):
        Dataset.__init__(self,array,def_list)

    def get_elements_by_category(self,dim,category):
        """Returns all elements along input dim belonging to category.
        Assumes a two-dim category data only!
        """
        if type(category)!=list:
            raise ValueError, "category must be list"
        gene_ids = []
        axis_dim = self._dim_num[dim]
        cat_index = self.extract_index_from_id(category)
        for ind in cat_index:
            if axis_dim==0:
                gene_indx = nonzero(self._data[:,ind])
            elif axis_dim==1:
                gene_indx = nonzero(self._data[ind,:])
            else:
                ValueError, "Only support for 2-dim data"
            gene_ids.append(self.extract_id_from_index(dim,gene_index))
        return gene_ids
                
        
        
                
            
        
        

    

class Selection:
    """Handles selected identifiers along each dimension of a dataset"""
    def __init__(self):
        self.current_selection={}