import os,sys
from itertools import izip
import pygtk
import gobject
import gtk
import fluents
from system import logger
import matplotlib
from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg as FigureCanvas
from matplotlib.backend_bases import NavigationToolbar2,cursors
from matplotlib.backends.backend_gtk import FileChooserDialog,cursord
from matplotlib.widgets import SubplotTool,RectangleSelector
from matplotlib.axes import Subplot
from matplotlib.figure import Figure
from matplotlib import cm,cbook
from pylab import Polygon
from matplotlib.collections import LineCollection
from matplotlib.mlab import prctile
import networkx
import scipy
 
# global active mode. Used by toolbars to communicate correct mode
active_mode = 'default'

class ObjectTable:
    """A 2D table of elements."""

    def __init__(self, xsize=0, ysize=0, creator=None):
        self._elements = []
        self._creator = creator or (lambda : None)
        self.xsize = xsize
        self.ysize = ysize
        self.resize(xsize, ysize)

    def resize(self, xsize, ysize):
        """Resizes the table by removing and creating elements as required."""
        # Delete or append new rows
        del self._elements[ysize:]
        new_rows = ysize - len(self._elements)
        self._elements.extend([list() for i in range(new_rows)])

        # Delete or append new columns
        for row in self._elements:
            del row[xsize:]
            new_elems = xsize - len(row)
            row.extend([self._creator() for i in range(new_elems)])

    def __getitem__(self, index):
        x, y = index
        return self._elements[y][x]

    def __setitem__(self, index, value):
        x, y = index
        self._elements[y][x] = value


class ViewFrame (gtk.Frame):
    """
    A ViewFrame is a gtk bin that contains a view.
    The ViewFrame is either active or inactive, and belongs to a group of 
    VeiwFrames of which only one can be active at any time.
    """

    def __init__(self, view_frames):
        gtk.Frame.__init__(self)
        self.focused = False
        self.view_frames = view_frames
        self.empty_view = EmptyView()
        self._button_event = None

        ## Set up a VBox with a label wrapped in an event box.
        label = gtk.Label()
        ebox = gtk.EventBox()
        ebox.add(label)
        vbox = gtk.VBox()
        vbox.pack_start(ebox, expand=False)
        vbox.pack_start(gtk.HSeparator(), expand=False)

        self._ebox_button_event = ebox.connect("button-press-event", 
                                               self.on_button_press_event)
        ## Keep the references for later use.
        self._vbox = vbox
        self._ebox = ebox
        self._view_title = label
        self.add(vbox)
                
        view_frames.append(self)
        if len(view_frames) == 1:
            self.focus()
        else:
            self.focused = True
            self.unfocus()

        # Get dropped views
        self.drag_dest_set(gtk.DEST_DEFAULT_ALL,
                           [("GTK_TREE_MODEL_ROW", gtk.TARGET_SAME_APP, 7)],
                           gtk.gdk.ACTION_LINK)
        self.connect("drag-data-received", self.on_drag_data_received)
       
        # Set view
        self._view = self.empty_view
        self._view.connect("button-press-event", self.on_button_press_event)
        self._vbox.add(self._view)
        self._view_title.set_text(self._view.title)
        self.show_all()
        self._view.show()

    def focus(self):
        """Gets focus and ensures that no other window is in focus."""
        if self.focused:
            self.emit('focus-changed', self, True)
            return self

        for frame in self.view_frames:
            frame.unfocus()

        self.set_shadow_type(gtk.SHADOW_IN)
        self._ebox.set_state(gtk.STATE_ACTIVE)
        self.focused = True
        self.emit('focus-changed', self, True)
        return self

    def unfocus(self):
        """Removes focus from the ViewFrame. Does nothing if unfocused."""
        if not self.focused:
            return

        self.set_shadow_type(gtk.SHADOW_OUT)
        self._ebox.set_state(gtk.STATE_NORMAL)
        self.focused = False
        self.emit('focus-changed', self, False)

    def set_view(self, view):
        """Set view to view or to empty view if parameter is None"""

        # if None is passed, use empty view
        if view == None:
            view = self.empty_view

        # do nothing if the view is already there
        if view == self._view:
            return
        
        # detach view from current parent
        if view._view_frame:
            view._view_frame.set_view(None)

        # switch which widget we are listening to
        if self._button_event:
            self._view.disconnect(self._button_event)

        self._button_event = view.connect("button-press-event", 
                                          self.on_button_press_event)

        # remove old view, set new view
        if self._view:
            self._view.hide()
            self._vbox.remove(self._view)
            self._view._view_frame = None

        self._view_title.set_text(view.title)
        self._vbox.add(view)
        view.show()
        
        view._view_frame = self
        self._view = view

    def get_view(self):
        """Returns current view, or None if the empty view is set."""
        if self._view == self.empty_view:
            return None
        return self._view

    def on_button_press_event(self, widget, event):
        if not self.focused:
            self.focus()
            
    def on_drag_data_received(self, widget, drag_context, x, y, 
                              selection, info, timestamp):
        treestore, path = selection.tree_get_row_drag_data()
        iter = treestore.get_iter(path)
        obj  = treestore.get_value(iter, 2)

        if isinstance(obj, Plot):
            self.set_view(obj)
            self.focus()


