import pygtk
import gobject
import gtk

from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg
from matplotlib.nxutils import points_inside_poly
from matplotlib.figure import Figure
from matplotlib.collections import LineCollection
from matplotlib.patches import Polygon,Rectangle, Circle
from matplotlib.lines import Line2D
from matplotlib.mlab import prctile
import networkx
import scipy

import fluents
import logger
import view

    
class Plot (view.View):
    def __init__(self, title):
        view.View.__init__(self, title)
        logger.log('debug', 'plot %s init' %title)
        self.selection_listener = None
        self.current_dim = None
        self._current_selection = None
        self._frozen = False
        self._init_mpl()
        
    def _init_mpl(self):    
        # init matplotlib related stuff 
        self._background = None
        self._use_blit = False
        self.fig = Figure()
        self.canvas = FigureCanvasGTKAgg(self.fig)
        self.axes = self.fig.gca()
        self._toolbar = view.PlotToolbar(self)
        self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK)
        self.add(self.canvas)
        self.canvas.show()
    
    def set_frozen(self, frozen):
        """A frozen plot will not be updated when the current
        selection is changed."""
        self._frozen = frozen
        if not frozen and self._current_selection != None:
            self.set_current_selection(self._current_selection)
    
    def get_title(self):
        return self.title

    def get_toolbar(self):
        return self._toolbar
    
    def selection_changed(self, dim_name, selection):
        """ Selection observer handle.
        
        A selection change in a plot is only drawn if:
        1.) plot is sensitive to selections (not freezed)
        2.) plot is visible (has a view)
        3.) the selections dim_name is the plot's dimension.
        """
        
        self._current_selection = selection
        if self._frozen \
               or not self.get_property('visible') \
               or self.current_dim != dim_name:
            return

        self.set_current_selection(selection)

    def set_selection_listener(self, listener):
        """Allow project to listen to selections.

        The selection will propagate back to all plots through the
        selection_changed() method. The listener will be called as
        listener(dimension_name, ids).
        """
        self.selection_listener = listener

    def update_selection(self, ids, key=None):
        """Returns updated current selection from ids.
        If a key is pressed we use the appropriate mode.

        key map:
        shift : union
        control : intersection

        """
        if key == 'shift':
            ids = set(ids).union(self._current_selection[self.current_dim])
        elif key == 'control':
            ids = set(ids).intersection(self._current_selection[self.current_dim])
        return ids

    def set_current_selection(self, selection):
        """Called whenever the plot should change the selection.

        This method is a dummy method, so that specialized plots that have 
        no implemented selection can ignore selections alltogether.
        """
        pass

    def rectangle_select_callback(self, *args):
        """Overrriden in subclass."""
        if hasattr(self, 'canvas'):
            self.canvas.draw()

    def lasso_select_callback(self, *args):
        """Overrriden in subclass."""
        if hasattr(self, 'canvas'):
            self.canvas.draw()
    
    def get_index_from_selection(self, dataset, selection):
        """Returns the index vector of current selection in given dim."""
        if not selection: return []
        ids = selection.get(self.current_dim, []) # current identifiers
        if not ids : return []
        return dataset.get_indices(self.current_dim, ids)


