Add modular RLSSM simulator framework#278
Conversation
Milestone 1, Commit 1: Defines the LearningProcess protocol (the handshake contract between learning and decision processes) and the first built-in implementation — RescorlaWagnerDeltaRule — which is numerically equivalent to HSSM's compute_v_trial_wise(). Includes 13 unit tests covering Q-value trajectories, drift ordering, HSSM numerical equivalence, and protocol compliance. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 2: Defines the TaskEnvironment protocol for reward generation, TwoArmedBandit (Bernoulli bandit with configurable per-arm probabilities), and TaskConfig convenience dataclass for common paradigms. Includes 14 unit tests covering reward statistics, reproducibility, input validation, protocol compliance, and TaskConfig builder. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 3: Defines RLSSMModelConfig — the structural model specification that resolves the handshake between learning process and decision process (SSM). Auto-derives list_params, bounds, and defaults from components. Includes validate() for config consistency checking and to_hssm_config_dict() for bridging to HSSM's RLSSMConfig. 13 tests cover auto-derivation, handshake validation, computed_param_mapping, TaskConfig auto-build, and HSSM dict contract. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 4: Implements the core RLSSMSimulator class that runs the trial-by-trial interleaved loop: compute SSM params from learning state, simulate one SSM trial, observe choice, generate reward, update learning. Reuses ssm-simulators' existing simulator() with n_samples=1 — all 40+ SSM models work as decision processes. Includes 15 tests covering DataFrame output shape, balanced panel, reproducibility, theta validation, edge cases, and omission handling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 5: Adds the preset registry (register/get/list_rlssm_preset) with rlssm1 preset (RW delta rule + angle SSM + two-armed bandit). Wires up the ssms.rl public API with full __all__ exports and adds `from . import rl` to ssms/__init__.py. Fixes circular import in rl_simulator.py (OMISSION_SENTINEL). Includes 13 contract tests for HSSM compatibility (output dtypes, no NaNs, config dict schema) and registry smoke tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report❌ Patch coverage is
Flags with carried forward coverage won't be shown. Click here to find out more.
... and 1 file with indirect coverage changes 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This PR introduces a new modular RLSSM simulation framework under ssms.rl, designed to interleave trial-wise reinforcement learning updates with existing SSM decision simulators and to export HSSM-compatible configuration/data.
Changes:
- Added
ssms.rlcore components:ModelConfig,Simulator, task environments (bandits), learning rules (Rescorla–Wagner variants), and a preset registry. - Added comprehensive RLSSM-focused tests (learning, env, simulator behavior, and HSSM compatibility/contract checks).
- Added an MkDocs tutorial entry for the new RLSSM simulator workflow.
Reviewed changes
Copilot reviewed 13 out of 15 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
ssms/rl/config.py |
Defines the structural RLSSM configuration and HSSM export support. |
ssms/rl/simulator.py |
Implements the interleaved learning + SSM simulation loop and output formatting. |
ssms/rl/env.py |
Adds a task environment protocol plus Bernoulli/Gaussian bandit implementations and task registry. |
ssms/rl/learning.py |
Adds the learning process protocol and Rescorla–Wagner learning rules. |
ssms/rl/preset.py |
Adds an RLSSM preset registry and a built-in rlssm1 preset. |
ssms/rl/__init__.py |
Exposes the public ssms.rl API surface. |
ssms/__init__.py |
Re-exports the rl module at the package top level. |
tests/rl/test_task_environment.py |
Tests bandit environment behavior, validation, and task config building. |
tests/rl/test_learning_process.py |
Tests RW learning rules’ numerical behavior and protocol compliance. |
tests/rl/test_rl_config.py |
Tests config auto-derivation, handshake validation, response mapping, and HSSM dict export. |
tests/rl/test_rl_simulator.py |
Tests simulation output schema, reproducibility, omission handling, and response/action mapping. |
tests/rl/test_hssm_compatibility.py |
Contract tests for HSSM consumability and preset registry behavior. |
tests/rl/__init__.py |
Initializes the RL test package. |
mkdocs.yml |
Adds the RLSSM tutorial notebook to the docs nav. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Brings main's docs/branding refresh and simulator updates into the RLSSM feature branch (PR #278). The only conflict was uv.lock; resolved by taking main's lockfile as base and regenerating with `uv lock` so it stays consistent with the merged pyproject.toml. uv.lock remains tracked (repo convention). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
cpaniaguam
left a comment
There was a problem hiding this comment.
Not a thorough review -- missing some files I didn't comment on.
Reject empty list_params during ModelConfig validation and stop treating empty bandit reward strings as the default bernoulli reward. Co-authored-by: Cursor <cursoragent@cursor.com>
Remove unused compiled.py sentinel, simplify computed-param mapping, document direct compile validation, and add uv sync --extra jax docs. Co-authored-by: Cursor <cursoragent@cursor.com>
|
Important Review skippedReview was skipped as selected files did not have any reviewable changes. 💤 Files selected but had no reviewable changes (1)
⛔ Files ignored due to path filters (1)
⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Plus Run ID: ⛔ Files ignored due to path filters (1)
📒 Files selected for processing (1)
You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds the complete Changesssms.rl RLSSM Framework
Sequence Diagram(s)sequenceDiagram
participant User
participant Simulator
participant ModelConfig
participant LearningProcess
participant TaskEnvironment
participant SSM as ssms.simulator
User->>ModelConfig: ModelConfig(decision_process, learning_process, task)
ModelConfig->>ModelConfig: __post_init__ (derive choices, handshake, backend)
User->>Simulator: Simulator(config)
User->>Simulator: simulate(theta, mode="generative", n_participants, n_trials)
Simulator->>Simulator: validate_theta_keys, expand_participant_theta
loop per participant
Simulator->>LearningProcess: reset / init_state
Simulator->>TaskEnvironment: reset(seed)
loop per trial
Simulator->>LearningProcess: compute_python/compute_jax(state, params, context)
LearningProcess-->>Simulator: SSM params {v, a, ...}
Simulator->>SSM: simulator(model, merged_params, n_samples=1)
SSM-->>Simulator: rt, response
Simulator->>LearningProcess: update_python/update_jax(state, params, context)
LearningProcess-->>Simulator: new state
end
end
Simulator-->>User: pd.DataFrame (balanced panel)
sequenceDiagram
participant HSSM
participant ModelConfig
participant AssembledModel
participant LearningProcess
HSSM->>ModelConfig: to_hssm_config_dict()
ModelConfig-->>HSSM: dict (participant_contract, placeholders, learning_process_kind)
HSSM->>ModelConfig: assemble(backend="jax")
ModelConfig->>AssembledModel: from_config(config, backend)
AssembledModel->>AssembledModel: resolve backend, gradient, snapshot metadata
HSSM->>AssembledModel: assemble_participant_fn(input_fields, output="dict")
AssembledModel->>LearningProcess: init_jax_state / compute_jax / update_jax
AssembledModel-->>HSSM: participant_fn (JAX lax.scan callable)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (2)
ssms/rl/env.py (1)
200-203: ⚡ Quick winPrevent silent task-name overrides in the registry.
register_task()currently replaces existing entries without signaling. A duplicate registration can silently changeTaskConfigbehavior at runtime; make override explicit (or reject duplicates).Suggested change
-def register_task(task: str, builder: TaskEnvironmentBuilder) -> None: +def register_task( + task: str, builder: TaskEnvironmentBuilder, *, overwrite: bool = False +) -> None: """Register a task environment builder for ``TaskConfig``.""" + if not overwrite and task in _TASK_REGISTRY: + raise ValueError( + f"Task '{task}' is already registered. " + "Pass overwrite=True to replace it." + ) _TASK_REGISTRY[task] = builder🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@ssms/rl/env.py` around lines 200 - 203, The register_task() function silently overwrites existing entries in _TASK_REGISTRY without any warning or error, which can cause unexpected behavior changes at runtime. Add a check before assigning to _TASK_REGISTRY to either reject duplicate registrations by raising an exception (e.g., ValueError) if the task name already exists, or log a warning message to explicitly signal that an override is occurring. This ensures task registration changes are intentional rather than accidental.tests/rl/test_learning_process.py (1)
162-173: ⚡ Quick winAdd regression tests for invalid action indices.
Please add explicit tests for negative and out-of-range actions on
update()/update_python()for both learning rules, so the zero-based action contract is enforced in CI.Also applies to: 259-267
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/rl/test_learning_process.py` around lines 162 - 173, Add regression tests to verify that both RescorlaWagnerDeltaRule and the learning rule at lines 259-267 properly validate action indices in their update() and update_python() methods. Create test methods that assert RuntimeError is raised when negative action indices are passed to update(), and separate test methods that assert RuntimeError is raised when out-of-range action indices are passed to update(). Repeat these validation tests for both update() and update_python() methods across both learning rule classes to ensure the zero-based action contract is enforced and caught in CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@ssms/rl/compiled.py`:
- Around line 283-287: The JAX path in the compute function (lines 283–287)
silently defaults unknown response labels to action 0, whereas the Python path
(line 270) raises a KeyError when encountering unmapped labels. Add pre-scan
validation at the start of the compute function that checks whether all unique
response values in the data exist as keys in self.response_to_choice, and raise
a ValueError listing any unmapped labels found. This ensures consistent error
handling between the two code paths and prevents silent corruption of learning
trajectories.
In `@ssms/rl/config.py`:
- Around line 301-305: The auto mode fallback in the learning_backend property
returns "python" without verifying that "python" is actually in the available
backends list. If the learning process only supports "jax" and JAX is
unavailable, this will return an unsupported backend and defer the error to
runtime. In the conditional block where learning_backend equals "auto", after
checking if "jax" is available and in the available set, add a check to ensure
"python" is also in the available set before returning it as the fallback. If
"python" is not available either, raise a configuration error immediately with a
clear message indicating that no supported learning backend is available.
In `@ssms/rl/learning.py`:
- Around line 223-235: Add validation in both update_python() methods (the one
shown in the diff and the one in RescorlaWagnerDualAlphaRule) to verify that the
choice index is within valid bounds before using it to index into q_values.
After extracting and converting choice to int, validate that it is greater than
or equal to 0 and less than the length of q_values. If the choice is invalid,
raise an appropriate error to fail fast rather than silently updating the wrong
arm.
In `@ssms/rl/simulator.py`:
- Around line 139-140: Remove the int() type coercion being applied to
participant_id and trial_id across multiple locations in the simulator.py file.
At line 139-140 where participant_id is converted with int(participant_id),
remove the int() call to preserve the original identifier type. Apply the same
fix at line 158-159 where participant_id is similarly coerced, and at line 410
where trial_id is coerced to int(). This will ensure that non-integer IDs are
not corrupted and that original identifier values are preserved throughout the
codebase.
In `@ssms/rl/validation.py`:
- Around line 300-301: The validation logic at lines 300, 313-314, and 339-341
performs type conversions (int(v) and dtype=float) that can raise unhandled
exceptions when encountering non-numeric values in response/rt columns instead
of gracefully reporting validation failures. Wrap the int(v) conversion at line
300 in a try-except block to catch ValueError exceptions for non-numeric values
and collect them as invalid choices to be reported. Similarly, wrap the
dtype=float conversion at line 340 in a try-except block to catch conversion
errors and report them through the DataValidationReport instead of crashing.
Ensure that all malformed or non-numeric values are captured and included in the
validation report's issues list rather than causing hard exceptions that bypass
the validation reporting mechanism.
---
Nitpick comments:
In `@ssms/rl/env.py`:
- Around line 200-203: The register_task() function silently overwrites existing
entries in _TASK_REGISTRY without any warning or error, which can cause
unexpected behavior changes at runtime. Add a check before assigning to
_TASK_REGISTRY to either reject duplicate registrations by raising an exception
(e.g., ValueError) if the task name already exists, or log a warning message to
explicitly signal that an override is occurring. This ensures task registration
changes are intentional rather than accidental.
In `@tests/rl/test_learning_process.py`:
- Around line 162-173: Add regression tests to verify that both
RescorlaWagnerDeltaRule and the learning rule at lines 259-267 properly validate
action indices in their update() and update_python() methods. Create test
methods that assert RuntimeError is raised when negative action indices are
passed to update(), and separate test methods that assert RuntimeError is raised
when out-of-range action indices are passed to update(). Repeat these validation
tests for both update() and update_python() methods across both learning rule
classes to ensure the zero-based action contract is enforced and caught in CI.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro Plus
Run ID: 13047bae-35fe-4e77-ab17-f7642f376e75
⛔ Files ignored due to path filters (2)
docs/core_tutorials/assets/pedersen_frank_2020_rlddm_fig1.pngis excluded by!**/*.pnguv.lockis excluded by!**/*.lock
📒 Files selected for processing (25)
README.mddocs/api/rlssm.mddocs/contributing/README.mddocs/core_tutorials/rlssm_advanced_tutorial.ipynbdocs/core_tutorials/rlssm_tutorial.ipynbmkdocs.ymlpyproject.tomlssms/__init__.pyssms/rl/__init__.pyssms/rl/compiled.pyssms/rl/config.pyssms/rl/env.pyssms/rl/learning.pyssms/rl/preset.pyssms/rl/simulator.pyssms/rl/validation.pytests/rl/__init__.pytests/rl/test_compiled_model.pytests/rl/test_data_validation.pytests/rl/test_hssm_bridge_contract.pytests/rl/test_hssm_compatibility.pytests/rl/test_learning_process.pytests/rl/test_rl_config.pytests/rl/test_rl_simulator.pytests/rl/test_task_environment.py
Add get_participant_input_fields with a backward-compatible alias, validate output/backend options via get_args, and unify choice extraction behind a single _extract_choice dispatcher. Co-authored-by: Cursor <cursoragent@cursor.com>
Rename _normalize_theta to _expand_participant_theta to reflect that the method tiles scalar or participant-wise theta into per-participant dicts. Co-authored-by: Cursor <cursoragent@cursor.com>
Document external-only validate_data usage, simulator PPC as tutorial/smoke tooling rather than PyMC PPC, and derived participant_contract metadata. Co-authored-by: Cursor <cursoragent@cursor.com>
Raise a clear ValueError when learning_backend='auto' hits a JAX-only learning process without JAX installed instead of silently falling back. Co-authored-by: Cursor <cursoragent@cursor.com>
Reject non-int and out-of-range choice values before updating Q-values so wrong-arm learning updates cannot happen silently. Co-authored-by: Cursor <cursoragent@cursor.com>
Report invalid response/rt dtypes instead of crashing, validate JAX trial responses against response_to_choice before scan, and preserve observed participant_id and trial_id types in PPC output. Co-authored-by: Cursor <cursoragent@cursor.com>
Introduce a base TaskEnvironment protocol for context-only tasks and a DiscreteChoiceEnvironment sub-protocol for response/choice mapping. Bandit environments implement the latter with an n_arms alias for n_choices. ModelConfig now requires discrete-choice environments for current RLSSM models. Also guard register_task() against silent overwrites. Co-authored-by: Cursor <cursoragent@cursor.com>
Replace compiled.py/CompiledModel/compile() with assembled.py, AssembledModel, assemble(), and assemble_participant_fn(). Adopt StrEnum for backend and output modes. Breaking change; HSSM bridge updates tracked separately in hssm-rlssm-api handoff doc. Co-authored-by: Cursor <cursoragent@cursor.com>
Clarify that _ssm_config is strictly derived, document TaskEnvironment protocols and the assembled inference API in docs/api/rlssm.md, and add helper docstrings in validation.py. Co-authored-by: Cursor <cursoragent@cursor.com>
Refresh basic and advanced tutorial notebooks to use assemble(), AssembledModel, and assemble_participant_fn() after the compiled rename. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@ssms/rl/assembled.py`:
- Around line 316-321: The code at lines 316-321 and the corresponding location
at lines 353-359 are silently casting response values to integers using
int(value), which converts non-integer values like 1.9 to 1 instead of rejecting
them. Add validation to check that responses are actually valid integers before
processing them. Instead of directly casting to int, verify that each response
value either is already an integer type or can be losslessly converted to an
integer (e.g., by checking if the value equals its integer conversion), and skip
or reject non-integer values to fail fast rather than silently corrupting the
data. Apply this validation at both affected locations in the code where
responses are being processed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro Plus
Run ID: ff214e8b-4f64-4b57-80a7-ab6884913ea0
📒 Files selected for processing (18)
docs/api/rlssm.mddocs/core_tutorials/rlssm_advanced_tutorial.ipynbdocs/core_tutorials/rlssm_tutorial.ipynbssms/rl/__init__.pyssms/rl/assembled.pyssms/rl/config.pyssms/rl/env.pyssms/rl/learning.pyssms/rl/simulator.pyssms/rl/validation.pytests/rl/test_assembled_model.pytests/rl/test_data_validation.pytests/rl/test_hssm_bridge_contract.pytests/rl/test_hssm_compatibility.pytests/rl/test_learning_process.pytests/rl/test_rl_config.pytests/rl/test_rl_simulator.pytests/rl/test_task_environment.py
✅ Files skipped from review due to trivial changes (1)
- docs/api/rlssm.md
🚧 Files skipped from review as they are similar to previous changes (7)
- ssms/rl/env.py
- tests/rl/test_hssm_bridge_contract.py
- tests/rl/test_data_validation.py
- ssms/rl/learning.py
- tests/rl/test_rl_config.py
- ssms/rl/validation.py
- ssms/rl/simulator.py
| unmapped = sorted( | ||
| { | ||
| int(value) | ||
| for value in np.unique(responses) | ||
| if int(value) not in mapping_keys | ||
| } |
There was a problem hiding this comment.
Reject non-integer response labels before mapping to choices.
Line 318 and Line 353/Line 359 coerce responses with integer casting, so values like 1.9 are silently treated as 1. That can drive incorrect learning updates instead of failing fast.
Proposed fix
def _validate_trial_response_labels(
self,
subject_trials,
field_to_idx: dict[str, int],
response_field: str,
) -> None:
"""Raise when trial responses are absent from ``response_to_choice``."""
trials = np.asarray(subject_trials)
if trials.ndim == 1:
trials = trials.reshape(1, -1)
- responses = trials[:, field_to_idx[response_field]]
+ responses = np.asarray(trials[:, field_to_idx[response_field]], dtype=np.float64)
+ rounded = np.rint(responses)
+ if not np.all(np.equal(responses, rounded)):
+ bad = np.unique(responses[responses != rounded]).tolist()
+ raise ValueError(
+ "Trial responses must be integer-coded labels. "
+ f"Got non-integer values: {bad}."
+ )
+ responses = rounded.astype(np.int64)
mapping_keys = set(self.response_to_choice.keys())
unmapped = sorted(
{
int(value)
for value in np.unique(responses)
if int(value) not in mapping_keys
}
)Also applies to: 353-359
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@ssms/rl/assembled.py` around lines 316 - 321, The code at lines 316-321 and
the corresponding location at lines 353-359 are silently casting response values
to integers using int(value), which converts non-integer values like 1.9 to 1
instead of rejecting them. Add validation to check that responses are actually
valid integers before processing them. Instead of directly casting to int,
verify that each response value either is already an integer type or can be
losslessly converted to an integer (e.g., by checking if the value equals its
integer conversion), and skip or reject non-integer values to fail fast rather
than silently corrupting the data. Apply this validation at both affected
locations in the code where responses are being processed.
Addresses #279
Summary
Adds a modular
ssms.rlframework for simulating reinforcement-learning sequential sampling models (RLSSMs).The new API composes three pieces:
ssm-simulatorsSSM backends such asangleorddmThe simulator runs the trial-wise RLSSM loop:
This keeps RLSSM simulation in Python and reuses the existing SSM simulator engine; no Cython simulator changes are required.
Public API
Recommended usage:
New public surface:
Old developmental names such as RLSSMSimulator, RLSSMModelConfig, get_rlssm_preset, and list_rlssm_presets are intentionally not kept as public aliases.
Reviewer Guide
Suggested review order:
Core learning/task primitives
Structural model config and presets
Trial-wise simulation behavior
Data validation and PPC support
HSSM bridge-facing assembled interface
User-facing docs and tutorials
The large diff is mostly contained inside the new ssms/rl package, its test suite, docs/tutorial notebooks, and uv.lock.
Design Notes
A few important choices are intentional:
Deferred Follow-Ups
Some broader architecture work is intentionally deferred to smaller follow-up PRs:
Task-environment expansion beyond the current bandit support is also deferred to a later Milestone 5 PR.
Validation
Checks run during PR prep:
uv run pytest tests/rl -q --no-cov
uv run pytest tests/rl tests/test_hssm_support.py tests/test_simulator.py -q --no-cov
uv run pytest
uv run ruff check .
uv run ruff format --check .
python -m json.tool docs/core_tutorials/rlssm_tutorial.ipynb
python -m json.tool docs/core_tutorials/rlssm_advanced_tutorial.ipynb
MPLCONFIGDIR=/tmp/.mpl uv run --extra docs mkdocs build
Summary by CodeRabbit
Release Notes
New Features
Documentation
Dependencies
Tests