import sys,os
import webbrowser

from fluents import logger, plots,workflow,dataset,main
from fluents.lib import blmfuncs,nx_utils,validation,engines,cx_stats,cx_utils
import gobrowser, geneontology
import scipy
import networkx as nx


class SmallTestWorkflow(workflow.Workflow):
    name = 'Smokers'
    ident = 'smokers'
    description = 'A small test workflow for gene expression analysis.'

    def __init__(self):
        workflow.Workflow.__init__(self)        

        # DATA IMPORT
        load = workflow.Stage('load', 'Data')
        load.add_function(DatasetLoadFunctionSmokerSmall())
        load.add_function(DatasetLoadFunctionSmokerMedium())
        load.add_function(DatasetLoadFunctionSmokerFull())
        load.add_function(DatasetLoadFunctionSmokerGO())
        #load.add_function(DatasetLoadFunctionCYCLE())
        self.add_stage(load)
        
        # NETWORK PREPROCESSING
        #net = workflow.Stage('net', 'Network integration')
        #net.add_function(DiffKernelFunction())
        #net.add_function(ModKernelFunction())
        #net.add_function(RandDiffKernelFunction())
        #self.add_stage(net)
        
        # BLM's
        model = workflow.Stage('models', 'Models')
        model.add_function(blmfuncs.PCA())
        model.add_function(blmfuncs.PLS())
        model.add_function(blmfuncs.LPLS())
        model.add_function(SAM())
        
        #model.add_function(bioconFuncs.SAM(app))
        self.add_stage(model)
        
        query = workflow.Stage('query', 'Gene Query')
        query.add_function(NCBIQuery())
        query.add_function(KEGGQuery())
        self.add_stage(query)
        
        # Gene Ontology
        go = workflow.Stage('go', 'Gene Ontology')
        go.add_function(gobrowser.LoadGOFunction())
        go.add_function(gobrowser.SetICFunction())
        # go.add_function(gobrowser.GOWeightFunction())
        # go.add_function(gobrowser.DistanceToSelectionFunction())
        # go.add_function(gobrowser.TTestFunction())
        go.add_function(gobrowser.PlotDagFunction())
        go.add_function(GoEnrichment())
        go.add_function(GoEnrichmentCond())
        self.add_stage(go)

        # EXTRA PLOTS
        #plt = workflow.Stage('net', 'Network')
        #plt.add_function(nx_analyser.KeggNetworkAnalyser())
        #self.add_stage(plt)
        
        logger.log('debug', 'Small test workflow is now active')


class DatasetLoadFunctionSmokerSmall(workflow.Function):
    """Loader for all ftsv files of smokers small datasets."""
    def __init__(self):
        workflow.Function.__init__(self, 'load_small', 'Smoker (Small)')

    def run(self):
        path = 'data/smokers-small/'
        files = os.listdir(path)
        out = []
        for fname in files:
            if fname.endswith('.ftsv'):
                input_file = open(os.path.join(path, fname))
                out.append(dataset.read_ftsv(input_file))
        return out


class DatasetLoadFunctionSmokerMedium(workflow.Function):
    """Loader for all ftsv files of smokers small datasets."""
    def __init__(self):
        workflow.Function.__init__(self, 'load_medium', 'Smoker (Medium)')

    def run(self):
        path = 'data/smokers-medium/'
        files = os.listdir(path)
        out = []
        for fname in files:
            if fname.endswith('.ftsv'):
                input_file = open(os.path.join(path, fname))
                out.append(dataset.read_ftsv(input_file))
        return out


class DatasetLoadFunctionSmokerFull(workflow.Function):
    """Loader for all ftsv files of smokers small datasets."""
    def __init__(self):
        workflow.Function.__init__(self, 'load_full', 'Smoker (Full)')

    def run(self):
        path = 'data/smokers-full/'
        files = os.listdir(path)
        out = []
        for fname in files:
            if fname.endswith('.ftsv'):
                input_file = open(os.path.join(path, fname))
                out.append(dataset.read_ftsv(input_file))
        return out

