""" Module for Gene ontology related functions called in R"""
import scipy
import rpy
silent_eval = rpy.with_mode(rpy.NO_CONVERSION, rpy.r)
import collections

def goterms_from_gene(genelist, ontology='BP', garbage=None):
    """ Returns the go-terms from a specified genelist (Entrez id).

    Recalculates the information content if needed based on selected evidence codes.
    
    """
    rpy.r.library("GOSim")
    _CODES = {"IMP" : "inferred from mutant phenotype",
              "IGI" : "inferred from genetic interaction",
              "IPI" :"inferred from physical interaction",
              "ISS" : "inferred from sequence similarity",
              "IDA" : "inferred from direct assay",
              "IEP" : "inferred from expression pattern",
              "IEA" : "inferred from electronic annotation",
               "TAS" : "traceable author statement",
              "NAS" : "non-traceable author statement",
              "ND" : "no biological data available",
              "IC" : "inferred by curator"
              }
    _ONTOLOGIES = ['BP', 'CC', 'MF']
    #assert(scipy.all([(code in _CODES) for code in garbage]) or garbage==None)
    assert(ontology in _ONTOLOGIES)
    dummy = rpy.r.setOntology(ontology)
    ddef = False
    if ontology=='BP' and garbage!=None:
        # This is for ont=BP and garbage =['IEA', 'ISS', 'ND']
        rpy.r.load("ICsBPIMP_IGI_IPI_ISS_IDA_IEP_TAS_NAS_IC.rda")
        ic = rpy.r.assign("IC",rpy.r.IC, envir=rpy.r.GOSimEnv)
        print len(ic)
    else:
        ic = rpy.r('get("IC", envir=GOSimEnv)')
    print "loading GO definitions environment"
    
    gene2terms = {}
    for gene in genelist:
        info = rpy.r('GOENTREZID2GO[["' + str(gene) + '"]]')
        #print info
        if info:
            skip=False
            for term, desc in info.items():
                if ic.get(term)==scipy.isinf:
                    print "\nIC is Inf on this GO term %s for this gene: %s" %(term,gene)
                    skip=True
                if ic.get(term)==None:
                    #print "\nHave no IC on this GO term %s for this gene: %s" %(term,gene)
                    skip=True
                if desc['Ontology']!=ontology:
                    #print "\nThis GO term %s belongs to: %s:" %(term,desc['Ontology'])
                    skip = True
                if not skip:
                    if gene2terms.has_key(gene):
                        gene2terms[gene].append(term)
                    else:
                        gene2terms[gene] = [term]
        else:
           print "\nHave no Annotation on this gene: %s" %gene 

    return gene2terms

def genego_matrix(goterms, tmat, gene_ids, term_ids, func=max):
    ngenes = len(gene_ids)
    nterms = len(term_ids)
    gene2indx = {}
    for i,id in enumerate(gene_ids):
        gene2indx[id]=i
    term2indx = {}
    for i,id in enumerate(term_ids):
        term2indx[id]=i
    #G = scipy.empty((nterms, ngenes),'d')
    G = []
    new_gene_index = []
    for gene, terms in goterms.items():
        g_ind = gene2indx[gene]
        if len(terms)>0:
            t_ind = []
            new_gene_index.append(g_ind)
            for term in terms:
                if term2indx.has_key(term): t_ind.append(term2indx[term])
            subsim = tmat[t_ind, :]
            gene_vec = scipy.apply_along_axis(func, 0, subsim)
            G.append(gene_vec)

    return scipy.asarray(G), new_gene_index

def genego_sim(gene2go, gene_ids, all_go_terms, STerm, go_term_sim="OA", term_sim="Lin", verbose=False):
    """Returns go-terms x genes similarity matrix.

    :input:
           - gene2go: dict: keys: gene_id, values: go_terms
           - gene_ids: list of gene ids (entrez ids)
           - STerm: (go_terms x go_terms) similarity matrix
           - go_terms_sim: similarity measure between a gene and multiple go terms (max, mean, OA)
           - term_sim: similarity measure between two go-terms
           - verbose
    """
    rpy.r.library("GOSim")

    #gene_ids = gene2go.keys()
    GG = scipy.empty((len(all_go_terms), len(gene_ids)), 'd')
    for j,gene in enumerate(gene_ids):
        for i,go_term in enumerate(all_go_terms):
            if verbose:
                print "\nAssigning similarity from %s to terms(gene): %s" %(go_term,gene)
            GG_ij = rpy.r.getGSim(go_term, gene2go[gene], similarity=go_term_sim,
                                  similarityTerm=term_sim, STerm=STerm, verbose=verbose)
            GG[i,j] = GG_ij
    return GG

