Skip to content

Commit 9f704a4

Browse files
committed
fix merge issues
1 parent 6f1c840 commit 9f704a4

File tree

2 files changed

+29
-38
lines changed

2 files changed

+29
-38
lines changed

genai_bench/sampling/text.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def _get_current_prefix(self, prefix_length: int) -> str:
210210
# Get the difference in length between the existing
211211
# prefix and the desired prefix length
212212

213-
current_prefix_length = self.get_token_length(current_prefix)
213+
current_prefix_tokens = self.tokenizer.encode(current_prefix)
214+
current_prefix_length = len(current_prefix_tokens)
214215
prefix_length_diff: int = prefix_length - current_prefix_length
215216

216217
# Generate the prefix if it hasn't been created yet, or add
@@ -221,8 +222,9 @@ def _get_current_prefix(self, prefix_length: int) -> str:
221222

222223
elif prefix_length_diff < 0:
223224
# If the prefix is longer than needed, truncate it
224-
char_to_token_ratio = len(current_prefix) / current_prefix_length
225-
current_prefix = self.prefix[: int(prefix_length * char_to_token_ratio)]
225+
current_prefix = self.tokenizer.decode(
226+
current_prefix_tokens[:prefix_length]
227+
)
226228
return current_prefix
227229

228230
def _sample_text(self, num_input_tokens: int) -> str:
@@ -259,10 +261,12 @@ def _sample_text(self, num_input_tokens: int) -> str:
259261

260262
# Prepend the prefix to all prompts with a randomly picked 4 digits
261263
prompt = f"{current_prefix}{random.randint(1000,9999)}"
262-
left_tokens_to_sample = num_input_tokens - self.get_token_length(prompt)
264+
265+
prompt_tokens = self.tokenizer.encode(prompt)
266+
left_tokens_to_sample = num_input_tokens - len(prompt_tokens)
263267

264268
if left_tokens_to_sample < 0:
265-
return prompt[: self.get_token_length(prompt) + left_tokens_to_sample]
269+
return self.tokenizer.decode(prompt_tokens[:num_input_tokens])
266270
while left_tokens_to_sample > 0:
267271
random.shuffle(data_copy)
268272
for line in data_copy:

tests/sampling/test_text.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ def mock_encode(text):
235235
# Count actual tokens in result
236236
# Need to handle mixed content (original lines + decoded text)
237237
total_tokens = 0
238+
239+
# All prompts start with 4 numbers, which are 1 token
240+
total_tokens += 1
241+
result = result[4:]
242+
238243
# Split by our test lines to count tokens properly
239244
remaining = result
240245
for line in self.test_data:
@@ -270,38 +275,6 @@ def test_sample_text_truncation(self):
270275

271276
# Verify decode was called with truncated tokens
272277
self.tokenizer.decode.assert_called_with(line_tokens[:requested_tokens])
273-
def test_sample_chat_prefix_request(self):
274-
self.tokenizer.encode.side_effect = [
275-
[1] * 0,
276-
[1] * 11,
277-
[1] * 14,
278-
[1] * 11,
279-
[1] * 11,
280-
[1] * 11,
281-
[1] * 11,
282-
]
283-
scenario = NormalDistribution(
284-
mean_input_tokens=20,
285-
stddev_input_tokens=0,
286-
mean_output_tokens=20,
287-
stddev_output_tokens=0,
288-
)
289-
prefix_sampler = TextSampler(
290-
tokenizer=self.tokenizer,
291-
model=self.model,
292-
output_modality=self.output_modality,
293-
data=self.test_data,
294-
use_scenario=True,
295-
prompt_prefix_ratio=0.5, # Set a prefix ratio for testing
296-
)
297-
result = prefix_sampler.sample(scenario)
298-
self.assertIsInstance(result, UserChatRequest)
299-
self.assertEqual(result.model, self.model)
300-
self.assertTrue(isinstance(result.prompt, str))
301-
self.assertGreater(len(result.prompt), 0)
302-
# The prompt should start with the generated prefix and a 4-digit number
303-
self.assertTrue(result.prompt.startswith(prefix_sampler.prefix))
304-
self.assertEqual(len(result.prompt), 20)
305278

306279
def test_sample_chat_prefix_ratio_request(self):
307280
"""Test prefix generation using ratio."""
@@ -343,7 +316,21 @@ def mock_decode(tokens):
343316
self.assertEqual(len(result.prompt), 20)
344317

345318
def test_short_prompt_request(self):
346-
self.tokenizer.encode.return_value = [1] * 10
319+
"""Test that short prompts are handled correctly."""
320+
321+
def mock_encode(text, add_special_tokens=False):
322+
return [1] * len(text)
323+
324+
self.tokenizer.encode = mock_encode
325+
326+
# Mock decode to return the original text
327+
def mock_decode(tokens):
328+
if isinstance(tokens, list):
329+
return "a" * len(tokens) # Return 'a' repeated for the token count
330+
return "decoded_text"
331+
332+
self.tokenizer.decode = mock_decode
333+
347334
self.sampler.data = ["2"]
348335

349336
# Scenario asks for only 1 input token

0 commit comments

Comments
 (0)