Fixed subdata method so the correct identifiers are stored
This commit is contained in:
parent
e417547923
commit
b37ebe568f
|
@ -230,16 +230,17 @@ class Dataset(object):
|
||||||
"""
|
"""
|
||||||
ds = self.copy()
|
ds = self.copy()
|
||||||
indices = ds.get_indices(dim, idents)
|
indices = ds.get_indices(dim, idents)
|
||||||
|
idents = ds.get_identifiers(dim, indices=indices)
|
||||||
|
if not idents:
|
||||||
|
raise ValueError("No of identifers from: \n%s \nfound in %s" %(str(idents), ds._name))
|
||||||
ax = [i for i, name in enumerate(ds._dims) if name == dim][0]
|
ax = [i for i, name in enumerate(ds._dims) if name == dim][0]
|
||||||
subarr = ds._array.take(indices, ax)
|
subarr = ds._array.take(indices, ax)
|
||||||
for k, v in ds._map[dim].items():
|
new_indices = range(len(idents))
|
||||||
if k not in idents:
|
ds._map[dim] = ReverseDict(zip(idents, new_indices))
|
||||||
del ds._map[dim][k]
|
|
||||||
ds.shape = tuple(len(ds._map[d]) for d in ds._dims)
|
ds.shape = tuple(len(ds._map[d]) for d in ds._dims)
|
||||||
ds.set_array(subarr)
|
ds.set_array(subarr)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
|
|
||||||
def transpose(self):
|
def transpose(self):
|
||||||
"""Returns a copy of transpose of a dataset.
|
"""Returns a copy of transpose of a dataset.
|
||||||
|
|
||||||
|
@ -390,7 +391,7 @@ class GraphDataset(Dataset):
|
||||||
dim0, dim1 = self.get_dim_name()
|
dim0, dim1 = self.get_dim_name()
|
||||||
node_ids = self.get_identifiers(dim0, sorted=True)
|
node_ids = self.get_identifiers(dim0, sorted=True)
|
||||||
edge_ids = self.get_identifiers(dim1, sorted=True)
|
edge_ids = self.get_identifiers(dim1, sorted=True)
|
||||||
G = self._graph_from_incidence_matrix(self._array, node_ids=node_ids, edge_ids=edge_ids)
|
G, weights = self._graph_from_incidence_matrix(self._array, node_ids=node_ids, edge_ids=edge_ids)
|
||||||
self._graph = G
|
self._graph = G
|
||||||
return G
|
return G
|
||||||
|
|
||||||
|
|
Reference in New Issue