class MainView (gtk.Notebook):
    def __init__(self):
        gtk.Notebook.__init__(self)
        self.set_show_tabs(False)
        self.set_show_border(False)

        self._view_frames = []
        self._views = ObjectTable(2, 2, lambda : ViewFrame(self._view_frames))
        self._small_views = gtk.Table(2, 2, True)
        self._small_views.set_col_spacings(4)
        self._small_views.set_row_spacings(4)
        self._large_view = ViewFrame(list())
        self.update_small_views()

        for vf in self._view_frames:
            vf.connect('focus-changed', self.on_view_focus_changed)
        
        self.append_page(self._small_views)
        self.append_page(self._large_view)
        self.show()
        self.set_current_page(0)

    def __getitem__(self, x, y):
        return self._views[x, y]

    def update_small_views(self):
        for x in range(self._views.xsize):
            for y in range(self._views.ysize):
                child = self._views[x,y]
                self._small_views.attach(child, x, x+1, y, y+1)

    def get_active_small_view(self):
        for vf in self._view_frames:
            if vf.focused:
                return vf
        return None

    def get_active_view_frame(self):
        if self.get_current_page() == 0:
            return self.get_active_small_view()
        else:
            return self._large_view
            
    def goto_large(self):
        if self.get_current_page() == 1:
            return

        vf = self.get_active_small_view()
        view = vf.get_view()
        vf.set_view(None)
        self._large_view.set_view(view)
        self.set_current_page(1)

    def goto_small(self):
        if self.get_current_page() == 0:
            return

        vf = self.get_active_small_view()
        view = self._large_view.get_view()
        self._large_view.set_view(None)
        vf.set_view(view)
        self.set_current_page(0)

    def insert_view(self, view):
        if self.get_current_page() == 0:
            vf = self.get_active_small_view()
        else:
            vf = self._large_view
        vf.set_view(view)
       
    def set_all_plots(self, plots):
        for vf in self._view_frames:
            if plots:
                vf.set_view(plots.pop())
            else:
                vf.set_view(None)

    def show(self):
        for vf in self._view_frames:
            vf.show()
        self._small_views.show()
        gtk.Notebook.show(self)
        
    def on_view_focus_changed(self, widget, vf, focused):
        if focused:
            self.emit('view-changed', vf)
    

class View (gtk.Frame):
    """The base class of everything that is shown in the center view of fluents.

    Most views should rather subclass Plot, which automatically handles freezing and
    toolbars, and sets up matplotlib Figure and Canvas objects.
    """
    
    def __init__(self, title):
        gtk.Frame.__init__(self)
        self.title = title
        self.set_shadow_type(gtk.SHADOW_NONE)
        self._view_frame = None

    def get_toolbar(self):
        return None


class EmptyView (View):
    """EmptyView is shown in ViewFrames that are unused."""
    def __init__(self):
        View.__init__(self, 'Empty view')
        self.set_label(None)
        label = gtk.Label('No view')
        ebox = gtk.EventBox()
        ebox.add(label)
        self.add(ebox)
        label.show()
        ebox.show()
        self.show()

    
