import sys,os
import os.path
import webbrowser
import cPickle

import scipy
import networkx as nx

from laydi import logger,plots,workflow,dataset,main
from laydi.lib import blmfuncs,nx_utils,cx_utils

import gobrowser


class SmallTestWorkflow(workflow.Workflow):
    name = 'Demo'
    ident = 'demo'
    description = 'A small test workflow for gene expression analysis.'
    chip = 'hgu'
    def __init__(self):
        workflow.Workflow.__init__(self)        

        # DATA IMPORT
        load = workflow.Stage('load', 'Data')

        load_small = LoadDataFunction('load-small', 'Small', self)
        load.add_function(load_small)

        load_medium = LoadDataFunction('load-geneid', 'GeneID', self, 'geneid')
        load.add_function(load_medium)
        
        load_medium = LoadDataFunction('load-full', 'FullChip', self, 'full')
        load.add_function(load_medium)
        
        self.add_stage(load)
        
        # NETWORK PREPROCESSING
        #net = workflow.Stage('net', 'Network integration')
        #net.add_function(DiffKernelFunction())
        #net.add_function(ModKernelFunction())
        #self.add_stage(net)
        
        # Models
        model = workflow.Stage('models', 'Models')
        model.add_function(blmfuncs.PCA())
        model.add_function(blmfuncs.PLS())
        model.add_function(SAM())
        self.add_stage(model)
        
        query = workflow.Stage('query', 'Gene Query')
        query.add_function(NCBIQuery())
        query.add_function(KEGGQuery())
        query.add_function(SubgraphQuery())
        self.add_stage(query)
        
        # Background knowledge
        go = workflow.Stage('go', 'Gene Ontology')
        go.add_function(gobrowser.PlotDagFunction())
        go.add_function(GoEnrichment())
        go.add_function(GoEnrichmentCond())
        go.add_function(MapGO2Gene())
        go.add_function(MapGene2GO())
        self.add_stage(go)

        # EXTRA PLOTS
        #plt = workflow.Stage('net', 'Network')
        #plt.add_function(nx_analyser.KeggNetworkAnalyser())
        #self.add_stage(plt)
        
        logger.log('debug', 'Small test workflow is now active')


class LoadDataFunction(workflow.Function):
    """Loads all datasets in a given directory."""
    def __init__(self, ident, label, wf, dir=''):
        workflow.Function.__init__(self, ident, label)
        self._dir = dir
        self._wf = wf

    def run(self):
        path = os.path.join(main.options.datadir, self._wf.ident, self._dir)
        files = os.listdir(path)
        out = []
        for fn in files:
            if fn.endswith('.ftsv'):
                out.append(dataset.read_ftsv(os.path.join(path, fn)))
        return out



##### WORKFLOW SPECIFIC FUNCTIONS ######


class SAM(workflow.Function):
    def __init__(self, id='sam', name='SAM'):
        workflow.Function.__init__(self, id, name)
        
    def run(self, x, y):        
        n_iter = 50 #B
        alpha = 0.01 #cut off on qvals
        
        ###############

        # Main function call

        # setup prelimenaries
        import rpy
        rpy.r.library("siggenes")
        rpy.r.library("multtest")
        
        cl = scipy.dot(y.asarray(), scipy.diag(scipy.arange(y.shape[1]))).sum(1)
        data = x.asarray().T
        sam = rpy.r.sam(data, cl=cl, B=n_iter, var_equal=False,med=False,s0=scipy.nan,rand=scipy.nan)
        qvals = scipy.asarray(rpy.r.slot(sam, "p.value"))
        pvals = scipy.asarray(rpy.r.slot(sam, "q.value"))
        
        sam_index = (qvals<alpha).nonzero()[0]

        # Update selection object
        dim_name = x.get_dim_name(1)
        sam_selection = x.get_identifiers(dim_name, indices=sam_index)
        main.project.set_selection(dim_name, sam_selection)
        
        sel = dataset.Selection('SAM selection')
        sel.select(dim_name, sam_selection)
        logger.log('notice','Number of significant varibles (SAM): %s' %len(sam_selection))

        # ## OUTPUT ###
        xcolname = x.get_dim_name(1) # genes
        x_col_ids = [xcolname, x.get_identifiers(xcolname, sorted=True)]
        sing_id = ['_john', ['0']] #singleton
        D_qvals = dataset.Dataset(qvals, (x_col_ids, sing_id), name='q_vals')
        D_pvals = dataset.Dataset(pvals, (x_col_ids, sing_id), name='p_vals')
        
        # plots
        s_indx = qvals.flatten().argsort()
        s_ids = [x_col_ids[0],[x_col_ids[1][i] for i in s_indx]]
        xindex = scipy.arange(len(qvals))
        qvals_s = qvals.take(s_indx)
        D_qs = dataset.Dataset(qvals_s, (s_ids, sing_id), name="sorted qvals")
        Dind = dataset.Dataset(xindex, (s_ids, sing_id), name="dum")
        st = plots.ScatterPlot(D_qs, Dind, 'gene_ids', '_john', '0', '0', s=10, name='SAM qvals')
        
        return [D_qvals, D_pvals, D_qs, st, sel]
        

class DiffKernelFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'diffkernel', 'Diffusion')

    def run(self, x, a):
        """x is gene expression data, a is the network.
        """
        #sanity check:
        g = a.asnetworkx()
        genes = x.get_identifiers(x.get_dim_name(1), sorted=True)
        W = nx.adj_matrix(g, nodelist=genes)
        X = x.asarray()
        Xc, mn_x = cx_utils.mat_center(X, ret_mn=True)
        out = []
        alpha = 1.0
        beta = 1.0
        K = nx_utils.K_diffusion(W, alpha=alpha, beta=beta,normalised=True)
        Xp = scipy.dot(Xc, K) + mn_x
        # dataset
        row_ids = (x.get_dim_name(0),
                   x.get_identifiers(x.get_dim_name(0),
                                     sorted=True))
        col_ids = (x.get_dim_name(1),
                   x.get_identifiers(x.get_dim_name(1),
                                     sorted=True))
        
        xout = dataset.Dataset(Xp,
                               (row_ids, col_ids),
                               name=x.get_name()+'_diff'+str(alpha))
        out.append(xout)
        
        return out


class ModKernelFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'mokernel', 'Modularity')

    def run(self,x,a):
        X = x.asarray()
        g = a.asnetworkx()
        genes = x.get_identifiers(x.get_dim_name(1), sorted=True)
        W = nx.adj_matrix(g, nodelist=genes)
        out=[]
        alpha=.2
        Xc,mn_x = cx_utils.mat_center(X, ret_mn=True)
        K = nx_utils.K_modularity(W, alpha=alpha)
        Xp = scipy.dot(Xc, K)
        Xp = Xp + mn_x
        
        # dataset
        row_ids = (x.get_dim_name(0),
                   x.get_identifiers(x.get_dim_name(0),
                                     sorted=True))
        col_ids = (x.get_dim_name(1),
                   x.get_identifiers(x.get_dim_name(1),
                                     sorted=True))
        xout = dataset.Dataset(Xp,
                               (row_ids,col_ids),
                               name=x.get_name()+'_mod'+str(alpha))
        out.append(xout)
        return out


class NCBIQuery(workflow.Function):
    def __init__(self, gene_id_name='gene_ids'):
        self._gene_id_name = gene_id_name
        workflow.Function.__init__(self, 'query', 'NCBI')

    def run(self):
        selection = main.project.get_selection()
        if not selection.has_key(self._gene_id_name):
            logger.log("notice", "Expected gene ids: %s, but got: %s" %(self._gene_id_name, selection.keys()))
            return None
        if len(selection[self._gene_id_name])==0:
            logger.log("notice", "No selected genes to query")
            return None
        
        base = 'http://www.ncbi.nlm.nih.gov/entrez/query.fcgi?'
        options = {r'&db=' : 'gene',
                   r'&cmd=' : 'retrieve',
                   r'&dopt=' : 'full_report'}
        gene_str = ''.join([gene + "+" for gene in selection[self._gene_id_name]])
        options[r'&list_uids='] = gene_str[:-1]
        opt_str = ''.join([key+value for key,value in options.items()])
        web_str = base + opt_str
        webbrowser.open(web_str)


