From 23d420fbeef4a1c029691b9c14a30d56b354e071 Mon Sep 17 00:00:00 2001 From: Enrique Piqueras Date: Fri, 3 Apr 2026 09:48:30 -0700 Subject: [PATCH] Add a test for zero sized arrays in CPP fast packing. PiperOrigin-RevId: 894122292 --- .../dataset/transformations/testing_util.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/grain/_src/python/dataset/transformations/testing_util.py b/grain/_src/python/dataset/transformations/testing_util.py index 031807435..0c55cd0f1 100644 --- a/grain/_src/python/dataset/transformations/testing_util.py +++ b/grain/_src/python/dataset/transformations/testing_util.py @@ -1414,6 +1414,79 @@ def test_pack_max_sequences_per_bin_invalid_value( max_sequences_per_bin=max_sequences_per_bin, ) + def test_pack_sequences_with_zeros(self): + input_elements = [ + { + "input_tokens": [8, 8], + "input_vectors": np.empty((0, 3), dtype=np.int64), + "targets": np.array([], dtype=np.int64), + }, + { + "input_tokens": [1], + "input_vectors": [[0, 1, 2], [1, 2, 3], [2, 3, 4]], + "targets": [10], + }, + { + "input_tokens": [4, 5], + "input_vectors": [[3, 4, 5], [4, 5, 6]], + "targets": [20, 30, 40], + }, + { + "input_tokens": [6], + "input_vectors": [[5, 6, 7]], + "targets": [50, 60], + }, + { + "input_tokens": np.array([], dtype=np.int64), + "input_vectors": np.empty((0, 3), dtype=np.int64), + "targets": np.array([], dtype=np.int64), + }, + ] + length_struct = {"input_tokens": 5, "input_vectors": 3, "targets": 5} + + expected_elements = [ + { + "input_tokens": [8, 8, 1, 0, 0], + "input_tokens_segment_ids": [1, 1, 2, 0, 0], + "input_tokens_positions": [0, 1, 0, 0, 0], + "input_vectors": [ + [0, 1, 2], + [1, 2, 3], + [2, 3, 4], + ], + "input_vectors_segment_ids": [2, 2, 2], + "input_vectors_positions": [0, 1, 2], + "targets": [10, 0, 0, 0, 0], + "targets_segment_ids": [2, 0, 0, 0, 0], + "targets_positions": [0, 0, 0, 0, 0], + }, + { + "input_tokens": [4, 5, 6, 0, 0], + "input_tokens_segment_ids": [1, 1, 2, 0, 0], + "input_tokens_positions": [0, 1, 0, 0, 0], + "input_vectors": [ + [3, 4, 5], + [4, 5, 6], + [5, 6, 7], + ], + "input_vectors_segment_ids": [1, 1, 2], + "input_vectors_positions": [0, 1, 0], + "targets": [20, 30, 40, 50, 60], + "targets_segment_ids": [1, 1, 1, 2, 2], + "targets_positions": [0, 1, 2, 0, 1], + }, + ] + + _common_test_body( + self.packer_cls, + input_elements, + expected_elements, + length_struct, + kwargs=self.kwargs, + num_packing_bins=2, + max_sequences_per_bin=3, + ) + class BaseBestFitPackIterDatasetTest(BaseFirstFitPackIterDatasetTest): """Base test for the Best-Fit packing algorithm.