class Plot (View):
    def __init__(self, title):
        View.__init__(self, title)
        
        self.selection_listener = None
        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self._background = None
        self._frozen = False
        self._toolbar = PlotToolbar(self)
        self.canvas.add_events(gtk.gdk.ENTER_NOTIFY_MASK)
        self.current_dim = None
        self._current_selection = None

    def set_frozen(self, frozen):
        """A frozen plot will not be updated when the current selection is changed."""
        self._frozen = frozen
        if not frozen and self._current_selection != None:
            self.set_current_selection(self._current_selection)
    
    def get_title(self):
        return self.title

    def get_toolbar(self):
        return self._toolbar

    def selection_changed(self, dim_name, selection):
        """ Selection observer handle.
        
        A selection change in a plot is only drawn if:
        1.) plot is sensitive to selections (not freezed)
        2.) plot is visible (has a view)
        3.) the selections dim_name is the plot's dimension.
        """
        
        self._current_selection = selection
        if self._frozen \
               or not self.get_property('visible') \
               or self.current_dim != dim_name:
            return

        self.set_current_selection(selection)

    def set_selection_listener(self, listener):
        """Allow project to listen to selections.

        The selection will propagate back to all plots through the
        selection_changed() method. The listener will be called as
        listener(dimension_name, ids).
        """
        self.selection_listener = listener

        
class LineViewPlot(Plot):
    """Line view of current selection, no interaction
    Only works on 2d-arrays
    input:
           -- major_axis : dim_number for line dim (see scipy.ndarray for axis def.)
           -- minor_axis : needs definition only for higher order arrays
    ps: slow (cant get linecollection and blit to work)
    """
    def __init__(self, dataset, major_axis=1, minor_axis=None, name="Line view"):
        self.use_blit = False
        self._last_index = []
        self._data = dataset.asarray()
        self.dataset = dataset
        Plot.__init__(self, name)
        self.ax = self.fig.add_subplot(111)
        #self.ax.set_title(self.get_title())
        self.current_dim = self.dataset.get_dim_name(major_axis)
        if len(self._data.shape)==2 and not minor_axis:
            minor_axis = major_axis-1

        #initial draw
        x_axis = scipy.arange(self._data.shape[minor_axis])
        self.line_segs=[]
        for xi in range(self._data.shape[major_axis]):
            yi = self._data.take([xi], major_axis)
            self.line_segs.append([(xx,yy) for xx,yy in izip(x_axis,yi)])

        #background
        xax = scipy.arange(self._data.shape[0])
        verts_0 = [] #100,0
        verts_1 = [] # 90,10
        verts_2 = [] # 75,25
        med = []
        for i in xax:
            pp = prctile(self._data[i,:], [0.,5.,25,50.,75.,95.,100])
            verts_0.append((i,pp[0]))
            verts_1.append((i,pp[1]))
            verts_2.append((i,pp[2]))
        for i in xax[::-1]:
            pp = prctile(self._data[i,:], [0.,5.,25,50.,75.,95.,100])
            verts_0.append((i, pp[-1]))
            verts_1.append((i, pp[-2]))
            verts_2.append((i, pp[-3]))
            med.append(pp[3])

        bck0 = Polygon(verts_0, alpha=.15, lw=0)
        bck1 = Polygon(verts_1, alpha=.15, lw=0)
        bck2 = Polygon(verts_2, alpha=.15, lw=0)
        
        self.ax.add_patch(bck0)
        self.ax.add_patch(bck1)
        self.ax.add_patch(bck2)
        self.ax.plot(xax,med, 'b')
        self.ax.autoscale_view()
        
        self.add(self.canvas)
        self.canvas.show()
        
        #FIXME: Lineview plot cannot do selections -> disable in toolbar
        self._toolbar = PlotToolbar(self)
        self.canvas.mpl_connect('resize_event', self.clear_background)

    def clear_background(self, event):
        self._background = None
        
    def set_current_selection(self, selection):
        ids = selection[self.current_dim] # current identifiers
        index = self.dataset.get_indices(self.current_dim, ids)
        if self.use_blit:
            if self._background is None:
                self._bbox = self.ax.bbox.deepcopy()
                self._background = self.canvas.copy_from_bbox(self.ax.bbox)
            self.canvas.restore_region(self._background)

        if len(index)>0: # do we have a selection
            if len(self.ax.collections)>0:
                self.ax.collections = []
            segs = [self.line_segs[i] for i in index]
            line_coll = LineCollection(segs, colors=(1,0,0,1))
            line_coll.set_clip_box(self.ax.bbox)
            self.ax.update_datalim(line_coll.get_verts(self.ax.transData))

            if self.use_blit:
                self.ax.draw_artist(line_coll)
                line_coll.get_clip_box().get_bounds()
                self.canvas.blit()
                
            else:
                self.ax.add_collection(line_coll)
                self.canvas.draw()

            
