import gtk
from laydi import dataset, logger, plots, workflow
#import geneontology
#import gostat
from scipy import array, randn, log, ones
import cPickle
import networkx

class TestWorkflow (workflow.Workflow):

    name = 'Test Workflow'
    ident = 'test'
    description = 'Test Gene Ontology Workflow. This workflow currently serves as a general testing workflow.'

    def __init__(self, app):
        workflow.Workflow.__init__(self, app)

        load = workflow.Stage('load', 'Load Data')
        load.add_function(CelFileImportFunction())
        load.add_function(DataLoadTestFunction(self))
        load.add_function(TestDataFunction())
        load.add_function(DatasetLoadFunction())
        load.add_function(SelectFunction())
        self.add_stage(load)

        preproc = workflow.Stage('preprocess', 'Preprocessing')
        preproc.add_function(DatasetLog())
        preproc.add_function(workflow.Function('rma', 'RMA'))
        self.add_stage(preproc)

        go = workflow.Stage('go', 'Gene Ontology Data')
        go.add_function(GODistanceFunction())
        go.add_function(ImagePlotFunction())
        self.add_stage(go)

        regression = workflow.Stage('regression', 'Regression')
        regression.add_function(workflow.Function('pls', 'PLS'))
        self.add_stage(regression)

        explore = workflow.Stage('explore', 'Explorative analysis')
        explore.add_function(PCAFunction(self))
        self.add_stage(explore)

        save = workflow.Stage('save', 'Save Data')
        save.add_function(DatasetSaveFunction())
        self.add_stage(save)
        
        
class LoadAnnotationsFunction(workflow.Function):

    def __init__(self):
        workflow.Function.__init__(self, 'load-go-ann', 'Load Annotations')
        self.annotations = None

    def load_file(self, filename):
        f = open(filename)
        self.annotations = Annotations('genes', 'go-terms')
        logger.log('notice', 'Loading annotation file: %s' % filename)

        for line in f.readlines():
            val = line.split(' \t')

            if len(val) > 1:
                val = [v.strip() for v in val]
                retval.add_annotations('genes', val[0], 
                                       'go-terms', set(val[1:]))
            
    def on_response(self, dialog, response):
        if response == gtk.RESPONSE_OK:
            logger.log('notice', 'Reading file: %s' % dialog.get_filename())
            self.load_file(dialog.get_filename())

    def run(self):
        btns = ('Open', gtk.RESPONSE_OK, \
                'Cancel', gtk.RESPONSE_CANCEL)
        dialog = gtk.FileChooserDialog('Open GO Annotation File',
                                       buttons=btns)
        dialog.connect('response', self.on_response)
        dialog.run()
        dialog.destroy()
        return [self.annotations]

class GODistanceFunction(workflow.Function):

    def __init__(self):
        workflow.Function.__init__(self, 'go_diatance', 'GO Distances')
        self.output = None

    def run(self, data):
        logger.log('debug', 'datatype: %s' % type(data))
        if not type(data) == Annotations:
            return None

        logger.log('debug', 'dimensions: %s' % data.dimensions)
        
        genes = data.get_ids('genes')
        gene_distances = array((len(genes), len(genes)))

        return gene_distances


class ImagePlotFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'image', 'Show Image')

    def run(self, data):
        return [plots.ImagePlot(data, name='foo')]


class TestDataFunction(workflow.Function):  
    def __init__(self):
        workflow.Function.__init__(self, 'test_data', 'Generate Test Data')

    def run(self):
        logger.log('notice', 'Injecting foo test data')
        x = randn(500,15)
        X = dataset.Dataset(x)
        p = plots.ScatterPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='scatter')
        p2 = plots.ScatterMarkerPlot(X, X, 'rows', 'rows', '0_1', '0_2',name='marker')
        graph = networkx.XGraph()
        for x in 'ABCDEF':
            for y in 'ADE':
                graph.add_edge(x, y, 3)
        ds = dataset.GraphDataset(array(networkx.adj_matrix(graph)))
        ds_plot = plots.NetworkPlot(ds)

        cds = dataset.CategoryDataset(ones([3, 3]))
        ds_scatter = plots.ScatterMarkerPlot(ds, ds, 'rows_0', 'rows_0', '0_1', '0_2')
        lp = plots.LineViewPlot(X,major_axis=0)
        vp = plots.VennPlot()
        return [X, ds, p, ds_plot, ds_scatter, p2, cds, lp, vp]

class SelectFunction(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'select', 'Select')

    def run(self, data):
        s = dataset.Selection('Arbitrary selection')
        s.select('rows', ['0_1', '0_2'])
        return [s]

class DatasetLog(workflow.Function):
    def __init__(self):
        workflow.Function.__init__(self, 'log', 'Log')

    def run(self, data):
        logger.log('notice', 'Taking the log of dataset %s' % data.get_name())
        d = data.asarray()
        d = log(d)
        new_data_name = 'log(%s)' % data.get_name()
        ds = dataset.Dataset(d, name=new_data_name)
        return [ds]

class DatasetLoadFunction(workflow.Function):
    """Loader for previously pickled Datasets."""
    def __init__(self):
        workflow.Function.__init__(self, 'load_data', 'Load Pickled Dataset')

    def run(self):
        chooser = gtk.FileChooserDialog(title="Select cel files...", parent=None,
                                        action=gtk.FILE_CHOOSER_ACTION_OPEN,
                                        buttons=(gtk.STOCK_CANCEL, gtk.RESPONSE_CANCEL,
                                                 gtk.STOCK_OPEN, gtk.RESPONSE_OK))
        pkl_filter = gtk.FileFilter()
        pkl_filter.set_name("Python pickled data files (*.pkl)")
        pkl_filter.add_pattern("*.[pP][kK][lL]")
        all_filter = gtk.FileFilter()
        all_filter.set_name("All Files (*.*)")
        all_filter.add_pattern("*")
        chooser.add_filter(pkl_filter)
        chooser.add_filter(all_filter)

        try:
            if chooser.run() == gtk.RESPONSE_OK:
                return [cPickle.load(open(chooser.get_filename()))]
        finally:
            chooser.destroy()