def goterm2desc(gotermlist):
    """Returns the go-terms description keyed by go-term.
    """
    rpy.r.library("GO")
    term2desc = {}
    for term in gotermlist:
        try:
            desc = rpy.r('Term(GOTERM[["' +str(term)+ '"]])')
            term2desc[str(term)] = desc
        except:
            raise Warning("Description not found for %s\n Mapping incomplete" %term)
    return term2desc

def parents_dag(go_terms, ontology=['BP']):
    """ Returns a list of lists representation of a GO DAG parents of goterms.

    make the networkx graph by:
        G = networkx.Digraph()
        G = networkx.from_dict_of_lists(edge_dict, G)
    """
    try:
        rpy.r.library("GOstats")
    except:
        raise ImportError, "Gostats"
    assert(go_terms[0][:3]=='GO:')

    # go valid namespace
    go_env = {'BP':rpy.r.GOBPPARENTS, 'MF':rpy.r.GOMFPARENTS, 'CC': rpy.r.GOCCPARENTS}
    graph = rpy.r.GOGraph(go_terms, go_env[ontology[0]])
    edges = rpy.r.edges(graph)
    edges.pop('all')
    edge_dict = {}
    for head, neighbours in edges.items():
        for nn in neighbours.values():
            if edge_dict.has_key(nn):
                edge_dict[nn].append(head)
            else:
                edge_dict[nn] = [head]
    return edge_dict

def gene_GO_hypergeo_test(genelist,universe="entrezUniverse",ontology="BP",chip = "hgu133a",pval_cutoff=0.01,cond=False,test_direction="over"):
    
    #assert(scipy.alltrue([True for i in genelist if i in universe]))
    universeGeneIds=universe
    params = rpy.r.new("GOHyperGParams",
                       geneIds=genelist,
                       annotation="hgu133a",
                       ontology=ontology,
                       pvalueCutoff=pval_cutoff,
                       conditional=cond,
                       testDirection=test_direction
                       )
    result = rpy.r.summary(rpy.r.hyperGTest(params))
    
    return rpy.r.summary(result), params

def data_aff2loc_hgu133a(X, aff_ids, verbose=False):
    aff_ids = scipy.asarray(aff_ids)
    if verbose:
        print "\nNumber of probesets in affy list: %s" %len(aff_ids)
    import rpy
    rpy.r.library("hgu133a")
    trans_table = rpy.r.as_list(rpy.r.hgu133aENTREZID)
    if verbose:
        print "Number of entrez ids: %d" %(scipy.asarray(trans_table.values())>0).sum()
    enz2aff = collections.defaultdict(list)
    #aff2enz = collections.defaultdict(list)
    for aff, enz in trans_table.items():
        if int(enz)>0 and (aff in aff_ids):
            enz2aff[enz].append(aff)
            #aff2enz[aff].append(enz)
    if verbose:
        print "\nNumber of translated entrez ids: %d" %len(enz2aff)
    aff2ind = dict(zip(aff_ids, scipy.arange(len(aff_ids))))
    var_x = X.var(0)
    new_data = []
    new_ids = []
    m = 0
    s = 0
    for enz, aff_id_list in enz2aff.items():
        index = [aff2ind[aff_id] for aff_id in aff_id_list]
        if len(index)>1:
            m+=1
            if verbose:
                pass
                #print "\nEntrez id: %s has %d probesets" %(enz, len(index))
                #print index
            xsub = X[:,index]
            choose_this = scipy.argmax(xsub.var(0))
            new_data.append(xsub[:,choose_this].ravel())
        else:
            s+=1
            new_data.append(X[:,index].ravel())
        new_ids.append(enz)
    if verbose:
        print "Ids with multiple probesets: %d" %m
        print "Ids with unique probeset: %d" %s
    X = scipy.asarray(new_data).T
    return X, new_ids