class ScatterMarkerPlot(Plot):
    """The ScatterMarkerPlot is faster than regular scatterplot, but
       has no color and size options."""
    
    def __init__(self, dataset_1, dataset_2, id_dim, sel_dim,
                 id_1, id_2, s=6, name="Scatter plot"):
        Plot.__init__(self, name)
        self.use_blit = False
        self._background = None
        self.ax = self.fig.add_subplot(111)
        self.ax.axhline(0, color='k', lw=1., zorder=1)
        self.ax.axvline(0, color='k', lw=1., zorder=1)
        self.current_dim = id_dim
        self.dataset_1 = dataset_1
        self.ms = s
        self._selection_line = None
        x_index = dataset_1[sel_dim][id_1]
        y_index = dataset_2[sel_dim][id_2]
        self.xaxis_data = dataset_1._array[:, x_index]
        self.yaxis_data = dataset_2._array[:, y_index]
        self.ax.plot(self.xaxis_data, self.yaxis_data, 'o', markeredgewidth=0, markersize=s)
        #self.ax.set_title(self.get_title())
        self.add(self.canvas)
        self.canvas.show()

    def rectangle_select_callback(self, x1, y1, x2, y2):
        ydata = self.yaxis_data
        xdata = self.xaxis_data

        # find indices of selected area
        if x1>x2:
            x1, x2 = x2, x1
        if y1>y2:
            y1, y2 = y2, y1
        assert x1<=x2
        assert y1<=y2
        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
        ids = self.dataset_1.get_identifiers(self.current_dim, index)
        self.selection_listener(self.current_dim, ids)

    def set_current_selection(self, selection):
        ids = selection[self.current_dim] # current identifiers
        index = self.dataset_1.get_indices(self.current_dim, ids)
        if self.use_blit:
            if self._background is None:
                self._background = self.canvas.copy_from_bbox(self.ax.bbox)
            self.canvas.restore_region(self._background)
        if not len(index)>0:
            return
        xdata_new = self.xaxis_data.take(index) #take data
        ydata_new = self.yaxis_data.take(index)
        #remove old selection
        if self._selection_line:
            self.ax.lines.remove(self._selection_line)

        self._selection_line, = self.ax.plot(xdata_new, ydata_new,marker='o', markersize=self.ms, linestyle=None, markerfacecolor='r')

#        self._toolbar.forward() #update data lims before draw
        if self.use_blit:
            self.ax.draw_artist(self._selection_line)
            self.canvas.blit()
        else:
            self.canvas.draw()


class ScatterPlot(Plot):
    """The ScatterPlot is slower than scattermarker, but has size option."""
    def __init__(self, dataset_1, dataset_2, id_dim, sel_dim, id_1, id_2,c='b',s=30,sel_dim_2=None, name="Scatter plot"):
        Plot.__init__(self, name)
        self.use_blit = False
        self.ax = self.fig.add_subplot(111)
        self.current_dim = id_dim
        self.dataset_1 = dataset_1
        x_index = dataset_1[sel_dim][id_1]
        if sel_dim_2:
            y_index = dataset_2[sel_dim_2][id_2]
        else:
            y_index = dataset_2[sel_dim][id_2]
        self.xaxis_data = dataset_1._array[:, x_index]
        self.yaxis_data = dataset_2._array[:, y_index]
        lw = scipy.zeros(self.xaxis_data.shape)
        sc = self.ax.scatter(self.xaxis_data, self.yaxis_data, s=s, c=c, linewidth=lw, edgecolor='k', alpha=.6, cmap = cm.jet)
        if len(c)>1:
            self.fig.colorbar(sc,ticks=[], fraction=.05)
        self.ax.axhline(0, color='k', lw=1., zorder=1)
        self.ax.axvline(0, color='k', lw=1., zorder=1)
        #self.ax.set_title(self.get_title())
        # collection
        self.coll = self.ax.collections[0]

        # add canvas to widget
        self.add(self.canvas)
        self.canvas.show()

    def rectangle_select_callback(self, x1, y1, x2, y2):
        ydata = self.yaxis_data
        xdata = self.xaxis_data

        # find indices of selected area
        if x1>x2:
            x1, x2 = x2, x1
        if y1>y2:
            y1, y2 = y2, y1
        assert x1<=x2
        assert y1<=y2

        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
        ids = self.dataset_1.get_identifiers(self.current_dim, index)
        self.selection_listener(self.current_dim, ids)

    def set_current_selection(self, selection):
        ids = selection[self.current_dim] # current identifiers
        if len(ids)==0:
            print "nothing selected"
            return
        #self._toolbar.forward() #update data lims before draw
        index = self.dataset_1.get_indices(self.current_dim, ids)
        if self.use_blit:
            if self._background is None:
                self._background = self.canvas.copy_from_bbox(self.ax.bbox)
            self.canvas.restore_region(self._background)
        lw = scipy.zeros(self.xaxis_data.shape)
        if len(index)>0:
            lw.put(2.,index)
        self.coll.set_linewidth(lw)
        
        if self.use_blit:
            self.canvas.blit()
            self.ax.draw_artist(self.coll)
        else:
            self.canvas.draw()
        
        
