diff --git a/src/_lcm/simulation/additional_targets.py b/src/_lcm/simulation/additional_targets.py index 637918ea..3479c86e 100644 --- a/src/_lcm/simulation/additional_targets.py +++ b/src/_lcm/simulation/additional_targets.py @@ -99,12 +99,12 @@ def _compute_targets( targets: list[str], regime: Regime, regime_params: FlatRegimeParams, - subject_batch_size: int | None = None, + target_batch_size: int | None = None, ) -> dict[str, FloatND | IntND | BoolND | np.ndarray]: """Compute additional targets for a regime. The target DAG is vmapped over the regime's in-regime subject-period rows. When - `subject_batch_size` is a positive value below the row count, the rows are + `target_batch_size` is a positive value below the row count, the rows are processed in chunks and each chunk's outputs are pulled to host before the next runs, so the fused-DAG device workspace is bounded by the chunk rather than the full population. `0`/`None` (or any value at least the row count) evaluates in a @@ -124,7 +124,7 @@ def _compute_targets( inputs = {k: v for k, v in data.items() if k in variables} n_rows = len(data["period"]) - if not subject_batch_size or subject_batch_size >= n_rows: + if not target_batch_size or target_batch_size >= n_rows: kwargs = {k: jnp.asarray(v) for k, v in inputs.items()} result = vectorized_func(**all_params, **kwargs) return {k: jnp.squeeze(v) for k, v in result.items()} @@ -133,8 +133,8 @@ def _compute_targets( # time. Squeeze the *concatenated* result, never a chunk — an uneven final # chunk of one row would otherwise lose its row axis. chunk_outputs: list[dict[str, np.ndarray]] = [] - for start in range(0, n_rows, subject_batch_size): - stop = min(start + subject_batch_size, n_rows) + for start in range(0, n_rows, target_batch_size): + stop = min(start + target_batch_size, n_rows) chunk_kwargs = {k: jnp.asarray(v[start:stop]) for k, v in inputs.items()} chunk_result = vectorized_func(**all_params, **chunk_kwargs) chunk_outputs.append({k: np.asarray(v) for k, v in chunk_result.items()}) diff --git a/src/_lcm/simulation/result_dataframe.py b/src/_lcm/simulation/result_dataframe.py index d1d2fa21..c95b3416 100644 --- a/src/_lcm/simulation/result_dataframe.py +++ b/src/_lcm/simulation/result_dataframe.py @@ -28,7 +28,7 @@ def _create_flat_dataframe( metadata: ResultMetadata, additional_targets: list[str] | None, ages: AgeGrid, - subject_batch_size: int | None = None, + target_batch_size: int | None = None, ) -> pd.DataFrame: """Create a single flat DataFrame from all regime results. @@ -49,7 +49,7 @@ def _create_flat_dataframe( regime_params=flat_params[name], additional_targets=additional_targets, ages=ages, - subject_batch_size=subject_batch_size, + target_batch_size=target_batch_size, ) for name in metadata.regime_names if raw_results[name] @@ -72,7 +72,7 @@ def _process_regime( regime_params: FlatRegimeParams, additional_targets: list[str] | None, ages: AgeGrid, - subject_batch_size: int | None = None, + target_batch_size: int | None = None, ) -> pd.DataFrame: """Process results for a single regime into a DataFrame. @@ -115,7 +115,7 @@ def _process_regime( targets=targets_for_regime, regime=regime, regime_params=regime_params, - subject_batch_size=subject_batch_size, + target_batch_size=target_batch_size, ) data.update(target_values) diff --git a/src/lcm/result.py b/src/lcm/result.py index 5364c476..1ea65b9e 100644 --- a/src/lcm/result.py +++ b/src/lcm/result.py @@ -127,6 +127,7 @@ def to_dataframe( additional_targets: list[str] | Literal["all"] | None = None, *, use_labels: bool = True, + target_batch_size: int | None = None, ) -> pd.DataFrame: """Convert simulation results to a flat pandas DataFrame. @@ -137,12 +138,20 @@ def to_dataframe( - "all": Compute all available targets (see `available_targets`) Targets can be any function defined in a regime. Each target is computed for the regimes where it exists; rows from regimes without - that target will have NaN. When `simulate` ran with - `subject_batch_size` set, target evaluation is chunked over subjects - with that batch size (bounding device memory; values are unchanged). + that target will have NaN. use_labels: If True (default), discrete variables (states, actions, and regime) are returned as pandas Categorical dtype with string labels. If False, discrete variables are returned as integer codes. + target_batch_size: Chunk size for the `additional_targets` evaluation. + A positive value below the in-regime row count processes the rows in + chunks, pulling each chunk to host before the next runs, so the fused + target-DAG device workspace is bounded by the chunk rather than the + full population. `None` (default) falls back to the + `subject_batch_size` the simulation ran with. Set it explicitly to + bound target-eval device memory independently of the simulate — e.g. + when the simulate ran single-pass under a distributed grid, where + raising `subject_batch_size` is not available. Values are identical + to the single-pass evaluation. Returns: DataFrame with simulation results. @@ -153,6 +162,10 @@ def to_dataframe( available_targets=self.available_targets, ) + effective_target_batch_size = ( + self._subject_batch_size if target_batch_size is None else target_batch_size + ) + df = _create_flat_dataframe( raw_results=self._raw_results, regimes=self._regimes, @@ -160,7 +173,7 @@ def to_dataframe( metadata=self._metadata, additional_targets=resolved_targets, ages=self._ages, - subject_batch_size=self._subject_batch_size, + target_batch_size=effective_target_batch_size, ) if use_labels: diff --git a/tests/simulation/test_subject_batching.py b/tests/simulation/test_subject_batching.py index a8d605bf..51738eb3 100644 --- a/tests/simulation/test_subject_batching.py +++ b/tests/simulation/test_subject_batching.py @@ -155,3 +155,43 @@ def test_raw_results_are_host_resident_jax_arrays_when_batched() -> None: v_arr = result.raw_results["work"][0].V_arr assert isinstance(v_arr, jax.Array) assert v_arr.devices() == {jax.devices("cpu")[0]} + + +def _target_batch_df(*, target_batch_size: int) -> pd.DataFrame: + """Simulate in a single pass, then chunk only the `to_dataframe` target eval.""" + model = get_multi_regime_model(n_periods=6, distribution_type="normal") + params = get_multi_regime_params("normal") + result = model.simulate( + log_level="off", + params=params, + initial_conditions=_INITIAL_CONDITIONS, + period_to_regime_to_V_arr=None, + seed=42, + subject_batch_size=0, + ) + return ( + result.to_dataframe( + additional_targets=["utility"], target_batch_size=target_batch_size + ) + .sort_values(["subject_id", "period"]) + .reset_index(drop=True) + ) + + +@pytest.mark.parametrize("target_batch_size", [2, 3, 100]) +def test_to_dataframe_targets_are_invariant_to_target_batch_size( + target_batch_size: int, +) -> None: + """`to_dataframe(target_batch_size=N)` chunks the target eval on its own knob. + + The post-simulate `additional_targets` DAG (`utility`) is evaluated over the + in-regime rows in chunks of `target_batch_size`, independently of the simulate's + `subject_batch_size` (here `0` — the single-pass case a distributed/sharded + simulate produces, where raising `subject_batch_size` is unavailable). An even + split (2 over 7), an uneven one (3 → 3, 3, 1), and a chunk larger than the + population (100 → single chunk) each reproduce the single-pass `utility` column + exactly. + """ + baseline = _target_batch_df(target_batch_size=0) + batched = _target_batch_df(target_batch_size=target_batch_size) + _assert_columns_invariant(baseline, batched)