Fixed subdata method so the correct identifiers are stored
This commit is contained in:
parent
e417547923
commit
b37ebe568f
@ -230,16 +230,17 @@ class Dataset(object):
|
||||
"""
|
||||
ds = self.copy()
|
||||
indices = ds.get_indices(dim, idents)
|
||||
idents = ds.get_identifiers(dim, indices=indices)
|
||||
if not idents:
|
||||
raise ValueError("No of identifers from: \n%s \nfound in %s" %(str(idents), ds._name))
|
||||
ax = [i for i, name in enumerate(ds._dims) if name == dim][0]
|
||||
subarr = ds._array.take(indices, ax)
|
||||
for k, v in ds._map[dim].items():
|
||||
if k not in idents:
|
||||
del ds._map[dim][k]
|
||||
new_indices = range(len(idents))
|
||||
ds._map[dim] = ReverseDict(zip(idents, new_indices))
|
||||
ds.shape = tuple(len(ds._map[d]) for d in ds._dims)
|
||||
ds.set_array(subarr)
|
||||
return ds
|
||||
|
||||
|
||||
def transpose(self):
|
||||
"""Returns a copy of transpose of a dataset.
|
||||
|
||||
@ -390,7 +391,7 @@ class GraphDataset(Dataset):
|
||||
dim0, dim1 = self.get_dim_name()
|
||||
node_ids = self.get_identifiers(dim0, sorted=True)
|
||||
edge_ids = self.get_identifiers(dim1, sorted=True)
|
||||
G = self._graph_from_incidence_matrix(self._array, node_ids=node_ids, edge_ids=edge_ids)
|
||||
G, weights = self._graph_from_incidence_matrix(self._array, node_ids=node_ids, edge_ids=edge_ids)
|
||||
self._graph = G
|
||||
return G
|
||||
|
||||
|
Reference in New Issue
Block a user