class NetworkPlot(Plot):
    def __init__(self, dataset, **kw):
        # Set member variables and call superclass' constructor
        self.graph = dataset.asnetworkx()
        self.dataset = dataset 
        self.keywords = kw
        self.dim_name = self.dataset.get_dim_name(0)
        self.current_dim = self.dim_name
        if not kw.has_key('name'):
            kw['name'] = self.dataset.get_name()
        if not kw.has_key('prog'):
            kw['prog'] = 'neato'
        if not kw.has_key('pos') or kw['pos']:
            kw['pos'] = networkx.pygraphviz_layout(self.graph, kw['prog'])
        Plot.__init__(self, kw['name'])

        # Keep node size and color as dicts for fast lookup
        self.node_size = {}
        if kw.has_key('node_size') and cb.iterable(kw['node_size']):
            kw.remove('node_size')
            for id, size in zip(self.dataset[self.dim_name], kw['node_size']):
                self.node_size[id] = size
        else:
            for id in dataset[self.dim_name]:
                self.node_size[id] = 40
                
        self.node_color = {}
        if kw.has_key('node_color') and cb.iterable(kw['node_color']):
            kw.remove('node_color')
            for id, color in zip(self.dataset[self.dim_name], kw['node_color']):
                self.node_color[id] = color
        else:
            self.node_color = None
#            for id in self.dataset[self.dim_name]:
#                self.node_color[id] = 'red'

        if kw.has_key('node_color'):
            kw.pop('node_color')

        self.ax = self.fig.add_subplot(111)
        self.ax.set_position([0.01,0.01,.99,.99])
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        # FIXME: ax shouldn't be in kw at all
        if kw.has_key('ax'):
            kw.pop('ax')

        # Add canvas and show
        self.add(self.canvas)
        self.canvas.show()

        # Initial draw
        networkx.draw_networkx(self.graph, ax=self.ax, **kw)

    def get_toolbar(self):
        return self._toolbar

    def rectangle_select_callback(self, x1, y1, x2, y2):
        pos = self.keywords['pos']
        ydata = scipy.zeros((len(pos),), 'l')
        xdata = scipy.zeros((len(pos),), 'l')
        node_ids = []
        c = 0
        for name,(x,y) in pos.items():
            node_ids.append(name)
            xdata[c] = x
            ydata[c] = y
            c+=1
        
        # find indices of selected area
        if x1 > x2:
            x1, x2 = x2, x1
        if y1 > y2:
            y1, y2 = y2, y1
        index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
        

        ids = [node_ids[i] for i in index]
        self.selection_listener(self.current_dim, ids)

    def set_current_selection(self, selection):
        ids = selection[self.current_dim] # current identifiers
        node_set = set(self.graph.nodes())

        selected_nodes = list(ids.intersection(node_set))
        unselected_nodes = list(node_set.difference(ids))

        if self.node_color:
            unselected_colors = [self.node_color[x] for x in unselected_nodes]
        else:
            unselected_colors = 'red'

        if self.node_size:
            unselected_sizes = [self.node_size[x] for x in unselected_nodes]
            selected_sizes = [self.node_size[x] for x in selected_nodes]

        self.ax.clear()
        networkx.draw_networkx_edges(self.graph, edge_list=self.graph.edges(), \
            ax=self.ax, **self.keywords)
        networkx.draw_networkx_labels(self.graph,**self.keywords)
        if unselected_nodes:
            networkx.draw_networkx_nodes(self.graph, nodelist=unselected_nodes, \
                node_color='r', node_size=unselected_sizes, ax=self.ax, **self.keywords)

        if selected_nodes:
            networkx.draw_networkx_nodes(self.graph, nodelist=selected_nodes, \
             node_color='k', node_size=selected_sizes, ax=self.ax, **self.keywords)
        
        self.canvas.draw()


