NetworkPlot now works, but does not allow node colors or sizes to be passed as keyword arguments.

This commit is contained in:
Einar Ryeng 2006-08-03 14:30:06 +00:00
parent 44b676f726
commit 4c716e9428
2 changed files with 34 additions and 20 deletions

View File

@ -479,16 +479,12 @@ class ScatterPlot(Plot):
assert y1<=y2
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))
print 'index:', index
print 'current_dim', self.current_dim
ids = self.dataset_1.get_identifiers(self.current_dim, index)
print 'ids', ids
self.selection_listener(self.current_dim, ids)
def selection_changed(self, selection):
ids = selection[self.current_dim] # current identifiers
print 'ids: ', ids
index = self.dataset_1.get_indices(self.current_dim, ids)
xdata_new = scipy.take(self.xaxis_data, index) #take data
@ -499,10 +495,10 @@ class ScatterPlot(Plot):
self.canvas.draw()
class NetworkPlot(Plot):
def __init__(self, graph, **kw):
def __init__(self, dataset, **kw):
# Set member variables and call superclass' constructor
self.graph = graph.asnetworkx()
self.dataset = graph
self.graph = dataset.asnetworkx()
self.dataset = dataset
self.keywords = kw
self.dim_name = self.dataset.get_dim_name(0)
@ -515,17 +511,27 @@ class NetworkPlot(Plot):
Plot.__init__(self, kw['name'])
# Keep node size and color as dicts for fast lookup
self.node_size = {}
if kw.has_key('node_size') and cb.iterable(kw['node_size']):
kw.remove('node_size')
self.node_size = {}
for id, size in zip(self.dataset[self.dim_name], kw['node_size']):
self.node_size[id] = size
else:
for id in dataset[self.dim_name]:
self.node_size[id] = 40
self.node_color = {}
if kw.has_key('node_color') and cb.iterable(kw['node_color']):
kw.remove('node_color')
self.node_color = {}
for id, color in zip(self.dataset[self.dim_name], kw['node_color']):
self.node_color[id] = color
else:
self.node_color = None
# for id in self.dataset[self.dim_name]:
# self.node_color[id] = 'red'
if kw.has_key('node_color'):
kw.pop('node_color')
# FIXME: What is figsize?
self.fig = Figure(figsize=(5, 4), dpi=72)
@ -568,29 +574,38 @@ class NetworkPlot(Plot):
x1, x2 = x2, x1
if y1 > y2:
y1, y2 = y2, y1
index = scipy.nonzero((xdata<x1) & (xdata>x2) & (ydata>y1) & (ydata<y2))
index = scipy.nonzero((xdata>x1) & (xdata<x2) & (ydata>y1) & (ydata<y2))
ids = [node_ids[i] for i in index]
self.selection_listener(self.dataset.get_dim_name(0), ids)
def selection_changed(self, selection):
ids = selection[self.dataset.get_dim_name(0)] # current identifiers
selected_nodes = list(ids.intersection(set(self.graph.nodes())))
unselected_nodes = list(ids.difference(selected_nodes))
node_set = set(self.graph.nodes())
unselected_colors = [self.node_color[x] for x in unselected_nodes]
unselected_sizes = [self.node_size[x] for x in unselected_nodes]
selected_sizes = [self.node_size[x] for x in selected_nodes]
selected_nodes = list(ids.intersection(node_set))
unselected_nodes = list(node_set.difference(ids))
if self.node_color:
unselected_colors = [self.node_color[x] for x in unselected_nodes]
else:
unselected_colors = 'red'
if self.node_size:
unselected_sizes = [self.node_size[x] for x in unselected_nodes]
selected_sizes = [self.node_size[x] for x in selected_nodes]
self.ax.clear()
networkx.draw_networkx_edges(self.graph, edge_list=self.graph.edges(), \
ax=self.ax, **self.keywords)
networkx.draw_networkx_nodes(self.graph, self.keywords['pos'], node_list=unselected_nodes, \
node_color=unselected_colors, node_size=unselected_sizes, ax=self.ax, *kw)
if unselected_nodes:
networkx.draw_networkx_nodes(self.graph, nodelist=unselected_nodes, \
node_color='r', node_size=unselected_sizes, ax=self.ax, **self.keywords)
networkx.draw_networkx_nodes(self.graph, self.keywords['pos'], node_list=selected_nodes, \
node_color='black', node_size=selected_sizes, ax=self.ax, *kw)
if selected_nodes:
networkx.draw_networkx_nodes(self.graph, nodelist=selected_nodes, \
node_color='k', node_size=selected_sizes, ax=self.ax, **self.keywords)
self.canvas.draw()

View File

@ -108,7 +108,6 @@ class TestDataFunction(workflow.Function):
for x in 'ABCDEF':
for y in 'ADE':
graph.add_edge(x, y, 3)
print networkx.adj_matrix(graph)
ds = dataset.GraphDataset(array(networkx.adj_matrix(graph)))
ds_plot = plots.NetworkPlot(ds)