Skip to content

Commit 41d791d

Browse files
shrutipatel31facebook-github-bot
authored andcommitted
Add Progression Plots for MapMetric experiments to ResultsAnalysis
Summary: This diff adds learning curve visualization (progression plots) to ResultsAnalysis for experiments with MapData and MapMetrics. This is to match the existing `get_standard_plots` function in `ax_sweep_orchestrator.py` which includes MapMetric learning curves. Differential Revision: D89776181 Privacy Context Container: L1307644
1 parent d1d20a4 commit 41d791d

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

ax/analysis/plotly/progression.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
from plotly import graph_objects as go
2727
from pyre_extensions import none_throws, override
2828

29+
PROGRESSION_CARDGROUP_TITLE = "Learning Curves: Metric progression over trials"
30+
PROGRESSION_CARDGROUP_SUBTITLE = (
31+
"These plots show curve metrics (learning curves) that track the evolution of "
32+
"each metric over the course of the experiment. The plots display how metrics "
33+
"change during trial execution, either by progression (e.g., epochs or steps) "
34+
"or by wallclock time. This is useful for monitoring optimization progress and "
35+
"informing early stopping decisions."
36+
)
37+
2938

3039
@final
3140
class ProgressionPlot(Analysis):

ax/analysis/results.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from ax.analysis.best_trials import BestTrials
1414
from ax.analysis.plotly.arm_effects import ArmEffectsPlot
1515
from ax.analysis.plotly.bandit_rollout import BanditRollout
16+
from ax.analysis.plotly.progression import (
17+
PROGRESSION_CARDGROUP_SUBTITLE,
18+
PROGRESSION_CARDGROUP_TITLE,
19+
ProgressionPlot,
20+
)
1621
from ax.analysis.plotly.scatter import (
1722
SCATTER_CARDGROUP_SUBTITLE,
1823
SCATTER_CARDGROUP_TITLE,
@@ -25,6 +30,8 @@
2530
from ax.core.arm import Arm
2631
from ax.core.batch_trial import BatchTrial
2732
from ax.core.experiment import Experiment
33+
from ax.core.map_data import MapData
34+
from ax.core.map_metric import MapMetric
2835
from ax.core.outcome_constraint import ScalarizedOutcomeConstraint
2936
from ax.core.trial_status import TrialStatus
3037
from ax.core.utils import is_bandit_experiment
@@ -239,6 +246,33 @@ def compute(
239246
adapter=adapter,
240247
)
241248

249+
# Compute progression plots for MapMetrics (learning curves)
250+
progression_group = None
251+
data = experiment.lookup_data()
252+
has_map_data = isinstance(data, MapData)
253+
metrics = experiment.metrics.values()
254+
map_metrics = [m for m in metrics if isinstance(m, MapMetric)]
255+
if has_map_data and len(map_metrics) > 0:
256+
map_metric_names = [m.name for m in map_metrics]
257+
progression_cards = [
258+
ProgressionPlot(
259+
metric_name=metric_name, by_wallclock_time=by_wallclock_time
260+
).compute_or_error_card(
261+
experiment=experiment,
262+
generation_strategy=generation_strategy,
263+
adapter=adapter,
264+
)
265+
for metric_name in map_metric_names
266+
for by_wallclock_time in (False, True)
267+
]
268+
if progression_cards:
269+
progression_group = AnalysisCardGroup(
270+
name="ProgressionAnalysis",
271+
title=PROGRESSION_CARDGROUP_TITLE,
272+
subtitle=PROGRESSION_CARDGROUP_SUBTITLE,
273+
children=progression_cards,
274+
)
275+
242276
return self._create_analysis_card_group(
243277
title=RESULTS_CARDGROUP_TITLE,
244278
subtitle=RESULTS_CARDGROUP_SUBTITLE,
@@ -251,6 +285,7 @@ def compute(
251285
bandit_rollout_card,
252286
best_trials_card,
253287
utility_progression_card,
288+
progression_group,
254289
summary,
255290
)
256291
if child is not None

ax/analysis/tests/test_results.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
get_experiment_with_scalarized_objective_and_outcome_constraint,
3636
get_offline_experiments,
3737
get_online_experiments,
38+
get_test_map_data_experiment,
3839
)
3940
from ax.utils.testing.mock import mock_botorch_optimize
4041
from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node
@@ -499,6 +500,39 @@ def test_offline_experiments(self) -> None:
499500
self.assertIsNotNone(card_group)
500501
self.assertGreater(len(card_group.children), 0)
501502

503+
@mock_botorch_optimize
504+
def test_compute_with_map_data_includes_progression_plots(self) -> None:
505+
# Setup: Create experiment with MapData and MapMetrics
506+
experiment = get_test_map_data_experiment(
507+
num_trials=3, num_fetches=2, num_complete=2
508+
)
509+
generation_strategy = get_default_generation_strategy_at_MBM_node(
510+
experiment=experiment
511+
)
512+
513+
# Execute: Compute ResultsAnalysis
514+
card_group = ResultsAnalysis().compute(
515+
experiment=experiment,
516+
generation_strategy=generation_strategy,
517+
)
518+
519+
# Assert: ProgressionAnalysis group exists with children
520+
progression_group = None
521+
for child in card_group.children:
522+
if child.name == "ProgressionAnalysis":
523+
progression_group = child
524+
break
525+
526+
self.assertIsNotNone(
527+
progression_group,
528+
"ProgressionAnalysis group should be present for MapMetric experiments",
529+
)
530+
self.assertGreater(
531+
len(assert_is_instance(progression_group, AnalysisCardGroup).children),
532+
0,
533+
"ProgressionAnalysis group should have at least one progression plot",
534+
)
535+
502536

503537
class TestArmEffectsPair(TestCase):
504538
@mock_botorch_optimize

0 commit comments

Comments
 (0)