Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/_lcm/simulation/additional_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def _compute_targets(
targets: list[str],
regime: Regime,
regime_params: FlatRegimeParams,
subject_batch_size: int | None = None,
target_batch_size: int | None = None,
) -> dict[str, FloatND | IntND | BoolND | np.ndarray]:
"""Compute additional targets for a regime.

The target DAG is vmapped over the regime's in-regime subject-period rows. When
`subject_batch_size` is a positive value below the row count, the rows are
`target_batch_size` is a positive value below the row count, the rows are
processed in chunks and each chunk's outputs are pulled to host before the next
runs, so the fused-DAG device workspace is bounded by the chunk rather than the
full population. `0`/`None` (or any value at least the row count) evaluates in a
Expand All @@ -124,7 +124,7 @@ def _compute_targets(
inputs = {k: v for k, v in data.items() if k in variables}
n_rows = len(data["period"])

if not subject_batch_size or subject_batch_size >= n_rows:
if not target_batch_size or target_batch_size >= n_rows:
kwargs = {k: jnp.asarray(v) for k, v in inputs.items()}
result = vectorized_func(**all_params, **kwargs)
return {k: jnp.squeeze(v) for k, v in result.items()}
Expand All @@ -133,8 +133,8 @@ def _compute_targets(
# time. Squeeze the *concatenated* result, never a chunk — an uneven final
# chunk of one row would otherwise lose its row axis.
chunk_outputs: list[dict[str, np.ndarray]] = []
for start in range(0, n_rows, subject_batch_size):
stop = min(start + subject_batch_size, n_rows)
for start in range(0, n_rows, target_batch_size):
stop = min(start + target_batch_size, n_rows)
chunk_kwargs = {k: jnp.asarray(v[start:stop]) for k, v in inputs.items()}
chunk_result = vectorized_func(**all_params, **chunk_kwargs)
chunk_outputs.append({k: np.asarray(v) for k, v in chunk_result.items()})
Expand Down
8 changes: 4 additions & 4 deletions src/_lcm/simulation/result_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _create_flat_dataframe(
metadata: ResultMetadata,
additional_targets: list[str] | None,
ages: AgeGrid,
subject_batch_size: int | None = None,
target_batch_size: int | None = None,
) -> pd.DataFrame:
"""Create a single flat DataFrame from all regime results.

Expand All @@ -49,7 +49,7 @@ def _create_flat_dataframe(
regime_params=flat_params[name],
additional_targets=additional_targets,
ages=ages,
subject_batch_size=subject_batch_size,
target_batch_size=target_batch_size,
)
for name in metadata.regime_names
if raw_results[name]
Expand All @@ -72,7 +72,7 @@ def _process_regime(
regime_params: FlatRegimeParams,
additional_targets: list[str] | None,
ages: AgeGrid,
subject_batch_size: int | None = None,
target_batch_size: int | None = None,
) -> pd.DataFrame:
"""Process results for a single regime into a DataFrame.

Expand Down Expand Up @@ -115,7 +115,7 @@ def _process_regime(
targets=targets_for_regime,
regime=regime,
regime_params=regime_params,
subject_batch_size=subject_batch_size,
target_batch_size=target_batch_size,
)
data.update(target_values)

Expand Down
21 changes: 17 additions & 4 deletions src/lcm/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def to_dataframe(
additional_targets: list[str] | Literal["all"] | None = None,
*,
use_labels: bool = True,
target_batch_size: int | None = None,
) -> pd.DataFrame:
"""Convert simulation results to a flat pandas DataFrame.

Expand All @@ -137,12 +138,20 @@ def to_dataframe(
- "all": Compute all available targets (see `available_targets`)
Targets can be any function defined in a regime. Each target is
computed for the regimes where it exists; rows from regimes without
that target will have NaN. When `simulate` ran with
`subject_batch_size` set, target evaluation is chunked over subjects
with that batch size (bounding device memory; values are unchanged).
that target will have NaN.
use_labels: If True (default), discrete variables (states, actions, and
regime) are returned as pandas Categorical dtype with string labels.
If False, discrete variables are returned as integer codes.
target_batch_size: Chunk size for the `additional_targets` evaluation.
A positive value below the in-regime row count processes the rows in
chunks, pulling each chunk to host before the next runs, so the fused
target-DAG device workspace is bounded by the chunk rather than the
full population. `None` (default) falls back to the
`subject_batch_size` the simulation ran with. Set it explicitly to
bound target-eval device memory independently of the simulate — e.g.
when the simulate ran single-pass under a distributed grid, where
raising `subject_batch_size` is not available. Values are identical
to the single-pass evaluation.

Returns:
DataFrame with simulation results.
Expand All @@ -153,14 +162,18 @@ def to_dataframe(
available_targets=self.available_targets,
)

effective_target_batch_size = (
self._subject_batch_size if target_batch_size is None else target_batch_size
)

df = _create_flat_dataframe(
raw_results=self._raw_results,
regimes=self._regimes,
flat_params=self._flat_params,
metadata=self._metadata,
additional_targets=resolved_targets,
ages=self._ages,
subject_batch_size=self._subject_batch_size,
target_batch_size=effective_target_batch_size,
)

if use_labels:
Expand Down
40 changes: 40 additions & 0 deletions tests/simulation/test_subject_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,43 @@ def test_raw_results_are_host_resident_jax_arrays_when_batched() -> None:
v_arr = result.raw_results["work"][0].V_arr
assert isinstance(v_arr, jax.Array)
assert v_arr.devices() == {jax.devices("cpu")[0]}


def _target_batch_df(*, target_batch_size: int) -> pd.DataFrame:
"""Simulate in a single pass, then chunk only the `to_dataframe` target eval."""
model = get_multi_regime_model(n_periods=6, distribution_type="normal")
params = get_multi_regime_params("normal")
result = model.simulate(
log_level="off",
params=params,
initial_conditions=_INITIAL_CONDITIONS,
period_to_regime_to_V_arr=None,
seed=42,
subject_batch_size=0,
)
return (
result.to_dataframe(
additional_targets=["utility"], target_batch_size=target_batch_size
)
.sort_values(["subject_id", "period"])
.reset_index(drop=True)
)


@pytest.mark.parametrize("target_batch_size", [2, 3, 100])
def test_to_dataframe_targets_are_invariant_to_target_batch_size(
target_batch_size: int,
) -> None:
"""`to_dataframe(target_batch_size=N)` chunks the target eval on its own knob.

The post-simulate `additional_targets` DAG (`utility`) is evaluated over the
in-regime rows in chunks of `target_batch_size`, independently of the simulate's
`subject_batch_size` (here `0` — the single-pass case a distributed/sharded
simulate produces, where raising `subject_batch_size` is unavailable). An even
split (2 over 7), an uneven one (3 → 3, 3, 1), and a chunk larger than the
population (100 → single chunk) each reproduce the single-pass `utility` column
exactly.
"""
baseline = _target_batch_df(target_batch_size=0)
batched = _target_batch_df(target_batch_size=target_batch_size)
_assert_columns_invariant(baseline, batched)