Skip to content

Commit c9aca7b

Browse files
authored
[Plot] Remove hidden cap for data and set general scale start from 0 (#66)
1 parent ff02199 commit c9aca7b

File tree

4 files changed

+36
-54
lines changed

4 files changed

+36
-54
lines changed

genai_bench/analysis/flexible_plot_report.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def _add_plot_annotations(
353353
bbox=dict(
354354
boxstyle="round,pad=0.2",
355355
facecolor="white",
356-
alpha=0.7,
356+
alpha=0.1,
357357
edgecolor="none",
358358
),
359359
)
@@ -481,7 +481,7 @@ def _plot_multi_line_metric(
481481
ax.set_xlabel(plot_spec.x_label or self._generate_label(plot_spec.x_field))
482482
ax.set_ylabel(plot_spec.y_label or "Value")
483483
ax.set_title(plot_spec.title)
484-
ax.grid(True, alpha=0.3)
484+
ax.grid(True, alpha=0.1)
485485

486486
# Position legend outside plot area for multi-line plots to avoid overlap
487487
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
@@ -564,7 +564,7 @@ def _save_individual_subplots_multiline(
564564
bbox=dict(
565565
boxstyle="round,pad=0.2",
566566
facecolor="white",
567-
alpha=0.7,
567+
alpha=0.1,
568568
edgecolor="none",
569569
),
570570
)
@@ -593,7 +593,7 @@ def _save_individual_subplots_multiline(
593593
# Copy grid
594594
ax_temp.grid(
595595
ax.get_xgridlines()[0].get_visible() if ax.get_xgridlines() else True,
596-
alpha=0.3,
596+
alpha=0.1,
597597
)
598598
ax_temp.minorticks_on()
599599

genai_bench/analysis/plot_report.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,9 @@ def plot_graph(
5151
else:
5252
x_positions = x_data # type: ignore[assignment]
5353

54-
# If this is TTFT or E2E latency, filter out values outside [0.1, 100]
55-
valid_x = []
56-
valid_y = []
57-
valid_concurrency = []
58-
59-
should_cap = any(
60-
kw in y_label.lower() for kw in ["ttft", "mean e2e", "p90 e2e", "p99 e2e"]
61-
)
62-
63-
if should_cap:
64-
for xx, yy, cc in zip(x_data, y_data, concurrency_levels, strict=False):
65-
if 0.1 <= yy <= 100:
66-
valid_x.append(xx)
67-
valid_y.append(yy)
68-
valid_concurrency.append(cc)
69-
else:
70-
valid_x = x_data
71-
valid_y = y_data
72-
valid_concurrency = concurrency_levels
54+
valid_x = x_data
55+
valid_y = y_data
56+
valid_concurrency = concurrency_levels
7357

7458
# Plot data
7559
if plot_type == "line":
@@ -88,7 +72,7 @@ def plot_graph(
8872
textcoords="offset points",
8973
ha="left",
9074
bbox=dict(
91-
boxstyle="round,pad=0.2", facecolor="white", alpha=0.8, edgecolor="none"
75+
boxstyle="round,pad=0.2", facecolor="white", alpha=0.1, edgecolor="none"
9276
),
9377
)
9478

@@ -101,11 +85,17 @@ def plot_graph(
10185
mticker.LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1, numticks=100)
10286
)
10387

104-
# Cap the y-limits if needed
105-
if should_cap:
106-
ax.set_ylim([0.1, 100])
107-
else:
108-
ax.set_ylim(bottom=0)
88+
# Axis limits handling with autoscale re-enabled every draw
89+
# X-axis: allow Matplotlib to autoscale to include new data, then pin left=0
90+
ax.autoscale(enable=True, axis="x", tight=False)
91+
x_left, x_right = ax.get_xlim()
92+
ax.set_xlim(left=0.0, right=x_right)
93+
94+
# Y-axis: re-autoscale first, then pin bottom=0 for linear scale only
95+
ax.autoscale(enable=True, axis="y", tight=False)
96+
if ax.get_yscale() != "log":
97+
y_bottom, y_top = ax.get_ylim()
98+
ax.set_ylim(bottom=0.0, top=y_top)
10999

110100
ax.set_xlabel(x_label)
111101
ax.set_ylabel(y_label)
@@ -715,6 +705,10 @@ def plot_error_rates(
715705
ax.set_xlabel("Concurrency")
716706
ax.set_ylabel("Error Rate")
717707
ax.set_title("Error Rates by HTTP Status vs Concurrency")
718-
ax.set_ylim(bottom=0)
708+
# Re-enable autoscale for y so subsequent groups can extend the top,
709+
# then pin bottom at 0 (valid for linear scale used here)
710+
ax.autoscale(enable=True, axis="y", tight=False)
711+
y_bottom, y_top = ax.get_ylim()
712+
ax.set_ylim(bottom=0.0, top=y_top)
719713
ax.legend()
720714
ax.grid(True)

genai_bench/cli/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010

1111
from genai_bench.analysis.excel_report import create_workbook
1212
from genai_bench.analysis.experiment_loader import load_one_experiment
13+
from genai_bench.analysis.flexible_plot_report import plot_experiment_data_flexible
1314
from genai_bench.analysis.plot_report import (
1415
plot_single_scenario_inference_speed_vs_throughput,
1516
)
16-
from genai_bench.analysis.flexible_plot_report import plot_experiment_data_flexible
17-
1817
from genai_bench.auth.unified_factory import UnifiedAuthFactory
1918
from genai_bench.cli.option_groups import (
2019
api_options,

tests/analysis/test_plot_report.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def test_plot_single_scenario_rerank(mock_plot_graph, mock_plt, tmp_path, caplog
250250

251251
def test_plot_graph_line():
252252
ax = MagicMock()
253+
ax.get_xlim.return_value = (0, 10)
254+
ax.get_ylim.return_value = (0, 10)
255+
ax.get_yscale.return_value = "linear"
253256
x_data = [1, 2, 3]
254257
y_data = [10, 20, 30]
255258
x_label = "X Axis"
@@ -280,6 +283,9 @@ def test_plot_graph_line():
280283

281284
def test_plot_graph_scatter():
282285
ax = MagicMock()
286+
ax.get_xlim.return_value = (0, 10)
287+
ax.get_ylim.return_value = (0, 10)
288+
ax.get_yscale.return_value = "linear"
283289
x_data = [1, 2, 3]
284290
y_data = [10, 20, 30]
285291
x_label = "X Axis"
@@ -307,6 +313,9 @@ def test_plot_graph_scatter():
307313
def test_plot_graph_concurrency():
308314
"""When x_label is 'Concurrency', x_data is replaced by evenly spaced positions."""
309315
ax = MagicMock()
316+
ax.get_xlim.return_value = (0, 10)
317+
ax.get_ylim.return_value = (0, 10)
318+
ax.get_yscale.return_value = "linear"
310319
x_data = [10, 20, 30]
311320
y_data = [0.5, 1.0, 2.0]
312321
x_label = "Concurrency"
@@ -324,28 +333,6 @@ def test_plot_graph_concurrency():
324333
ax.plot.assert_called_once()
325334

326335

327-
def test_plot_graph_cap():
328-
"""
329-
When y_label triggers value capping (e.g. contains "ttft"),
330-
only y values in [0.1, 100] are plotted.
331-
"""
332-
ax = MagicMock()
333-
x_data = [0, 1, 2, 3]
334-
y_data = [0.05, 0.5, 50, 150] # only 0.5 and 50 are within the valid range
335-
x_label = "Not Concurrency"
336-
y_label = "TTFT"
337-
title = "TTFT Plot"
338-
concurrency_levels = [10, 20, 30, 40]
339-
label = "CapTest"
340-
341-
plot_graph(ax, x_data, y_data, x_label, y_label, title, concurrency_levels, label)
342-
343-
# The plotting call should use only the two valid data points.
344-
ax.plot.assert_called_once()
345-
# And y-limits should be capped to [0.1, 100]
346-
ax.set_ylim.assert_called_with([0.1, 100])
347-
348-
349336
@patch("genai_bench.analysis.plot_report.plot_graph")
350337
@patch("genai_bench.analysis.plot_report.plot_error_rates")
351338
def test_plot_metrics(mock_plot_error_rates, mock_plot_graph):
@@ -552,6 +539,8 @@ def test_plot_error_rates():
552539
ax = MagicMock()
553540
# Ensure unpacking of legend handles/labels works.
554541
ax.get_legend_handles_labels.return_value = ([], [])
542+
# Provide y-limits for autoscale+pin logic
543+
ax.get_ylim.return_value = (0, 1)
555544

556545
def create_agg(freq, num_requests):
557546
agg = MagicMock()

0 commit comments

Comments
 (0)