class DatasetSaveFunction(workflow.Function):
    """QND way to save data to file for later import to this program."""
    def __init__(self):
        workflow.Function.__init__(self, 'save_data', 'Save Pickled Dataset')

    def run(self):
        if not data:
            logger.log("notice", "No data to save.")
            return
        else:
            data = data[0]
        chooser = gtk.FileChooserDialog(title="Save pickled data...", parent=None,
                                        action=gtk.FILE_CHOOSER_ACTION_SAVE,
                                        buttons=(gtk.STOCK_CANCEL, gtk.RESPONSE_CANCEL,
                                                 gtk.STOCK_SAVE, gtk.RESPONSE_OK))
        pkl_filter = gtk.FileFilter()
        pkl_filter.set_name("Python pickled data files (*.pkl)")
        pkl_filter.add_pattern("*.[pP][kK][lL]")
        all_filter = gtk.FileFilter()
        all_filter.set_name("All Files (*.*)")
        all_filter.add_pattern("*")
        chooser.add_filter(pkl_filter)
        chooser.add_filter(all_filter)
        chooser.set_current_name(data.get_name() + ".pkl")

        try:
            if chooser.run() == gtk.RESPONSE_OK:
                cPickle.dump(data, open(chooser.get_filename(), "w"), protocol=2)
                logger.log("notice", "Saved data to %r." % chooser.get_filename())
        finally:
            chooser.destroy()
                

class CelFileImportFunction(workflow.Function):
    """Loads AffyMetrix .CEL-files into matrix."""
    def __init__(self):
        workflow.Function.__init__(self, 'cel_import', 'Import Affy')

    def run(self, data):
        import rpy
        chooser = gtk.FileChooserDialog(title="Select cel files...", parent=None,
                                        action=gtk.FILE_CHOOSER_ACTION_OPEN,
                                        buttons=(gtk.STOCK_CANCEL, gtk.RESPONSE_CANCEL,
                                                 gtk.STOCK_OPEN, gtk.RESPONSE_OK))
        chooser.set_select_multiple(True)
        cel_filter = gtk.FileFilter()
        cel_filter.set_name("Cel Files (*.cel)")
        cel_filter.add_pattern("*.[cC][eE][lL]")
        all_filter = gtk.FileFilter()
        all_filter.set_name("All Files (*.*)")
        all_filter.add_pattern("*")
        chooser.add_filter(cel_filter)
        chooser.add_filter(all_filter)

        try:
            if chooser.run() == gtk.RESPONSE_OK:
                rpy.r.library("affy")
    
                silent_eval = rpy.with_mode(rpy.NO_CONVERSION, rpy.r)
                silent_eval('E <- ReadAffy(filenames=c("%s"))' % '", "'.join(chooser.get_filenames()))
                silent_eval('E <- rma(E)')

                m = rpy.r('m <- E@exprs')
    
                vector_eval = rpy.with_mode(rpy.VECTOR_CONVERSION, rpy.r)
                rownames = vector_eval('rownames(m)') 
                colnames = vector_eval('colnames(m)') 

                # We should be nice and clean up after ourselves
                rpy.r.rm(["E", "m"])
                
                if m:
                    data = dataset.Dataset(m, (('ids', rownames), ('filename', colnames)), name="AffyMatrix Data")
                    plot = plots.LinePlot(data, "Gene profiles")
                    return [data, plot]
                else:
                    logger.log("notice", "No data loaded from importer.")
        finally:
            chooser.destroy()


class DataLoadTestFunction(workflow.Function):
    def __init__(self, wf):
        workflow.Function.__init__(self, 'datadirload', 'Load from datadir')
        self._wf = wf
    
    def run(self):
        print self._wf.get_data_file_name('smoker-x.ftsv')
        fn = self._wf.get_data_file_name('smoker-x.ftsv')
        if fn:
            fd = open(fn)
            ds = dataset.read_ftsv(fd)
            return [ds]
        else:
            print "Cannot find file %s" % fn
            return []

class PCAFunction(workflow.Function):
    """Generic PCA function."""
    def __init__(self, wf):
        workflow.Function.__init__(self, 'pca', 'PCA')
        self._workflow = wf

    def run(self, data):
        import rpy
        
        dim_2, dim_1 = data.get_dim_names()
        
    
        silent_eval = rpy.with_mode(rpy.NO_CONVERSION, rpy.r)
        rpy.with_mode(rpy.NO_CONVERSION, rpy.r.assign)("m", data.asarray())
        silent_eval("t = prcomp(t(m))")

        T_ids = map(str, range(1, rpy.r("dim(t$x)")[1]+1))
        T = dataset.Dataset(rpy.r("t$x"), [(dim_1, data.get_identifiers(dim_1)),
                                   ("component", T_ids)], name="T")
        P = dataset.Dataset(rpy.r("t$rotation"), [(dim_2, data.get_identifiers(dim_2)),
                                          ("component", T_ids)], name="P")

        # cleanup
        rpy.r.rm(["t", "m"])

        loading_plot = plots.ScatterMarkerPlot(P, P, 'ids','component','1','2', "Loadings")
        score_plot = plots.ScatterMarkerPlot(T, T,'filename','component','1','2', "Scores")
        
        return [T, P, loading_plot, score_plot]