class KEGGQuery(workflow.Function):
    def __init__(self, org='hsa', gene_id_name='gene_ids'):
        self._org=org
        self._gene_id_name = gene_id_name
        workflow.Function.__init__(self, 'query', 'KEGG')

    def run(self, selection):
        if not selection.has_key(self._gene_id_name):
            logger.log("notice", "Expected gene ids: %s, but got. %s" %(self._gene_id_name, selection.keys()))
            return None
        if len(selection[self._gene_id_name])==0:
            logger.log("notice", "No selected genes to query")
            return None
        
        base = r'http://www.genome.jp/dbget-bin/www_bget?'
        gene_str = ''.join([gene + "+" for gene in selection[self._gene_id_name]])
        gene_str = gene_str[:-1]
        gene_str = self._org + "+" + gene_str
        web_str = base + gene_str
        webbrowser.open(web_str)


class GoEnrichment(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'goenrich', 'Go Enrichment')

    def run(self, data):
        import rpy
        rpy.r.library("GOstats")
        
        # Get universe
        # Here, we are using a defined dataset to represent the universe
        if not 'gene_ids' in data:
            logger.log('notice', 'No dimension called [gene_ids] in dataset: %s' %data.get_name())
            return
        universe = list(data.get_identifiers('gene_ids'))
        logger.log('notice', 'Universe consists of %s gene ids from %s' %(len(universe), data.get_name()))
        # Get current selection and validate
        curr_sel = main.project.get_selection()
        selected_genes = list(curr_sel['gene_ids'])
        if len(selected_genes)==0:
            logger.log('notice', 'This function needs a current selection!')
            return
        
        # Hypergeometric parameter object
        pval_cutoff = 0.9999
        cond = False
        test_direction = 'over'
        params = rpy.r.new("GOHyperGParams",
                           geneIds=selected_genes,
                           annotation="hgu133a",
                           ontology="BP",
                           pvalueCutoff=pval_cutoff,
                           conditional=cond,
                           testDirection=test_direction
                           )
        # run test
        # result.keys(): ['Count', 'Term', 'OddsRatio', 'Pvalue', 'ExpCount', 'GOBPID', 'Size']
        result = rpy.r.summary(rpy.r.hyperGTest(params))
        
        # dataset
        terms = result['GOBPID']
        pvals = scipy.log(scipy.asarray(result['Pvalue']))
        row_ids = ('go-terms', terms)
        col_ids = ('_john', ['_doe'])
        
        xout = dataset.Dataset(pvals,
                               (row_ids, col_ids),
                               name='P values (enrichment)')
        return [xout]


class GoEnrichmentCond(workflow.Function):
    """ Enrichment conditioned on dag structure."""
    def __init__(self):
        workflow.Function.__init__(self, 'goenrich', 'Go Cond. Enrich.')

    def run(self, data):
        import rpy
        rpy.r.library("GOstats")
        
        # Get universe
        # Here, we are using a defined dataset to represent the universe
        if not 'gene_ids' in data:
            logger.log('notice', 'No dimension called [gene_ids] in dataset: %s', data.get_name())
            return
        universe = list(data.get_identifiers('gene_ids'))
        logger.log('notice', 'Universe consists of %s gene ids from %s' %(len(universe), data.get_name()))
        # Get current selection and validate
        curr_sel = main.project.get_selection()
        selected_genes = list(curr_sel['gene_ids'])
        if len(selected_genes)==0:
            logger.log('notice', 'This function needs a current selection!')
            return
        
        # Hypergeometric parameter object
        pval_cutoff = 1
        cond = True
        test_direction = 'over'
        params = rpy.r.new("GOHyperGParams",
                           geneIds=selected_genes,
                           annotation="hgu133a",
                           ontology="BP",
                           pvalueCutoff=pval_cutoff,
                           conditional=cond,
                           testDirection=test_direction
                           )
        # run test
        # result.keys(): ['Count', 'Term', 'OddsRatio', 'Pvalue', 'ExpCount', 'GOBPID', 'Size']
        result = rpy.r.summary(rpy.r.hyperGTest(params))
        
        # dataset
        terms = result['GOBPID']
        pvals = scipy.log(scipy.asarray(result['Pvalue']))
        row_ids = ('go-terms', terms)
        col_ids = ('_john', ['_doe'])
        
        xout = dataset.Dataset(pvals,
                               (row_ids, col_ids),
                               name='P values (enrichment)')
        return [xout]