class PlotMode:
    """A PlotMode object corresponds to a mouse mode in a plot.

    When a mode is selected in the toolbar, the PlotMode corresponding
    to the toolbar button is activated by calling setup(ax) for the axis
    system ax.
    """
    def __init__(self, plot, name, tooltip, image_file):
        self.name = name
        self.tooltip = tooltip
        self.image_file = image_file
        self.plot = plot
        self.canvas = plot.canvas

    def get_icon(self):
        """Returns the icon for the PlotMode"""
        fname = os.path.join(fluents.ICONDIR, self.image_file)
        image = gtk.Image()
        image.set_from_file(fname)
        return image
        
    def activate(self):
        """Subclasses of PlotMode should do their initialization here.

        The activate method is called when a mode is activated, and is
        used primarily to set up callback functions on events in the
        canvas.
        """
        pass

    def deactivate(self):
        """Subclasses of PlotMode should do their cleanup here.

        The deactivate method is primarily by subclasses of PlotMode to
        remove any callbacks they might have on the matplotlib canvas.
        """
        pass

    def _mpl_disconnect_all(self):
        """Disconnects all matplotlib callbacks defined on the canvas.

        This is a hack because the RectangleSelector in matplotlib does
        not store its callbacks, so we need a workaround to remove them.
        """
        callbacks = self.plot.canvas.callbacks
        
        for callbackd in callbacks.values():
            for c in callbackd.keys():
                del callbackd[c]


class DefaultPlotMode (PlotMode):
    def __init__(self, plot):
        PlotMode.__init__(self, plot, 'default', 'Default mode', 'cursor.png')


