Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/user-guide/run-benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ For heavier traffic scenarios, like `D(16000,200)` or `D(128000,200)`, use the f
--num-concurrency 32 \
```

To benchmark with prefix caching, you can make a given fraction of each prompt a common prefix with `--prompt-prefix-ratio`. For example, to set the first half of each prompt to a common prefix, use:

```shell
--prompt-prefix-ratio 0.5 \
```

## Distributed Benchmark

If you see the message below in the genai-bench logs, it indicates that a single process is insufficient to generate the desired load.
Expand Down
2 changes: 2 additions & 0 deletions genai_bench/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def benchmark(
spawn_rate,
upload_results,
namespace,
prompt_prefix_ratio,
# Storage auth options
storage_provider,
storage_bucket,
Expand Down Expand Up @@ -284,6 +285,7 @@ def benchmark(
data=data,
additional_request_params=additional_request_params,
dataset_config=dataset_config_obj,
prefix_length=prompt_prefix_ratio,
)

# If user did not provide scenarios but provided a dataset, default to dataset mode
Expand Down
8 changes: 8 additions & 0 deletions genai_bench/cli/option_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ def server_options(func):

# Group experiment-related options
def experiment_options(func):
func = click.option(
"--prompt-prefix-ratio",
type=click.FloatRange(0.0, 1.0),
default=0.0,
help="The ratio of prefix length to overall input length "
"to prepend to all inputs to test prefix caching. "
"Value should be between 0.0 and 1.0. ",
)(func)
func = click.option(
"--experiment-folder-name",
type=str,
Expand Down
93 changes: 91 additions & 2 deletions genai_bench/sampling/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
model: str,
output_modality: str,
data: List[str],
prompt_prefix_ratio: float = 0.0,
additional_request_params: Optional[Dict[str, Any]] = None,
dataset_config: Optional[DatasetConfig] = None,
**kwargs,
Expand All @@ -41,6 +42,8 @@ def __init__(

self.data = data
self.batch_size = 1 # Default batch size
self.prompt_prefix_ratio = prompt_prefix_ratio
self.prefix = ""

def sample(self, scenario: Optional[Scenario]) -> UserRequest:
"""
Expand Down Expand Up @@ -167,6 +170,68 @@ def _validate_scenario(self, scenario: Optional[Scenario]) -> None:
f"{type(scenario.scenario_type)}"
)

def _sample_prefix(self, current_prefix_length) -> str:
"""
Generates prefix of length current_prefix_length to be
prepended to all input prompts.
"""

data_copy = self.data.copy()

if not self.data:
raise ValueError("Cannot generate prefix from an empty dataset")

prefix = ""
prefix_tokens_len = 0
# Generate the prefix
while prefix_tokens_len < current_prefix_length:
random.shuffle(data_copy)
for line in data_copy:
line_tokens = self.tokenizer.encode(line)
num_line_tokens = len(line_tokens)
if prefix_tokens_len + num_line_tokens > current_prefix_length:
remaining_prefix_len = current_prefix_length - prefix_tokens_len
truncated_text = self.tokenizer.decode(
line_tokens[:remaining_prefix_len]
)
prefix += truncated_text
return prefix
prefix += line
prefix_tokens_len = len(self.tokenizer.encode(prefix))

return prefix

def _get_current_prefix(self, prefix_length: int) -> str:
"""
Returns the prefix for the current prompt of the specified length.

Args:
current_prefix_length (int): The desired length of the prefix.
"""

# Prefix of the current prompt being generated
current_prefix: str = self.prefix

# Get the difference in length between the existing
# prefix and the desired prefix length

current_prefix_tokens = self.tokenizer.encode(current_prefix)
current_prefix_length = len(current_prefix_tokens)
prefix_length_diff: int = prefix_length - current_prefix_length

# Generate the prefix if it hasn't been created yet, or add
# to its length if it's not long enough
if prefix_length_diff > 0:
self.prefix += self._sample_prefix(prefix_length_diff)
current_prefix = self.prefix

elif prefix_length_diff < 0:
# If the prefix is longer than needed, truncate it
current_prefix = self.tokenizer.decode(
current_prefix_tokens[:prefix_length]
)
return current_prefix

def _sample_text(self, num_input_tokens: Optional[int]) -> str:
"""
Samples text from a list of lines based on the specified number of
Expand All @@ -176,16 +241,40 @@ def _sample_text(self, num_input_tokens: Optional[int]) -> str:
Args:
num_input_tokens (int): The target number of input tokens.

Raises:
ValueError: if the prompt length is shorter than the prefix
length.

Returns:
str: A text prompt containing the desired number of tokens.
"""
if not num_input_tokens:
return random.choice(self.data)

# Calculate actual prefix length based on ratio or fixed length
current_prefix_length = 0
if self.prompt_prefix_ratio > 0.0:
current_prefix_length = round(num_input_tokens * self.prompt_prefix_ratio)

data_copy = self.data.copy()
prompt = ""
left_tokens_to_sample = num_input_tokens

if not self.data:
raise ValueError("Cannot sample text from an empty dataset")

if num_input_tokens <= current_prefix_length:
raise ValueError("Prefix length must be shorter than total input length")

# Get the prompt prefix
current_prefix: str = self._get_current_prefix(current_prefix_length)

# Prepend the prefix to all prompts with a randomly picked 4 digits
prompt = f"{current_prefix}{random.randint(1000,9999)}"

prompt_tokens = self.tokenizer.encode(prompt)
left_tokens_to_sample = num_input_tokens - len(prompt_tokens)

if left_tokens_to_sample < 0:
return self.tokenizer.decode(prompt_tokens[:num_input_tokens])
while left_tokens_to_sample > 0:
random.shuffle(data_copy)
for line in data_copy:
Expand Down
91 changes: 89 additions & 2 deletions tests/sampling/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ def mock_encode(text, add_special_tokens=False):
# Count actual tokens in result
# Need to handle mixed content (original lines + decoded text)
total_tokens = 0

# All prompts start with 4 numbers, which are 1 token
total_tokens += 1
result = result[4:]

# Split by our test lines to count tokens properly
remaining = result
for line in self.test_data:
Expand Down Expand Up @@ -255,6 +260,88 @@ def test_sample_text_truncation(self):
_ = self.sampler._sample_text(requested_tokens)

# Verify decode was called with truncated tokens
self.tokenizer.decode.assert_called_with(
line_tokens[:requested_tokens], skip_special_tokens=True
self.tokenizer.decode.assert_called_with(line_tokens[:requested_tokens])

def test_sample_chat_prefix_ratio_request(self):
"""Test prefix generation using ratio."""

# Mock encode to return list with length equal to number of characters in input
def mock_encode(text, add_special_tokens=False):
# ignore space
encoded_text = [1] * len(text.replace(" ", ""))
return encoded_text

self.tokenizer.encode = mock_encode

# Mock decode to return the original text
def mock_decode(tokens, skip_special_tokens=True):
if isinstance(tokens, list):
return "a" * len(tokens) # Return 'a' repeated for the token count
return "decoded_text"

self.tokenizer.decode = mock_decode

scenario = NormalDistribution(
mean_input_tokens=20,
stddev_input_tokens=0,
mean_output_tokens=20,
stddev_output_tokens=0,
)
prefix_sampler = TextSampler(
tokenizer=self.tokenizer,
model=self.model,
output_modality=self.output_modality,
data=self.test_data,
prompt_prefix_ratio=0.5, # 50% of 20 tokens = 10 tokens
)
result = prefix_sampler.sample(scenario)
self.assertIsInstance(result, UserChatRequest)
self.assertEqual(result.model, self.model)
self.assertTrue(isinstance(result.prompt, str))
self.assertGreater(len(result.prompt), 0)
self.assertTrue(result.prompt.startswith(prefix_sampler.prefix))
self.assertEqual(len(mock_encode(result.prompt)), 20)

def test_short_prompt_request(self):
"""Test that short prompts are handled correctly."""

def mock_encode(text, add_special_tokens=False):
return [1] * len(text)

self.tokenizer.encode = mock_encode

# Mock decode to return the original text
def mock_decode(tokens):
if isinstance(tokens, list):
return "a" * len(tokens) # Return 'a' repeated for the token count
return "decoded_text"

self.tokenizer.decode = mock_decode

self.sampler.data = ["2"]

# Scenario asks for only 1 input token
scenario = NormalDistribution(1, 0, 1, 0)

result = self.sampler.sample(scenario)
self.assertIsInstance(result, UserChatRequest)
# The prompt will be the 4-digit number, truncated to 1 char
self.assertEqual(len(result.prompt), 1)
self.assertGreater(len(result.prompt), 0)

def test_empty_dataset(self):
"""Test sampling from an empty dataset."""
empty_sampler = TextSampler(
tokenizer=self.tokenizer,
model=self.model,
output_modality=self.output_modality,
data=[],
)
scenario = NormalDistribution(10, 0, 10, 0)

with self.assertRaises(ValueError) as context:
empty_sampler.sample(scenario)

self.assertEqual(
str(context.exception), "Cannot sample text from an empty dataset"
)