diff --git a/system/dataset.py b/system/dataset.py index 568ae39..7fcba54 100644 --- a/system/dataset.py +++ b/system/dataset.py @@ -183,7 +183,9 @@ class Dataset: backitems.sort() sorted_ids=[ backitems[i][1] for i in range(0,len(backitems))] - if indices != None: + # we use id as scipy-arrays return a new array on boolean + # operations + if id(indices) != id(None): return [sorted_ids[index] for index in indices] else: return sorted_ids diff --git a/test/system/datasettest.py b/test/system/datasettest.py index 3add265..52ee2da 100644 --- a/test/system/datasettest.py +++ b/test/system/datasettest.py @@ -2,7 +2,7 @@ import unittest import sys sys.path.append('../..') from system.dataset import * -from scipy import rand,shape +from scipy import rand,shape, array class DatasetTest(unittest.TestCase): @@ -37,9 +37,12 @@ class DatasetTest(unittest.TestCase): self.assertEquals(['gene_a', 'gene_b', 'gene_c'], data.get_identifiers('genes', [0, 1, 2])) # "advanced" lookup self.assertEquals(['gene_c', 'gene_a'], data.get_identifiers('genes', [2, 0])) + # handle empty matrix of indices + self.assertEquals([], data.get_identifiers('samples', array([]))) # other dimension self.assertEquals(['sample_a', 'sample_b'], data.get_identifiers('samples', [0, 1])) + #def testExtraction(self): # ids = ['gene_a','gene_b'] # dim_name = 'genes'