Skip to content

Commit cd1c7b0

Browse files
committed
add tests
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 9a7c68c commit cd1c7b0

File tree

4 files changed

+1040
-0
lines changed

4 files changed

+1040
-0
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
from typing import Any, Dict, List
16+
17+
import pytest
18+
import torch
19+
20+
import nemo_automodel.components.datasets.llm.retrieval_collator as rc
21+
22+
23+
class FakeTokenizer:
24+
def __call__(
25+
self,
26+
texts: List[str],
27+
max_length: int,
28+
padding: Any,
29+
truncation: bool,
30+
return_token_type_ids: bool,
31+
) -> Dict[str, List[List[int]]]:
32+
# Simple whitespace tokenizer: ids are range(len(tokens))
33+
input_ids = []
34+
attention_masks = []
35+
for t in texts:
36+
tokens = t.split()
37+
if truncation:
38+
tokens = tokens[:max_length]
39+
ids = list(range(len(tokens)))
40+
mask = [1] * len(ids)
41+
input_ids.append(ids)
42+
attention_masks.append(mask)
43+
return {"input_ids": input_ids, "attention_mask": attention_masks}
44+
45+
def pad(
46+
self,
47+
features: List[Dict[str, List[int]]],
48+
padding: Any,
49+
pad_to_multiple_of: int,
50+
return_tensors: str,
51+
) -> Dict[str, torch.Tensor]:
52+
# Determine max length and round to multiple if requested
53+
max_len = max(len(f["input_ids"]) for f in features) if features else 0
54+
if pad_to_multiple_of and max_len % pad_to_multiple_of != 0:
55+
max_len = int(math.ceil(max_len / pad_to_multiple_of) * pad_to_multiple_of)
56+
input_ids = []
57+
attention_masks = []
58+
for f in features:
59+
ids = list(f["input_ids"])
60+
mask = list(f["attention_mask"])
61+
pad_len = max_len - len(ids)
62+
ids = ids + [0] * pad_len
63+
mask = mask + [0] * pad_len
64+
input_ids.append(ids)
65+
attention_masks.append(mask)
66+
return {
67+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
68+
"attention_mask": torch.tensor(attention_masks, dtype=torch.long),
69+
}
70+
71+
72+
def test_unpack_doc_values():
73+
features = [
74+
{"input_ids": [[1, 2], [3]], "attention_mask": [[1, 1], [1]]},
75+
]
76+
out = rc._unpack_doc_values(features)
77+
assert out == [{"input_ids": [1, 2], "attention_mask": [1, 1]}, {"input_ids": [3], "attention_mask": [1]}]
78+
79+
80+
def test_merge_and_convert_helpers():
81+
collator = rc.RetrievalBiencoderCollator(FakeTokenizer())
82+
query_batch = {"input_ids": [[10], [20]], "attention_mask": [[1], [1]]} # batch_size = 2
83+
# 2 examples * train_n_passages(=2) = 4 document rows
84+
doc_batch = {"input_ids": [[100], [101], [110], [111]], "attention_mask": [[1], [1], [1], [1]]}
85+
merged = collator._merge_batch_dict(query_batch, doc_batch, train_n_passages=2)
86+
# Ensure query keys are prefixed and doc keys reshaped to [batch, passages, seq]
87+
assert "q_input_ids" in merged and "d_input_ids" in merged
88+
assert merged["d_input_ids"] == [[[100], [101]], [[110], [111]]]
89+
# Convert dict-of-lists to list-of-dicts
90+
lst = collator._convert_dict_to_list({"a": [1, 2], "b": [3, 4]})
91+
assert lst == [{"a": 1, "b": 3}, {"a": 2, "b": 4}]
92+
93+
94+
def _make_batch(num_examples: int = 2, docs_per_example: int = 3) -> List[Dict[str, Any]]:
95+
batch = []
96+
for i in range(num_examples):
97+
question = f"what is item {i}"
98+
docs = [f"doc {i}-{j}" for j in range(docs_per_example)]
99+
batch.append({"question": question, "doc_text": docs, "doc_image": [""] * docs_per_example})
100+
return batch
101+
102+
103+
def test_collator_end_to_end_no_prefix():
104+
tok = FakeTokenizer()
105+
collator = rc.RetrievalBiencoderCollator(tokenizer=tok, q_max_len=16, p_max_len=16, padding=True)
106+
batch = _make_batch(num_examples=2, docs_per_example=3)
107+
out = collator(batch)
108+
# Expected keys
109+
for k in ["q_input_ids", "q_attention_mask", "d_input_ids", "d_attention_mask", "labels"]:
110+
assert k in out
111+
# Shapes: queries [B, Lq], docs [B * P, Ld], labels [B]
112+
assert out["q_input_ids"].shape[0] == 2
113+
assert out["d_input_ids"].shape[0] == 2 * 3
114+
assert out["labels"].dtype == torch.long and out["labels"].shape[0] == 2 and torch.all(out["labels"] == 0)
115+
# Ensure attention masks align with input_ids shapes
116+
assert out["q_input_ids"].shape == out["q_attention_mask"].shape
117+
assert out["d_input_ids"].shape == out["d_attention_mask"].shape
118+
119+
120+
def test_collator_with_prefix_and_pad_multiple():
121+
tok = FakeTokenizer()
122+
collator = rc.RetrievalBiencoderCollator(
123+
tokenizer=tok, q_max_len=32, p_max_len=32, query_prefix="Q:", passage_prefix="D:", padding=True, pad_to_multiple_of=4
124+
)
125+
# Make varying lengths so padding is exercised and rounded to multiple-of 4
126+
batch = [
127+
{"question": "short", "doc_text": ["tiny", "a bit longer"], "doc_image": ["", ""]},
128+
{"question": "this is a somewhat longer question", "doc_text": ["short doc", "this is a longish doc text"], "doc_image": ["", ""]},
129+
]
130+
out = collator(batch)
131+
# Verify padding rounded to multiple of 4
132+
assert out["q_input_ids"].shape[1] % 4 == 0
133+
assert out["d_input_ids"].shape[1] % 4 == 0
134+
# Still produces expected label size
135+
assert out["labels"].shape[0] == 2
136+
137+

0 commit comments

Comments
 (0)