class DatasetLoadFunctionSmokerGO(workflow.Function):
    """Loader for all ftsv files of smokers small datasets."""
    def __init__(self):
        workflow.Function.__init__(self, 'load_go', 'Smoker (GO)')

    def run(self):
        path = 'data/smokers-go/'
        files = os.listdir(path)
        out = []
        for fname in files:
            if fname.endswith('.ftsv'):
                input_file = open(os.path.join(path, fname))
                out.append(dataset.read_ftsv(input_file))
        return out

class DatasetLoadFunctionCYCLE(workflow.Function):
    """Loader for pickled CYCLE datasets."""
    def __init__(self):
        workflow.Function.__init__(self, 'load_data', 'Cycle')

    def run(self):
        filename='fluents/data/CYCLE'
        if filename:
            return dataset.from_file(filename)            


##### WORKFLOW SPECIFIC FUNCTIONS ######
class SAM(workflow.Function):
    def __init__(self, id='sam', name='SAM'):
        workflow.Function.__init__(self, id, name)
        
    def run(self, x, y):
        
        n_iter = 50 #B
        alpha = 0.01 #cut off on qvals
        
        ###############

        # Main function call

        # setup prelimenaries
        import rpy
        rpy.r.library("siggenes")
        rpy.r.library("multtest")
        
        cl = scipy.dot(y.asarray(), scipy.diag([1,2,3]) ).sum(1)
        data = x.asarray().T
        sam = rpy.r.sam(data, cl=cl, B=n_iter, var_equal=False,med=False,s0=scipy.nan,rand=scipy.nan)
        qvals = scipy.asarray(rpy.r.slot(sam, "p.value"))
        pvals = scipy.asarray(rpy.r.slot(sam, "q.value"))
        
        sam_index = (qvals<alpha).nonzero()[0]

        # Update selection object
        dim_name = x.get_dim_name(1)
        sam_selection = x.get_identifiers(dim_name, indices=sam_index)
        main.project.set_selection(dim_name, sam_selection)
        
        sel = dataset.Selection('SAM selection')
        sel.select(dim_name, sam_selection)
        logger.log('notice','Number of significant varibles (SAM): %s' %len(sam_selection))

        # ## OUTPUT ###
        xcolname = x.get_dim_name(1) # genes
        x_col_ids = [xcolname, x.get_identifiers(xcolname, sorted=True)]
        sing_id = ['_john', ['0']] #singleton
        D_qvals = dataset.Dataset(qvals, (x_col_ids, sing_id), name='q_vals')
        D_pvals = dataset.Dataset(pvals, (x_col_ids, sing_id), name='p_vals')
        
        # plots
        s_indx = qvals.flatten().argsort()
        s_ids = [x_col_ids[0],[x_col_ids[1][i] for i in s_indx]]
        xindex = scipy.arange(len(qvals))
        qvals_s = qvals.take(s_indx)
        D_qs = dataset.Dataset(qvals_s, (s_ids, sing_id), name="sorted qvals")
        Dind = dataset.Dataset(xindex, (s_ids, sing_id), name="dum")
        st = plots.ScatterPlot(D_qs, Dind, 'gene_ids', '_john', '0', '0', s=10, name='SAM qvals')
        
        return [D_qvals, D_pvals, D_qs, st, sel]
        

class DiffKernelFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'diffkernel', 'Diffusion')

    def run(self, x, a):
        """x is gene expression data, a is the network.
        """
        #sanity check:
        g = a.asnetworkx()
        genes = x.get_identifiers(x.get_dim_name(1), sorted=True)
        W = nx.adj_matrix(g, nodelist=genes)
        X = x.asarray()
        Xc, mn_x = cx_utils.mat_center(X, ret_mn=True)
        out = []
        alpha=1.0
        beta = 1.0
        K = nx_utils.K_diffusion(W, alpha=alpha, beta=beta,normalised=True)
        Xp = scipy.dot(Xc, K) + mn_x
        # dataset
        row_ids = (x.get_dim_name(0),
                   x.get_identifiers(x.get_dim_name(0),
                                     sorted=True))
        col_ids = (x.get_dim_name(1),
                   x.get_identifiers(x.get_dim_name(1),
                                     sorted=True))
        
        xout = dataset.Dataset(Xp,
                               (row_ids, col_ids),
                               name=x.get_name()+'_diff'+str(alpha))
        out.append(xout)
        
        return out


class RandDiffKernelFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'diffkernel', 'Rand. Diff.')

    def run(self, x, a):
        """x is gene expression data, a is the network.
        """
        #sanity check:
        g = a.asnetworkx()
        genes = x.get_identifiers(x.get_dim_name(1))
        # randomise nodelist
        genes = [genes[i] for i in cx_utils.randperm(x.shape[1])]
        W = nx.adj_matrix(g, nodelist=genes)
        X = x.asarray()
        Xc, mn_x = cx_utils.mat_center(X, ret_mn=True)
        out = []
        alpha=1.
        beta = 1.0
        K = nx_utils.K_diffusion(W, alpha=alpha, beta=beta,normalised=True)

        Xp = scipy.dot(Xc, K) + mn_x
        # dataset
        row_ids = (x.get_dim_name(0),
                   x.get_identifiers(x.get_dim_name(0),
                                     sorted=True))
        col_ids = (x.get_dim_name(1),
                   x.get_identifiers(x.get_dim_name(1),
                                     sorted=True))
        
        xout = dataset.Dataset(Xp,
                               (row_ids, col_ids),
                               name=x.get_name()+'_diff'+str(alpha))
        out.append(xout)
        
        return out
    

class ModKernelFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'mokernel', 'Modularity')

    def run(self,x,a):
        X = x.asarray()
        g = a.asnetworkx()
        genes = x.get_identifiers(x.get_dim_name(1), sorted=True)
        W = nx.adj_matrix(g, nodelist=genes)
        out=[]
        alpha=.2
        Xc,mn_x = cx_utils.mat_center(X, ret_mn=True)
        K = nx_utils.K_modularity(W, alpha=alpha)
        Xp = scipy.dot(Xc, K)
        Xp = Xp + mn_x
        
        # dataset
        row_ids = (x.get_dim_name(0),
                   x.get_identifiers(x.get_dim_name(0),
                                     sorted=True))
        col_ids = (x.get_dim_name(1),
                   x.get_identifiers(x.get_dim_name(1),
                                     sorted=True))
        xout = dataset.Dataset(Xp,
                               (row_ids,col_ids),
                               name=x.get_name()+'_mod'+str(alpha))
        out.append(xout)
        return out


class NCBIQuery(workflow.Function):
    def __init__(self, gene_id_name='gene_id'):
        self._gene_id_name = gene_id_name
        workflow.Function.__init__(self, 'query', 'NCBI')

    def run(self, selection):
        if not selection.has_key(self._gene_id_name):
            logger.log("notice", "Expected gene ids: %s, but got: %s" %(self._gene_id_name, selection.keys()))
            return None
        if len(selection[self._gene_id_name])==0:
            logger.log("notice", "No selected genes to query")
            return None
        
        logger.log("notice", "No selected genes to query")
        base = 'http://www.ncbi.nlm.nih.gov/entrez/query.fcgi?'
        options = {r'&db=' : 'gene',
                   r'&cmd=' : 'retrieve',
                   r'&dopt=' : 'full_report'}
        gene_str = ''.join([gene + "+" for gene in selection[self._gene_id_name]])
        options[r'&list_uids='] = gene_str[:-1]
        opt_str = ''.join([key+value for key,value in options.items()])
        web_str = base + opt_str
        webbrowser.open(web_str)