class LineViewPlot(Plot):
    """Line view plot with percentiles.

    A line view of vectors across a specified dimension of input dataset. 
    No selection interaction is defined.
    Only support for 2d-arrays.
    input:
           -- major_axis : dim_number for line dim (see scipy.ndarray for axis def.)
           -- minor_axis : needs definition only for higher order arrays

    fixme: slow
    """
    def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"):
        Plot.__init__(self, name)
        self.dataset = dataset
        self._data = dataset.asarray()
        if len(self._data.shape)==2 and not minor_axis:
            minor_axis = major_axis - 1
        self.major_axis = major_axis
        self.minor_axis = minor_axis
        self.current_dim = self.dataset.get_dim_name(major_axis)
        
        #initial draw
        self.line_coll = None
        self.line_segs = []
        x_axis = scipy.arange(self._data.shape[minor_axis])
        for xi in range(self._data.shape[major_axis]):
            yi = self._data.take([xi], major_axis).ravel()
            self.line_segs.append([(xx,yy) for xx,yy in zip(x_axis, yi)])
        
        # draw background
        self._set_background(self.axes)
        
        # Disable selection modes
        self._toolbar.freeze_button.set_sensitive(False)
        self._toolbar.set_mode_sensitive('select', False)
        self._toolbar.set_mode_sensitive('lassoselect', False)
        
    def _set_background(self, ax):
        """Add three patches representing [min max],[5,95] and [25,75] percentiles, and a line at the median.
        """
        if self._data.shape[self.minor_axis]<10:
            return 
        # settings
        patch_color = 'b' #blue
        patch_lw = 0 #no edges
        patch_alpha = .15 # transparancy
        median_color = 'b' #blue
        median_width = 1.5 #linewidth
        percentiles = [0, 5, 25, 50, 75, 100]

        # ordinate
        xax = scipy.arange(self._data.shape[self.minor_axis])

        #vertices
        verts_0 = [] #100,0
        verts_1 = [] # 90,10
        verts_2 = [] # 75,25
        med = []
        # add top vertices the low vertices (do i need an order?)#background
        for i in xax:
            prct = prctile(self._data.take([i], self.minor_axis), percentiles)
            verts_0.append((i, prct[0]))
            verts_1.append((i, prct[1]))
            verts_2.append((i, prct[2]))
            med.append(prct[3])
        for i in xax[::-1]:
            prct = prctile(self._data.take([i], self.minor_axis), percentiles)
            verts_0.append((i, prct[-1]))
            verts_1.append((i, prct[-2]))
            verts_2.append((i, prct[-3]))

        # make polygons from vertices
        bck0 = Polygon(verts_0, alpha=patch_alpha, lw=patch_lw,
                       facecolor=patch_color)
        bck1 = Polygon(verts_1, alpha=patch_alpha, lw=patch_lw,
                       facecolor=patch_color)
        bck2 = Polygon(verts_2, alpha=patch_alpha, lw=patch_lw,
                       facecolor=patch_color)

        # add polygons to axes
        ax.add_patch(bck0)
        ax.add_patch(bck1)
        ax.add_patch(bck2)
        # median line
        ax.plot(xax, med, median_color, linewidth=median_width)

        
    def set_current_selection(self, selection):
        """Draws the current selection.
        """
        index = self.get_index_from_selection(self.dataset, selection)

        if self.line_coll:
            self.axes.collections.remove(self.line_coll)
        segs = [self.line_segs[i] for i in index]
        self.line_coll = LineCollection(segs, colors=(1,0,0,1))
        self.axes.add_collection(self.line_coll)
        
        #draw
        if self._use_blit:
            if self._background is None:
                self._background = self.canvas.copy_from_bbox(self.axes.bbox)
                self.canvas.restore_region(self._background)
            self.axes.draw_artist(self.lines)
            self.canvas.blit()
        else:
            self.canvas.draw()
            

class ScatterMarkerPlot(Plot):
    """The ScatterMarkerPlot is faster than regular scatterplot, but
       has no color and size options."""
    
    def __init__(self, dataset_1, dataset_2, id_dim, sel_dim,
                 id_1, id_2, s=6, name="Scatter plot"):
        Plot.__init__(self, name)
        self.current_dim = id_dim
        self.dataset_1 = dataset_1
        self.ms = s
        x_index = dataset_1[sel_dim][id_1]
        y_index = dataset_2[sel_dim][id_2]
        self.xaxis_data = dataset_1._array[:, x_index]
        self.yaxis_data = dataset_2._array[:, y_index]

        # init draw
        self._selection_line = None
        self.line = self.axes.plot(self.xaxis_data, self.yaxis_data, 'o', markeredgewidth=0, markersize=s)
        self.axes.axhline(0, color='k', lw=1., zorder=1)
        self.axes.axvline(0, color='k', lw=1., zorder=1)

    def rectangle_select_callback(self, x1, y1, x2, y2, key):
        ydata = self.yaxis_data
        xdata = self.xaxis_data

        # find indices of selected area
        if x1>x2:
            x1, x2 = x2, x1
        if y1>y2:
            y1, y2 = y2, y1
        assert x1<=x2
        assert y1<=y2
        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
        ids = self.dataset_1.get_identifiers(self.current_dim, index)
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)

    def lasso_select_callback(self, verts, key=None):
        xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
        index = scipy.nonzero(points_inside_poly(xys, verts))[0]
        ids = self.dataset_1.get_identifiers(self.current_dim, index)
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)
    
    def set_current_selection(self, selection):
        #remove old selection
        if self._selection_line:
           self.axes.lines.remove(self._selection_line)
        index = self.get_index_from_selection(self.dataset_1, selection)
        if len(index)==0:
            # no selection
            self.canvas.draw()
            self._selection_line = None
            return
        
        xdata_new = self.xaxis_data.take(index) #take data
        ydata_new = self.yaxis_data.take(index)
        self._selection_line = Line2D(xdata_new, ydata_new
                                      ,marker='o', markersize=self.ms,
                                      linewidth=0, markerfacecolor='r',
                                      markeredgewidth=1.0)
        self.axes.add_line(self._selection_line)

        if self._use_blit:
            if self._background is None:
                self._background = self.canvas.copy_from_bbox(self.axes.bbox)
            self.canvas.restore_region(self._background)
            if self.selection_line:
                self.axes.draw_artist(self._selection_line)
            self.canvas.blit()
        else:
            self.canvas.draw()


