From 438dbd358bc014205ca2e0491567463aff75fc86 Mon Sep 17 00:00:00 2001
From: flatberg <flatberg@pvv.ntnu.no>
Date: Wed, 25 Oct 2006 12:56:21 +0000
Subject: [PATCH] Dataset to colour of dragndrop in dataset + minor
 addjustments in networkplot

---
 fluents/plots.py | 88 ++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 74 insertions(+), 14 deletions(-)

diff --git a/fluents/plots.py b/fluents/plots.py
index 0381b0c..e6a7c68 100644
--- a/fluents/plots.py
+++ b/fluents/plots.py
@@ -220,7 +220,14 @@ class ViewFrame (gtk.Frame):
         if isinstance(obj, Plot):
             self.set_view(obj)
             self.focus()
+            
+        elif isinstance(obj, fluents.dataset.Dataset):
+            view = self.get_view()
+            if view.is_mappable_with(obj):
+                view._update_color_from_dataset(obj)
 
+        # add selections below 
+                
 
 class MainView (gtk.Notebook):
     """The MainView class displays the Views in Fluents.
@@ -381,6 +388,10 @@ class View (gtk.Frame):
     def get_toolbar(self):
         return None
 
+    def is_mappable_with(self, dataset):
+        """Override in individual plots."""
+        return False
+
 
 class EmptyView (View):
     """EmptyView is shown in ViewFrames that are unused."""
@@ -654,12 +665,15 @@ class ScatterPlot(Plot):
         self.xaxis_data = dataset_1._array[:, x_index]
         self.yaxis_data = dataset_2._array[:, y_index]
         lw = scipy.zeros(self.xaxis_data.shape)
-        sc = self.ax.scatter(self.xaxis_data, self.yaxis_data, s=s, c=c, linewidth=lw, edgecolor='k', alpha=.6, cmap = cm.jet)
+        self.sc = sc = self.ax.scatter(self.xaxis_data, self.yaxis_data,
+                                       s=s, c=c, linewidth=lw,
+                                       edgecolor='k', alpha=.8,
+                                       cmap=cm.jet)
         if len(c)>1:
-            self.fig.colorbar(sc,ticks=[], fraction=.05)
+            self.fig.colorbar(sc, fraction=.05)
         self.ax.axhline(0, color='k', lw=1., zorder=1)
         self.ax.axvline(0, color='k', lw=1., zorder=1)
-        #self.ax.set_title(self.get_title())
+        
         # collection
         self.coll = self.ax.collections[0]
 
@@ -667,6 +681,41 @@ class ScatterPlot(Plot):
         self.add(self.canvas)
         self.canvas.show()
 
+    def is_mappable_with(self, dataset):
+        """Returns True if dataset is mappable with this plot.
+        """
+        if self.current_dim in dataset.get_dim_name() and dataset.asarray().shape[0] == self.xaxis_data.shape[0]:
+            return True
+            
+    def _update_color_from_dataset(self, data):
+        """Updates the facecolors from a dataset.
+        """
+        array = data.asarray()
+        #only support for 2d-arrays:
+        try:
+            m, n = array.shape
+        except:
+            raise ValueError, "No support for more tha 2 dimensions."
+        # is dataset a vector or matrix?
+        if not n==1:
+            # we have a category dataset
+            if isinstance(data, fluents.dataset.CategoryDataset):
+                map_vec = scipy.dot(array, scipy.diag(scipy.arange(n))).sum(1)
+            else:
+                map_vec = array.sum(1)
+        else:
+            map_vec = array.ravel()
+
+        # normalise mapping vector
+        map_vec = map_vec - map_vec.min()
+        map_vec = map_vec/map_vec.max()
+        # update facecolors
+        self.sc._facecolors = self.sc.to_rgba(map_vec, self.sc._alpha)
+        # draw
+        self.sc._A = None # mean hack 
+        self.ax.draw_artist(self.sc)
+        self.canvas.draw()
+        
     def rectangle_select_callback(self, x1, y1, x2, y2):
         ydata = self.yaxis_data
         xdata = self.xaxis_data
@@ -714,7 +763,9 @@ class NetworkPlot(Plot):
         self.dataset = dataset 
         self.keywords = kw
         self.dim_name = self.dataset.get_dim_name(0)
-        self.current_dim = self.dim_name
+        
+        if not kw.has_key('with_labels'):
+            k w['with_labels'] = False
         if not kw.has_key('name'):
             kw['name'] = self.dataset.get_name()
         if not kw.has_key('prog'):
@@ -722,6 +773,7 @@ class NetworkPlot(Plot):
         if not kw.has_key('pos') or kw['pos']:
             kw['pos'] = networkx.pygraphviz_layout(self.graph, kw['prog'])
         Plot.__init__(self, kw['name'])
+        self.current_dim = self.dim_name
 
         # Keep node size and color as dicts for fast lookup
         self.node_size = {}
@@ -731,7 +783,7 @@ class NetworkPlot(Plot):
                 self.node_size[id] = size
         else:
             for id in dataset[self.dim_name]:
-                self.node_size[id] = 40
+                self.node_size[id] = 30
                 
         self.node_color = {}
         if kw.has_key('node_color') and cb.iterable(kw['node_color']):
@@ -747,9 +799,9 @@ class NetworkPlot(Plot):
             kw.pop('node_color')
 
         self.ax = self.fig.add_subplot(111)
-        self.ax.set_position([0.01,0.01,.99,.99])
         self.ax.set_xticks([])
         self.ax.set_yticks([])
+        self.ax.grid(False)
         # FIXME: ax shouldn't be in kw at all
         if kw.has_key('ax'):
             kw.pop('ax')
@@ -760,11 +812,13 @@ class NetworkPlot(Plot):
 
         # Initial draw
         networkx.draw_networkx(self.graph, ax=self.ax, **kw)
+        print "Current dim is now: %s" %self.current_dim 
 
     def get_toolbar(self):
         return self._toolbar
 
     def rectangle_select_callback(self, x1, y1, x2, y2):
+        print "In select callbak, current dim is now: %s" %self.current_dim
         pos = self.keywords['pos']
         ydata = scipy.zeros((len(pos),), 'l')
         xdata = scipy.zeros((len(pos),), 'l')
@@ -782,9 +836,10 @@ class NetworkPlot(Plot):
         if y1 > y2:
             y1, y2 = y2, y1
         index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))[0]
-        
-
         ids = [node_ids[i] for i in index]
+        print "Updating listener from network with dim: %s" %self.current_dim
+        print "ids: "
+        print ids
         self.selection_listener(self.current_dim, ids)
 
     def set_current_selection(self, selection):
@@ -797,23 +852,28 @@ class NetworkPlot(Plot):
         if self.node_color:
             unselected_colors = [self.node_color[x] for x in unselected_nodes]
         else:
-            unselected_colors = 'red'
+            unselected_colors = 'gray'
 
         if self.node_size:
             unselected_sizes = [self.node_size[x] for x in unselected_nodes]
             selected_sizes = [self.node_size[x] for x in selected_nodes]
 
-        self.ax.clear()
-        networkx.draw_networkx_edges(self.graph, edge_list=self.graph.edges(), \
-            ax=self.ax, **self.keywords)
+        self.ax.collections=[]
+        networkx.draw_networkx_edges(self.graph,
+                                     edge_list=self.graph.edges(),
+                                     ax=self.ax,
+                                     **self.keywords)
+        
         networkx.draw_networkx_labels(self.graph,**self.keywords)
+
         if unselected_nodes:
             networkx.draw_networkx_nodes(self.graph, nodelist=unselected_nodes, \
-                node_color='r', node_size=unselected_sizes, ax=self.ax, **self.keywords)
+                node_color='gray', node_size=unselected_sizes, ax=self.ax, **self.keywords)
 
         if selected_nodes:
             networkx.draw_networkx_nodes(self.graph, nodelist=selected_nodes, \
-             node_color='k', node_size=selected_sizes, ax=self.ax, **self.keywords)
+             node_color='r', node_size=selected_sizes, ax=self.ax, **self.keywords)
+            self.ax.collections[-1].set_zorder(3)
         
         self.canvas.draw()