class MapGene2GO(workflow.Function):
    def __init__(self, ont='bp', gene_id_name='gene_ids'):
        self._ont = ont
        self._gene_id_name = gene_id_name
        workflow.Function.__init__(self, 'gene2go', 'gene->GO')
        # load data at init
        try:
            fname = "/home/flatberg/laydi/data/gene2go.pcl"
            self._gene2go = cPickle.load(open(fname))
        except:
            logger.log("notice", "could not load mapping")
        
    def run(self):
        selection = main.project.get_selection()
        if not selection.has_key(self._gene_id_name):
            logger.log("notice", "Expected gene ids: %s, but got. %s" %(self._gene_id_name, selection.keys()))
            return None
        if len(selection[self._gene_id_name])==0:
            logger.log("notice", "No selected genes to query")
            return None

        gene_ids = selection[self._gene_id_name]
        go_ids = set()
        for gene in gene_ids:
            go_ids_new = self._gene2go.get(gene, [])
            if not go_ids_new:
                logger.log("notice", "Could not find any goterms for %s" %gene)
            go_ids.update(self._gene2go.get(gene, []))
        main.project.set_selection('go-terms', go_ids)
        logger.log("notice", "GO terms updated")


class MapGO2Gene(workflow.Function):
    def __init__(self, ont='bp', gene_id_name='go-terms'):
        self._ont = ont
        self._gene_id_name = gene_id_name
        workflow.Function.__init__(self, 'go2gene', 'GO->gene')
        # load data at init
        try:
            fname = "/home/flatberg/laydi/data/go2gene.pcl"
            self._go2gene = cPickle.load(open(fname))
        except:
            logger.log("notice", "could not load mapping")
        
    def run(self):
        selection = main.project.get_selection()
        if not selection.has_key(self._gene_id_name):
            logger.log("notice", "Expected gene ids: %s, but got. %s" %(self._gene_id_name, selection.keys()))
            return None
        if len(selection[self._gene_id_name])==0:
            logger.log("notice", "No selected genes to query")
            return None

        go_ids = selection[self._gene_id_name]
        gene_ids = set()
        for go in go_ids:
            if not self._go2gene.get(go,[]):
                logger.log("notice", "Could not find any gene ids for %s" %go)
            gene_ids.update(self._go2gene.get(go,[]))
        main.project.set_selection('gene_ids', gene_ids)
        logger.log("notice", "GO terms updated")


class SubgraphQuery(workflow.Function):
    def __init__(self, graph='kegg', dim='gene_ids'):
        self._gtype = graph
        self._dim = dim
        
        workflow.Function.__init__(self, 'keggraph', 'KeggGraph')
        
    def run(self, Dw, DA):
        max_edge_ratio = .20
        max_cov_ratio = .25
        neigh_type = 'cov'
        neigh_type = 'cosine'
        #neigh_type = 'heat'
        # 1.) Operate on a subset selection
        selection = main.project.get_selection()
        if not selection.has_key(self._dim):
            logger.log("notice", "Expected gene ids: %s, but got. %s" %(self._dim, selection.keys()))
            return None
        if len(selection[self._dim]) == 0:
            logger.log("notice", "No selected genes to query, using all")
            Dw = Dw.subdata(self._dim, Dw.get_identifiers(self._dim)[:100])
        else:
            Dw = Dw.subdata(self._dim, selection[self._dim])

        # 2.) Pairwise goodness in loading space
        indices = self._pairsim(Dw)
        idents1 = Dw.get_identifiers(self._dim, indices[:,0])
        idents2 = Dw.get_identifiers(self._dim, indices[:,1])
        idents = zip(idents1, idents2)
        
        # 3.) Identify close subgraphs
        G = DA.asnetworkx()
        for edge in G.edges():
            if edge not in idents:
                G.delete_edge(edge)
        S = nx.connected_component_subgraphs(G)
        print map(len, S)
        # 4.) Rank subgraphs
        
        main.project.set_selection('gene_ids', idents1)
        #main.project.set_sele
        logger.log("notice", "Gene ids updated")
        plt = GraphQueryScatterPlot(S, Dw)
        #return [plt]

    def _pairsim(self, Dw, ptype='cosine',cut_rat=.2):
        """Returns close pairs across given dim.
        ptype : ['cov', 'correlation', 'cosine', 'heat', 'euclidean']
        """
        W = Dw.asarray()
        if ptype == 'cov':
            W -= W.mean(1)[:,scipy.newaxis]
            wcov = scipy.dot(W, W.T)/(W.shape[1]-1)
            wcov_min = wcov.max()*cut_rat
            indices = scipy.asarray(scipy.where(wcov >= wcov_min)).T
        elif ptype == 'heat':
            from hcluster import pdist, squareform
            D = squareform(pdist(W))
            H = exp(-D)
            h_min = H.max()*cut_rat
            indices = scipy.asarray(scipy.where(H >= h_min)).T
        elif ptype in ['euclidean', 'cosine', 'correlation']:
            from hcluster import pdist, squareform
            D = squareform(pdist(W), ptype)
            d_min = D.max()*cut_rat
            indices = []
            for i in range(D.shape[0]):
                for j in range(i, D.shape[0]):
                    if D[i,j] <= d_min:
                        indices.append([i,j])
            print "W"
            print W.shape
            indices = scipy.asarray(indices)
            
        else:
            raise ValueError("ptype: %s  not valid" %ptype)
        return indices

    def _subgraphsim(self, Dw, idents, stype='dijkstra'):
        # subgraph
        Gw = nx.XGraph()
        for edge in idents:
            e = G.get_edge(edge)
            Gw.add_edge()
        if stype == 'dijkstra':
            pass
        