class ScatterPlot(Plot):
    """The ScatterPlot is slower than scattermarker, but has size option."""
    def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2, c='b', s=30, sel_dim_2=None, name="Scatter plot"):

        Plot.__init__(self, name)
        self.dataset_1 = dataset_1
        self.s = s
        self.c = c
        self.current_dim = id_dim
        
        x_index = dataset_1[sel_dim][id_1]
        if sel_dim_2:
            y_index = dataset_2[sel_dim_2][id_2]
        else:
            y_index = dataset_2[sel_dim][id_2]
        self.xaxis_data = dataset_1._array[:, x_index]
        self.yaxis_data = dataset_2._array[:, y_index]

        # init draw
        self.init_draw()

        # signals to enable correct use of blit
        self.connect('zoom-changed', self.onzoom)
        self.connect('pan-changed', self.onpan)
        self.need_redraw = False
        self.canvas.mpl_connect('resize_event', self.onresize)
        
    def onzoom(self, widget, mode):
        logger.log('notice', 'Zoom in widget: %s' %widget)
        self.clean_redraw()
        
    def onpan(self, widget, mode):
        logger.log('notice', 'Pan in widget: %s' %widget)
        self.clean_redraw()
        
    def onresize(self, widget):
        logger.log('notice', 'resize event')
        self.clean_redraw()

    def clean_redraw(self):
        if self._use_blit == True:
            logger.log('notice', 'blit -> clean redraw ')
            self.set_current_selection(None)
            self._background = self.canvas.copy_from_bbox(self.axes.bbox)
            self.set_current_selection(self._current_selection)
        else:
            self._background = None
        
    def init_draw(self):
        lw = scipy.zeros(self.xaxis_data.shape)
        self.sc = self.axes.scatter(self.xaxis_data, self.yaxis_data,
                                  s=self.s, c=self.c, linewidth=lw, zorder=3)
        self.axes.axhline(0, color='k', lw=1., zorder=1)
        self.axes.axvline(0, color='k', lw=1., zorder=1)
        self.selection_collection = self.axes.scatter(self.xaxis_data,
                                                      self.yaxis_data,
                                                      alpha=0,
                                                      c='w',s=self.s,
                                                      linewidth=0,
                                                      zorder=4)
        self._background = self.canvas.copy_from_bbox(self.axes.bbox)
    
    def is_mappable_with(self, obj):
        """Returns True if dataset/selection is mappable with this plot.
        """
        if isinstance(obj, fluents.dataset.Dataset):
            if self.current_dim in obj.get_dim_name() \
                   and obj.asarray().shape[0] == self.xaxis_data.shape[0]:
                return True
        
        elif isinstance(obj, fluents.dataset.Selection):
            if self.current_dim in obj.get_dim_name():
                print "Selection is mappable"
                return True

        else:
            return False
        
    def _update_color_from_dataset(self, data):
        """Updates the facecolors from a dataset.
        """
        array = data.asarray()
        #only support for 2d-arrays:
        try:
            m, n = array.shape
        except:
            raise ValueError, "No support for more than 2 dimensions."
        # is dataset a vector or matrix?
        if not n==1:
            # we have a category dataset
            if isinstance(data, fluents.dataset.CategoryDataset):
                map_vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1)
            else:
                map_vec = array.sum(1)
        else:
            map_vec = array.ravel()

        # update facecolors
        self.sc.set_array(map_vec)
        self.sc.set_clim(map_vec.min(), map_vec.max())
        self.sc.update_scalarmappable() #sets facecolors from array
        self.canvas.draw()
                    
    def rectangle_select_callback(self, x1, y1, x2, y2, key):
        ydata = self.yaxis_data
        xdata = self.xaxis_data

        # find indices of selected area
        if x1>x2:
            x1, x2 = x2, x1
        if y1>y2:
            y1, y2 = y2, y1
        assert x1<=x2
        assert y1<=y2

        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
        ids = self.dataset_1.get_identifiers(self.current_dim, index)
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)

    def lasso_select_callback(self, verts, key=None):
        xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
        index = scipy.nonzero(points_inside_poly(xys, verts))[0]
        ids = self.dataset_1.get_identifiers(self.current_dim, index)
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)
        
    def set_current_selection(self, selection):
        linewidth = scipy.zeros(self.xaxis_data.shape, 'f')
        index = self.get_index_from_selection(self.dataset_1, selection)
        if len(index) > 0:
            linewidth.put(1, index)
        self.selection_collection.set_linewidth(linewidth)
            
        if self._use_blit and len(index)>0 :
            if self._background is None:
                    self._background = self.canvas.copy_from_bbox(self.axes.bbox)
            self.canvas.restore_region(self._background)
            self.axes.draw_artist(self.selection_collection)
            self.canvas.blit()
        else:
            self.canvas.draw()
        
        
