@@ -235,6 +235,11 @@ def mock_encode(text):
235235 # Count actual tokens in result
236236 # Need to handle mixed content (original lines + decoded text)
237237 total_tokens = 0
238+
239+ # All prompts start with 4 numbers, which are 1 token
240+ total_tokens += 1
241+ result = result [4 :]
242+
238243 # Split by our test lines to count tokens properly
239244 remaining = result
240245 for line in self .test_data :
@@ -270,38 +275,6 @@ def test_sample_text_truncation(self):
270275
271276 # Verify decode was called with truncated tokens
272277 self .tokenizer .decode .assert_called_with (line_tokens [:requested_tokens ])
273- def test_sample_chat_prefix_request (self ):
274- self .tokenizer .encode .side_effect = [
275- [1 ] * 0 ,
276- [1 ] * 11 ,
277- [1 ] * 14 ,
278- [1 ] * 11 ,
279- [1 ] * 11 ,
280- [1 ] * 11 ,
281- [1 ] * 11 ,
282- ]
283- scenario = NormalDistribution (
284- mean_input_tokens = 20 ,
285- stddev_input_tokens = 0 ,
286- mean_output_tokens = 20 ,
287- stddev_output_tokens = 0 ,
288- )
289- prefix_sampler = TextSampler (
290- tokenizer = self .tokenizer ,
291- model = self .model ,
292- output_modality = self .output_modality ,
293- data = self .test_data ,
294- use_scenario = True ,
295- prompt_prefix_ratio = 0.5 , # Set a prefix ratio for testing
296- )
297- result = prefix_sampler .sample (scenario )
298- self .assertIsInstance (result , UserChatRequest )
299- self .assertEqual (result .model , self .model )
300- self .assertTrue (isinstance (result .prompt , str ))
301- self .assertGreater (len (result .prompt ), 0 )
302- # The prompt should start with the generated prefix and a 4-digit number
303- self .assertTrue (result .prompt .startswith (prefix_sampler .prefix ))
304- self .assertEqual (len (result .prompt ), 20 )
305278
306279 def test_sample_chat_prefix_ratio_request (self ):
307280 """Test prefix generation using ratio."""
@@ -343,7 +316,21 @@ def mock_decode(tokens):
343316 self .assertEqual (len (result .prompt ), 20 )
344317
345318 def test_short_prompt_request (self ):
346- self .tokenizer .encode .return_value = [1 ] * 10
319+ """Test that short prompts are handled correctly."""
320+
321+ def mock_encode (text , add_special_tokens = False ):
322+ return [1 ] * len (text )
323+
324+ self .tokenizer .encode = mock_encode
325+
326+ # Mock decode to return the original text
327+ def mock_decode (tokens ):
328+ if isinstance (tokens , list ):
329+ return "a" * len (tokens ) # Return 'a' repeated for the token count
330+ return "decoded_text"
331+
332+ self .tokenizer .decode = mock_decode
333+
347334 self .sampler .data = ["2" ]
348335
349336 # Scenario asks for only 1 input token
0 commit comments