class PanPlotMode (PlotMode):
    def __init__(self, plot):
        PlotMode.__init__(self, plot, 'pan',
                          'Pan axes with left mouse, zoom with right',
                          'move.png')

        # Holds handler IDs for callbacks.
        self._button_press = None
        self._button_release = None
        self._motion_notify = None

        self._button_pressed = None

    def activate(self):
        self._button_press = self.canvas.mpl_connect(
            'button_press_event', self._on_button_press)
        self._button_relese = self.canvas.mpl_connect(
            'button_release_event', self._on_button_release)
        #self._drag = self.canvas.mpl_connect(
        #   'mouse_drag_event', self._on_drag)

    def deactivate(self):
        if self._button_press:
            self.canvas.mpl_disconnect(self._button_press)

        if self._button_release:
            self.canvas.mpl_disconnect(self._button_release)

    def _on_button_press(self, event):

        if event.button == 1:
            self._button_pressed = 1
        elif  event.button == 3:
            self._button_pressed = 3
        else:
            self._button_pressed=None
            return

        x, y = event.x, event.y

        # push the current view to define home if stack is empty
        # if self._views.empty(): self.push_current()

        self._xypress=[]
        for i, a in enumerate(self.canvas.figure.get_axes()):
            if x is not None and y is not None and a.in_axes(x, y) \
                   and a.get_navigate():
                xmin, xmax = a.get_xlim()
                ymin, ymax = a.get_ylim()
                lim = xmin, xmax, ymin, ymax
                self._xypress.append((x, y, a, i, lim,a.transData.deepcopy()))
                self.canvas.mpl_disconnect(self._motion_notify)
                
                cid = self.canvas.mpl_connect('motion_notify_event',
                                              self._on_motion_notify)
                self._motion_notify = cid

    def _on_motion_notify(self, event):
        """The drag callback in pan/zoom mode"""

        def format_deltas(event, dx, dy):
            """Returns the correct dx and dy based on the modifier keys"""
            if event.key=='control':
                if(abs(dx)>abs(dy)):
                    dy = dx
                else:
                    dx = dy
            elif event.key=='x':
                dy = 0
            elif event.key=='y':
                dx = 0
            elif event.key=='shift':
                if 2*abs(dx) < abs(dy):
                    dx=0
                elif 2*abs(dy) < abs(dx):
                    dy=0
                elif(abs(dx)>abs(dy)):
                    dy=dy/abs(dy)*abs(dx)
                else:
                    dx=dx/abs(dx)*abs(dy)
            return (dx,dy)

        for cur_xypress in self._xypress:
            lastx, lasty, a, ind, lim, trans = cur_xypress
            xmin, xmax, ymin, ymax = lim

            #safer to use the recorded button at the press than current button:
            #multiple button can get pressed during motion...
            if self._button_pressed==1:
                lastx, lasty = trans.inverse_xy_tup( (lastx, lasty) )
                x, y = trans.inverse_xy_tup( (event.x, event.y) )
                if a.get_xscale()=='log':
                    dx=1-lastx/x
                else:
                    dx=x-lastx
                if a.get_yscale()=='log':
                    dy=1-lasty/y
                else:
                    dy=y-lasty

                dx,dy=format_deltas(event,dx,dy)

                if a.get_xscale()=='log':
                    xmin *= 1-dx
                    xmax *= 1-dx
                else:
                    xmin -= dx
                    xmax -= dx
                if a.get_yscale()=='log':
                    ymin *= 1-dy
                    ymax *= 1-dy
                else:
                    ymin -= dy
                    ymax -= dy
                    
            elif self._button_pressed==3:
                try:
                    dx=(lastx-event.x)/float(a.bbox.width())
                    dy=(lasty-event.y)/float(a.bbox.height())
                    dx,dy=format_deltas(event,dx,dy)
                    if a.get_aspect() != 'auto':
                        dx = 0.5*(dx + dy)
                        dy = dx
                    alphax = pow(10.0,dx)
                    alphay = pow(10.0,dy)
                    lastx, lasty = trans.inverse_xy_tup( (lastx, lasty) )
                    if a.get_xscale()=='log':
                        xmin = lastx*(xmin/lastx)**alphax
                        xmax = lastx*(xmax/lastx)**alphax
                    else:
                        xmin = lastx+alphax*(xmin-lastx)
                        xmax = lastx+alphax*(xmax-lastx)
                    if a.get_yscale()=='log':
                        ymin = lasty*(ymin/lasty)**alphay
                        ymax = lasty*(ymax/lasty)**alphay
                    else:
                        ymin = lasty+alphay*(ymin-lasty)
                        ymax = lasty+alphay*(ymax-lasty)
                        
                except OverflowError:
                    warnings.warn('Overflow while panning')
                    return

            a.set_xlim(xmin, xmax)
            a.set_ylim(ymin, ymax)
            
        self.canvas.draw()

    def _on_button_release(self, event):
        'the release mouse button callback in pan/zoom mode'
        self.canvas.mpl_disconnect(self._motion_notify)
        if not self._xypress: return
        self._xypress = None
        self._button_pressed=None
        self.canvas.draw()


class ZoomPlotMode (PlotMode):
    def __init__(self, plot):
        PlotMode.__init__(self, plot, 'zoom',
                          'Zoom to rectangle','zoom_to_rect.png')        
        self._selectors = {}

    def activate(self):
        for ax in self.canvas.figure.get_axes():
            props = dict(facecolor = 'blue',
                         edgecolor = 'black',
                         alpha = 0.3,
                         fill = True)

            rs = RectangleSelector(ax, self._on_select, drawtype='box',
                                   useblit=True, rectprops = props)
            self.canvas.draw()
            self._selectors[rs] = ax

    def deactivate(self):
        self._mpl_disconnect_all()
        self._selectors = {}

    def _on_select(self, start, end):
        ax = start.inaxes

        ax.set_xlim((min(start.xdata, end.xdata), max(start.xdata, end.xdata)))
        ax.set_ylim((min(start.ydata, end.ydata), max(start.ydata, end.ydata)))
        self.canvas.draw()


