Skip to content

Commit 32cbcb5

Browse files
committed
fix: single label w asterisk
1 parent 006b1e6 commit 32cbcb5

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

ditec_wdn_dataset/core/datasets_large.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,11 +1004,19 @@ def stack_features(
10041004

10051005
merging_arrs.append(my_arr)
10061006

1007-
if max_dim == -1: # no non-asterisk attr
1008-
max_dim = sum(a.shape[1] for a in merging_arrs)
1009-
10101007
required_padding = any(has_asterisks)
10111008

1009+
if max_dim == -1: # no non-asterisk attr
1010+
if required_padding:
1011+
# max_dim hasn't been found but still require padding => we are in label/ edge-label cases and required shape from previous node/edge cases
1012+
if is_node:
1013+
max_dim = root.node_mask.shape[0] if root.node_mask is not None else -1
1014+
else:
1015+
max_dim = root.edge_mask.shape[0] if root.edge_mask is not None else -1
1016+
# double-check to prevent in case that node(edge)mask are unavailable for any reason
1017+
if max_dim == -1: # otherwise, max dim is the sum of all available dims
1018+
max_dim = sum(a.shape[1] for a in merging_arrs)
1019+
10121020
if required_padding:
10131021
for i in range(len(merging_arrs)):
10141022
feature: np.ndarray = merging_arrs[i]

ditec_wdn_dataset/hf/dataset.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,11 +1162,19 @@ def stack_features(
11621162

11631163
merging_arrs.append(my_arr)
11641164

1165-
if max_dim == -1: # no non-asterisk attr
1166-
max_dim = sum(a.shape[1] for a in merging_arrs)
1167-
11681165
required_padding = any(has_asterisks)
11691166

1167+
if max_dim == -1: # no non-asterisk attr
1168+
if required_padding:
1169+
# max_dim hasn't been found but still require padding => we are in label/ edge-label cases and required shape from previous node/edge cases
1170+
if is_node:
1171+
max_dim = root.node_mask.shape[0] if root.node_mask is not None else -1
1172+
else:
1173+
max_dim = root.edge_mask.shape[0] if root.edge_mask is not None else -1
1174+
# double-check to prevent in case that node(edge)mask are unavailable for any reason
1175+
if max_dim == -1: # otherwise, max dim is the sum of all available dims
1176+
max_dim = sum(a.shape[1] for a in merging_arrs)
1177+
11701178
if required_padding:
11711179
for i in range(len(merging_arrs)):
11721180
feature: np.ndarray = merging_arrs[i]

0 commit comments

Comments
 (0)