class ImagePlot(Plot):
    def __init__(self, dataset, **kw):
        Plot.__init__(self, kw.get('name', 'Image Plot'))
        self.dataset = dataset

        # Initial draw
        self.axes.grid(False)
        self.axes.imshow(dataset.asarray(), interpolation='nearest')
        self.axes.axis('tight')

        # Disable selection modes
        self._toolbar.freeze_button.set_sensitive(False)
        self._toolbar.set_mode_sensitive('select', False)
        self._toolbar.set_mode_sensitive('lassoselect', False)

    
class HistogramPlot(Plot):
    """ Histogram plot.
    If dataset is 1-dim the current_dim is set and selections may
    be performed. For dataset> 1.dim the histogram is over all values
    and selections are not defined,"""
    
    def __init__(self, dataset, **kw):
        Plot.__init__(self, kw['name'])
        self.dataset = dataset
        self._data = dataset.asarray()

        # If dataset is 1-dim we may do selections
        if dataset.shape[0]==1:
            self.current_dim = dataset.get_dim_name(1)
        if dataset.shape[1]==1:
            self.current_dim = dataset.get_dim_name(0)
        # Initial draw
        self.axes.grid(False)
        bins = min(self._data.size, 20)
        count, lims, self.patches = self.axes.hist(self._data, bins=bins)

        # Add identifiers to the individual patches
        if self.current_dim != None:
            for i, patch in enumerate(self.patches):
                if i==len(self.patches)-1:
                    end_lim = self._data.max() + 1
                else:
                    end_lim = lims[i+1]
                bool_ind = scipy.bitwise_and(self._data>=lims[i],
                                             self._data<=end_lim)
                patch.index = scipy.where(bool_ind)[0]

        if self.current_dim==None:
            # Disable selection modes
            self._toolbar.freeze_button.set_sensitive(False)
            self._toolbar.set_mode_sensitive('select', False)
            self._toolbar.set_mode_sensitive('lassoselect', False)
    
    def rectangle_select_callback(self, x1, y1, x2, y2, key):
        if self.current_dim == None: return
        # make (x1, y1) the lower left corner
        if x1>x2:
            x1, x2 = x2, x1
        if y1>y2:
            y1, y2 = y2, y1

        self.active_patches = []
        for patch in self.patches:
            xmin = patch.xy[0]
            xmax = xmin + patch.get_width()
            ymin, ymax = 0, patch.get_height() 
            if xmax>x1 and xmin<x2 and (ymax> y2 or ymax>y1):
                self.active_patches.append(patch)
        if not self.active_patches: return

        ids = set()
        for patch in self.active_patches:
            ids.update(self.dataset.get_identifiers(self.current_dim,
                                                    patch.index))
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)

    def lasso_select_callback(self, verts, key):
        if self.current_dim == None: return
        self.active_patches = []
        for patch in self.patches:
            if scipy.any(points_inside_poly(verts, patch.get_verts())):
                self.active_patches.append(patch)
        if not self.active_patches: return
        ids = set()
        for patch in self.active_patches:
            ids.update(self.dataset.get_identifiers(self.current_dim,
                                                    patch.index))
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)

    def set_current_selection(self, selection):
        index = self.get_index_from_selection(self.dataset, selection)
        for patch in self.patches:
            patch.set_facecolor('b')
        for patch in self.patches:
            if scipy.intersect1d(patch.index, index).size>1:
                patch.set_facecolor('r')
        self.canvas.draw()