class SelectPlotMode (PlotMode):
    def __init__(self, plot):
        PlotMode.__init__(self, plot, 'select',
                          'Select within rectangle', 'select.png')
        self._selectors = {}

    def activate(self):
        for ax in self.canvas.figure.get_axes():
            props = dict(facecolor = 'blue',
                         edgecolor = 'black',
                         alpha = 0.3,
                         fill = True)

            rs = RectangleSelector(ax, self._on_select, drawtype='box',
                                   useblit=True, rectprops = props)
            self.canvas.draw()
            self._selectors[rs] = ax

    def deactivate(self):
        self._mpl_disconnect_all()
        self._selectors = {}

    def _on_select(self, start, end):
        self.plot.rectangle_select_callback(start.xdata, start.ydata,
                                            end.xdata, end.ydata)


class PlotToolbar(gtk.Toolbar):
    
    def __init__(self, plot):
        gtk.Toolbar.__init__(self)
        self.plot = plot
        self.canvas = plot.canvas
        self._current_mode = None
        self.tooltips = gtk.Tooltips()

        ## Maps toolbar buttons to PlotMode objects.
        self._mode_buttons = {}

        self.set_property('show-arrow', False)
        
        self.canvas.connect('enter-notify-event', self.on_enter_notify)
        self.show()
        self.add_mode(DefaultPlotMode(self.plot))
        self.add_mode(PanPlotMode(self.plot))
        self.add_mode(ZoomPlotMode(self.plot))
        self.add_mode(SelectPlotMode(self.plot))

        self.insert(gtk.SeparatorToolItem(), -1)
        self.set_style(gtk.TOOLBAR_ICONS)

        # Set up freeze button
        btn = gtk.ToggleToolButton()

        fname = os.path.join(fluents.ICONDIR, "freeze.png")
        image = gtk.Image()
        image.set_from_file(fname)

        btn.set_icon_widget(image)
        btn.connect('toggled', self._on_freeze_toggle)
        self.insert(btn, -1)

        self.show_all()

    def add_mode(self, mode):
        """Adds a new mode to the toolbar."""

        if len(self._mode_buttons) > 0:
            other = self._mode_buttons.keys()[0]
        else:
            other = None

        btn = gtk.RadioToolButton(other)
        btn.set_icon_widget(mode.get_icon())
        btn.set_tooltip(self.tooltips, mode.tooltip, 'Private')
        btn.connect('toggled', self._on_mode_toggle)

        self._mode_buttons[btn] = mode
        self.insert(btn, -1)

        if self._current_mode == None:
            self._current_mode = mode

    def get_mode(self):
        """Returns the active mode name."""
        if self._current_mode:
            return self._current_mode.name
        return None
    
    def get_mode_by_name(self, mode_name):
        """Returns the mode with the given name or None."""
        for m in self._mode_buttons.values():
            if m.name == mode_name:
                return m
        return None

    def get_button(self, mode_name):
        """Returns the button that corresponds to a mode name."""
        for b, m in self._mode_buttons.items():
            if m.name == mode_name:
                return b
        return None
    
    def set_mode(self, mode_name):
        """Sets a mode by name. Returns the mode or None"""
        if mode_name == self._current_mode.name:
            return None

        if self._current_mode:
            self._current_mode.deactivate()

        new_mode = self.get_mode_by_name(mode_name)
        if new_mode:
            new_mode.activate()
            self._current_mode = self.get_mode_by_name(mode_name)
        else:
            logger.log('warning', 'No such mode: %s' % mode_name)

        if self.get_button(mode_name) and \
               not self.get_button(mode_name).get_active():
            self.get_button(mode_name).set_active(True)
            
        globals()['active_mode'] = mode_name
        return self._current_mode


    def _on_mode_toggle(self, button):
        if button.get_active():
            self.set_mode(self._mode_buttons[button].name)

    def _on_freeze_toggle(self, button):
        self.plot.set_frozen(button.get_active())

    def on_enter_notify(self, widget, event):
        self.set_mode(active_mode)


# Create a view-changed signal that should be emitted every time
# the active view changes.
gobject.signal_new('view-changed', MainView, gobject.SIGNAL_RUN_LAST,
                   gobject.TYPE_NONE, 
                   (gobject.TYPE_PYOBJECT,))

# Create focus-changed signal
gobject.signal_new('focus-changed', ViewFrame, gobject.SIGNAL_RUN_LAST,
                   gobject.TYPE_NONE, 
                   (gobject.TYPE_PYOBJECT, gobject.TYPE_BOOLEAN,))