class KEGGQuery(workflow.Function):
    def __init__(self, org='hsa', gene_id_name='gene_id'):
        self._org=org
        self._gene_id_name = gene_id_name
        workflow.Function.__init__(self, 'query', 'KEGG')

    def run(self, selection):
        if not selection.has_key(self._gene_id_name):
            logger.log("notice", "Expected gene ids: %s, but got. %s" %(self._gene_id_name, selection.keys()))
            return None
        if len(selection[self._gene_id_name])==0:
            logger.log("notice", "No selected genes to query")
            return None
        
        base = r'http://www.genome.jp/dbget-bin/www_bget?'
        gene_str = ''.join([gene + "+" for gene in selection[self._gene_id_name]])
        gene_str = gene_str[:-1]
        gene_str = self._org + "+" + gene_str
        web_str = base + gene_str
        webbrowser.open(web_str)


class GoEnrichment(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'goenrich', 'Go Enrichment')

    def run(self, data):
        import rpy
        rpy.r.library("GOstats")
        
        # Get universe
        # Here, we are using a defined dataset to represent the universe
        if not 'gene_ids' in data:
            logger.log('notice', 'No dimension called [gene_ids] in dataset: %s', data.get_name())
            return
        universe = list(data.get_identifiers('gene_ids'))
        logger.log('notice', 'Universe consists of %s gene ids from %s' %(len(universe), data.get_name()))
        # Get current selection and validate
        curr_sel = main.project.get_selection()
        selected_genes = list(curr_sel['gene_ids'])
        if len(selected_genes)==0:
            logger.log('notice', 'This function needs a current selection!')
            return
        
        # Hypergeometric parameter object
        pval_cutoff = 0.9999
        cond = False
        test_direction = 'over'
        params = rpy.r.new("GOHyperGParams",
                           geneIds=selected_genes,
                           annotation="hgu133a",
                           ontology="BP",
                           pvalueCutoff=pval_cutoff,
                           conditional=cond,
                           testDirection=test_direction
                           )
        # run test
        # result.keys(): ['Count', 'Term', 'OddsRatio', 'Pvalue', 'ExpCount', 'GOBPID', 'Size']
        result = rpy.r.summary(rpy.r.hyperGTest(params))
        
        # dataset
        terms = result['GOBPID']
        pvals = scipy.log(scipy.asarray(result['Pvalue']))
        row_ids = ('go-terms', terms)
        col_ids = ('_john', ['_doe'])
        
        xout = dataset.Dataset(pvals,
                               (row_ids, col_ids),
                               name='P values (enrichment)')
        return [xout]


class GoEnrichmentCond(workflow.Function):
    """ Enrichment conditioned on dag structure."""
    def __init__(self):
        workflow.Function.__init__(self, 'goenrich', 'Go Cond. Enrich.')

    def run(self, data):
        import rpy
        rpy.r.library("GOstats")
        
        # Get universe
        # Here, we are using a defined dataset to represent the universe
        if not 'gene_ids' in data:
            logger.log('notice', 'No dimension called [gene_ids] in dataset: %s', data.get_name())
            return
        universe = list(data.get_identifiers('gene_ids'))
        logger.log('notice', 'Universe consists of %s gene ids from %s' %(len(universe), data.get_name()))
        # Get current selection and validate
        curr_sel = main.project.get_selection()
        selected_genes = list(curr_sel['gene_ids'])
        if len(selected_genes)==0:
            logger.log('notice', 'This function needs a current selection!')
            return
        
        # Hypergeometric parameter object
        pval_cutoff = 0.9999
        cond = True
        test_direction = 'over'
        params = rpy.r.new("GOHyperGParams",
                           geneIds=selected_genes,
                           annotation="hgu133a",
                           ontology="BP",
                           pvalueCutoff=pval_cutoff,
                           conditional=cond,
                           testDirection=test_direction
                           )
        # run test
        # result.keys(): ['Count', 'Term', 'OddsRatio', 'Pvalue', 'ExpCount', 'GOBPID', 'Size']
        result = rpy.r.summary(rpy.r.hyperGTest(params))
        
        # dataset
        terms = result['GOBPID']
        pvals = scipy.log(scipy.asarray(result['Pvalue']))
        row_ids = ('go-terms', terms)
        col_ids = ('_john', ['_doe'])
        
        xout = dataset.Dataset(pvals,
                               (row_ids, col_ids),
                               name='P values (enrichment)')
        return [xout]