class BarPlot(Plot):
    """Bar plot.

    Ordinary bar plot for (column) vectors.
    For matrices there is one color for each row.
    """
    def __init__(self, dataset, **kw):
        Plot.__init__(self, kw.get('name', 'Bar Plot'))
        self.dataset = dataset
        
        # Initial draw
        self.axes.grid(False)
        n, m = dataset.shape
        if m>1:
            sm = matplotlib.cm.ScalarMappable()
            clrs = sm.to_rgba(range(n))
            for i, row in enumerate(dataset.asarray()):
                left = scipy.arange(i+1, m*n+1, n)
                height = row
                color = clrs[i]
                c = (color[0], color[1], color[2])
                self.axes.bar(left, height,color=c)
        else:
            height = dataset.asarray().ravel()
            left = scipy.arange(1, n, 1)
            self.axes.bar(left, height)

        # Disable selection modes
        self._toolbar.freeze_button.set_sensitive(False)
        self._toolbar.set_mode_sensitive('select', False)
        self._toolbar.set_mode_sensitive('lassoselect', False)
        

class NetworkPlot(Plot):
    def __init__(self, dataset, pos=None, nodecolor='b', nodesize=40,
                 prog='neato', with_labels=False, name='Network Plot'):

        Plot.__init__(self, name)
        self.dataset = dataset
        self.graph = dataset.asnetworkx()
        self._prog = prog
        self._pos = pos
        self._nodesize = nodesize
        self._nodecolor = nodecolor
        self._with_labels = with_labels
        
        self.current_dim = self.dataset.get_dim_name(0)

        if not self._pos:
            self._pos = networkx.graphviz_layout(self.graph, self._prog)
        self._xy = scipy.asarray([self._pos[node] for node in self.dataset.get_identifiers(self.current_dim, sorted=True)])
        self.xaxis_data = self._xy[:,0]
        self.yaxis_data = self._xy[:,1]

        # Initial draw
        self.default_props = {'nodesize': 50, 'nodecolor':'gray'}
        self.node_collection = None
        self.edge_collection = None
        self.node_labels = None
        lw = scipy.zeros(self.xaxis_data.shape)
        self.node_collection = self.axes.scatter(self.xaxis_data, self.yaxis_data,
                                                 s=self._nodesize,
                                                 c=self._nodecolor,
                                                 linewidth=lw,
                                                 zorder=3)
        # selected nodes is a transparent graph that adjust edge-visibility
        # according to the current selection
        self.selected_nodes = self.axes.scatter(self.xaxis_data,
                                                self.yaxis_data,
                                                s=self._nodesize,
                                                c=self._nodecolor,
                                                linewidth=lw,
                                                zorder=4,
                                                alpha=0)
        
        self.edge_collection = networkx.draw_networkx_edges(self.graph,
                                                            self._pos,
                                                            ax=self.axes,
                                                            edge_color='gray')
        if self._with_labels:
            self.node_labels = networkx.draw_networkx_labels(self.graph,
                                                            self._pos,
                                                            ax=self.axes)
        
        # remove axes, frame and grid
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.grid(False)
        self.axes.set_frame_on(False)
        self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

    def rectangle_select_callback(self, x1, y1, x2, y2, key):
        ydata = self.yaxis_data
        xdata = self.xaxis_data
        
        # find indices of selected area
        if x1>x2:
            x1, x2 = x2, x1
        if y1>y2:
            y1, y2 = y2, y1
        assert x1<=x2
        assert y1<=y2
        
        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
        ids = self.dataset.get_identifiers(self.current_dim, index)
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)

    def lasso_select_callback(self, verts, key=None):
        xys = scipy.c_[self.xaxis_data[:,scipy.newaxis], self.yaxis_data[:,scipy.newaxis]]
        index = scipy.nonzero(points_inside_poly(xys, verts))[0]
        ids = self.dataset.get_identifiers(self.current_dim, index)
        ids = self.update_selection(ids, key)
        self.selection_listener(self.current_dim, ids)
        
    def set_current_selection(self, selection):
        linewidth = scipy.zeros(self.xaxis_data.shape, 'f')
        index = self.get_index_from_selection(self.dataset, selection)
        if len(index) > 0:
            linewidth.put(2, index)
        self.selected_nodes.set_linewidth(linewidth)
        self.canvas.draw()