class GraphQueryScatterPlot(plots.ScatterPlot):
    def __init__(self, subgraphs, Dw, *args, **kw):
        self._subgraphs = subgraphs
        self._nx_nodes = []
        self._nx_edges = []
        self._init_scatter(Dw)
        self.overlay_subgraphs()

    def _init_scatter(self, Dw):
        self._Dw = Dw
        id_dim, sel_dim = Dw.get_dim_name()
        self._dim = id_dim
        id_1, = Dw.get_identifiers(sel_dim, [0])
        id_2, = Dw.get_identifiers(sel_dim, [1])
        plots.ScatterPlot.__init__(self, Dw, Dw, id_dim, sel_dim, id_1, id_2, c='g', s=50,name="Hypo", alpha=.5)
    
    def overlay_subgraphs(self):
        all_nodes = self._Dw.get_identifiers(self._dim, sorted=True)
        for subgraph in self._subgraphs:
            # get xy positions from 
            nodes = subgraph.nodes()
            for i, node in enumerate(all_nodes):
                pos[node] = (self.xaxis_data[i], self.yaxis_data[i])
            nn = nx.draw_networkx_nodes(subgraph, pos, node_size=200, ax=self.axes, zorder=10)
            ee = nx.draw_networkx_edges(subgraph, pos, ax=self.axes, zorder=9)
            self._nx_nodes.append(nn)
            self._nx_edges.append(ee)

    def _delete_networks(self):
        if len(self._nx_nodes) > 0:
            for n in self._nx_nodes:
                self._nx_nodes.remove(n)
                self.axes.collections.remove(n)
        if len(self._nx_edges) > 0:
            for e in self._nx_edges:
                self._nx_edges.remove(e)
                self.axes.collections.remove(e)
    
    def set_ordinate(self, sb):
        self._delete_networks()
        self.overlay_subgraphs()
        plots.ScatterPlot.set_ordinate(self, sb)

    def set_absicca(self, sb):
        self._delete_networks()
        self.overlay_subgraphs()
        plots.ScatterPlot.set_absicca(self, sb)


class CAsinglesel(workflow.Function):
    """ Modified non-symmetric correpsondence analysis.

    Setup multiple selections:

    Input : - a subset(s) along a dimension (selection) of `interesting` identifiers.
            - Predefined subsets (categories) along the same dimension.

    1.) The cooccurence matrix of interesting identifers and categories is made.
    2.) The variables are scaled to represent the relative frequencies.

    """

    def run(X, Ckegg):
        pass


    class CASingleSelDouble(workflow.Function):
        """
        """

    def run(X, Ckegg):
        pass