class VennPlot(Plot):
    def __init__(self, name="Venn diagram"):
        Plot.__init__(self, name)

        # init draw
        self._init_bck()
        for c in self._venn_patches:
            self.axes.add_patch(c) 
        for mrk in self._markers:
            self.axes.add_patch(mrk)
        self.axes.set_xlim([-3, 3])
        self.axes.set_ylim([-2.5, 3.5])
        self._last_active = set()
        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.axis('equal')
        self.axes.grid(False)
        self.axes.set_frame_on(False)
        self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

    def _init_bck(self):
        res = 50
        a = .5
        r = 1.5
        mr = .2
        self.c1 = c1 =  Circle((-1,0), radius=r, alpha=a, facecolor='b')
        self.c2 = c2 = Circle((1,0), radius=r, alpha=a, facecolor='r')
        self.c3 = c3 = Circle((0, scipy.sqrt(3)), radius=r, alpha=a, facecolor='g')

        self.c1marker = Circle((-1.25, -.25), radius=mr, facecolor='y', alpha=0)
        self.c2marker = Circle((1.25, -.25), radius=mr, facecolor='y', alpha=0)
        self.c3marker = Circle((0, scipy.sqrt(3)+.25), radius=mr, facecolor='y', alpha=0)
        self.c1c2marker = Circle((0, -.15), radius=mr, facecolor='y', alpha=0)

        self.c1c3marker = Circle((-scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
        self.c2c3marker = Circle((scipy.sqrt(2)/2, 1), radius=mr, facecolor='y', alpha=0)
        self.c1c2c3marker = Circle((0, .6), radius=mr, facecolor='y', alpha=0)

        c1.elements = set(['a', 'b', 'c', 'f'])
        c2.elements = set(['a', 'c', 'd', 'e']) 
        c3.elements = set(['a', 'e', 'f', 'g'])
        self.active_elements = set()
        self.all_elements = c1.elements.union(c2.elements).union(c3.elements)

        c1.active = False
        c2.active = False
        c3.active = False

        c1.name = 'Blue'
        c2.name = 'Red'
        c3.name = 'Green'
        
        self._venn_patches = [c1, c2, c3]
        self._markers = [self.c1marker, self.c2marker, self.c3marker,
                         self.c1c2marker, self.c1c3marker,
                         self.c2c3marker, self.c1c2c3marker]

        self._tot_label = 'Tot: ' + str(len(self.all_elements))
        self._sel_label = 'Sel: ' + str(len(self.active_elements))
        self._legend = self.axes.legend((self._tot_label, self._sel_label),
                                       loc='upper right')
        
    def set_selection(self, selection, patch=None):
        if patch:
            patch.selection = selection
        else:
            selection_set = False
            for patch in self._venn_patches:
                if len(patch.elements)==0:
                    patch.elements = selection
                    selection_set = True
            if not selection_set:
                self.venn_patches[0].elements = selection
    
    def lasso_select_callback(self, verts, key=None):
        if verts==None:
            verts = (self._event.xdata, self._event.ydata)
        if key!='shift':
            for m in self._markers:
                m.set_alpha(0)
        
        self._patches_within_verts(verts, key)
        active = [i.active for i in self._venn_patches]
        if active==[True, False, False]:
            self.c1marker.set_alpha(1)
            self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
        elif active== [False, True, False]:
            self.c2marker.set_alpha(1)
            self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
        elif active== [False, False, True]:
            self.c3marker.set_alpha(1)
            self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
        elif active==[True, True, False]:
            self.c1c2marker.set_alpha(1)
            self.active_elements = self.c1.elements.intersection(self.c2.elements)
        elif active==[True, False, True]:
            self.c1c3marker.set_alpha(1)
            self.active_elements = self.c1.elements.intersection(self.c3.elements)
        elif active==[False, True, True]:
            self.c2c3marker.set_alpha(1)
            self.active_elements = self.c2.elements.intersection(self.c3.elements)
        elif active==[True, True, True]:
            self.c1c2c3marker.set_alpha(1)
            self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)

        if key=='shift':
            self.active_elements = self.active_elements.union(self._last_active)
        self._last_active = self.active_elements.copy()
        self._sel_label = 'Sel: ' + str(len(self.active_elements))
        self._legend.texts[1].set_text(self._sel_label)
        self.axes.figure.canvas.draw()

    def rectangle_select_callback(self, x1, y1, x2, y2, key):
        verts = [(x1, y1), (x2, y2)]
        if key!='shift':
            for m in self._markers:
                m.set_alpha(0)
        
        self._patches_within_verts(verts, key)
        active = [i.active for i in self._venn_patches]
        if active==[True, False, False]:
            self.c1marker.set_alpha(1)
            self.active_elements = self.c1.elements.difference(self.c2.elements.union(self.c3.elements))
        elif active== [False, True, False]:
            self.c2marker.set_alpha(1)
            self.active_elements = self.c2.elements.difference(self.c1.elements.union(self.c3.elements))
        elif active== [False, False, True]:
            self.c3marker.set_alpha(1)
            self.active_elements = self.c3.elements.difference(self.c2.elements.union(self.c1.elements))
        elif active==[True, True, False]:
            self.c1c2marker.set_alpha(1)
            self.active_elements = self.c1.elements.intersection(self.c2.elements)
        elif active==[True, False, True]:
            self.c1c3marker.set_alpha(1)
            self.active_elements = self.c1.elements.intersection(self.c3.elements)
        elif active==[False, True, True]:
            self.c2c3marker.set_alpha(1)
            self.active_elements = self.c2.elements.intersection(self.c3.elements)
        elif active==[True, True, True]:
            self.c1c2c3marker.set_alpha(1)
            self.active_elements = self.c1.elements.intersection(self.c3.elements).intersection(self.c2.elements)

        if key=='shift':
            self.active_elements = self.active_elements.union(self._last_active)
        self._last_active = self.active_elements.copy()
        self._sel_label = 'Sel: ' + str(len(self.active_elements))
        self._legend.texts[1].set_text(self._sel_label)
        self.axes.figure.canvas.draw()
        
    def _patches_within_verts(self, verts, key):
        xy = scipy.array(verts).mean(0)
        for venn_patch in self._venn_patches:
            venn_patch.active = False
            if self._distance(venn_patch.center,xy)<venn_patch.radius:
                venn_patch.active = True

    def _distance(self, (x1,y1),(x2,y2)):
        return scipy.sqrt( (x2-x1)**2 + (y2-y1)**2 )


# Create zoom-changed signal
gobject.signal_new('zoom-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
                   (gobject.TYPE_PYOBJECT,))

# Create pan/zoom-changed signal
gobject.signal_new('pan-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
                   (gobject.TYPE_PYOBJECT,))

# Create plot-resize-changed signal
gobject.signal_new('plot-resize-changed', Plot, gobject.SIGNAL_RUN_LAST, None,
                   (gobject.TYPE_PYOBJECT,))