Skip to content

Add modular RLSSM simulator framework#278

Open
krishnbera wants to merge 58 commits into
mainfrom
feature/rlssm-simulator
Open

Add modular RLSSM simulator framework#278
krishnbera wants to merge 58 commits into
mainfrom
feature/rlssm-simulator

Conversation

@krishnbera

@krishnbera krishnbera commented May 19, 2026

Copy link
Copy Markdown
Member

Addresses #279

Summary

Adds a modular ssms.rl framework for simulating reinforcement-learning sequential sampling models (RLSSMs).

The new API composes three pieces:

  • a learning process, such as Rescorla-Wagner delta or dual-alpha learning
  • a decision process, using existing ssm-simulators SSM backends such as angle or ddm
  • a task environment, currently generic Bernoulli/Gaussian bandits

The simulator runs the trial-wise RLSSM loop:

  1. compute trial-wise SSM parameters from the current learning state
  2. simulate one SSM trial
  3. map the raw SSM response label to a zero-based learning choice
  4. generate or copy task context/feedback
  5. update the learning state
  6. continue to the next trial

This keeps RLSSM simulation in Python and reuses the existing SSM simulator engine; no Cython simulator changes are required.

Public API

Recommended usage:

import ssms.rl as rl

config = rl.preset.get("2AB_RW_Angle")
sim = rl.Simulator(config)
data = sim.simulate(theta={...}, n_trials=200, n_participants=20)

New public surface:

  • rl.ModelConfig
  • rl.Simulator
  • rl.AssembledModel
  • rl.resolve_model
  • rl.env
  • rl.learning
  • rl.preset

Old developmental names such as RLSSMSimulator, RLSSMModelConfig, get_rlssm_preset, and list_rlssm_presets are intentionally not kept as public aliases.

Reviewer Guide

Suggested review order:

  1. Core learning/task primitives

    • ssms/rl/learning.py
    • ssms/rl/env.py
    • tests/rl/test_learning_process.py
    • tests/rl/test_task_environment.py
  2. Structural model config and presets

    • ssms/rl/config.py
    • ssms/rl/preset.py
    • tests/rl/test_rl_config.py
    • tests/rl/test_hssm_compatibility.py
  3. Trial-wise simulation behavior

    • ssms/rl/simulator.py
    • tests/rl/test_rl_simulator.py
  4. Data validation and PPC support

    • ssms/rl/validation.py
    • tests/rl/test_data_validation.py
  5. HSSM bridge-facing assembled interface

    • ssms/rl/assembled.py
    • tests/rl/test_assembled_model.py
    • tests/rl/test_hssm_bridge_contract.py
  6. User-facing docs and tutorials

    • docs/api/rlssm.md
    • docs/core_tutorials/rlssm_tutorial.ipynb
    • docs/core_tutorials/rlssm_advanced_tutorial.ipynb

The large diff is mostly contained inside the new ssms/rl package, its test suite, docs/tutorial notebooks, and uv.lock.

Design Notes

A few important choices are intentional:

  • response stores the raw SSM response label, while optional choice stores the zero-based learning choice.
  • Task rewards, feedback, conditions, and similar trial variables are represented through generic context_fields.
  • ModelConfig describes structure only; concrete parameter values are passed to Simulator.simulate(...) as theta.
  • AssembledModel exposes a package-neutral interface for downstream HSSM integration without importing HSSM or PyTensor in ssm-simulators.
  • Generative simulation and observed-history-conditioned PPC simulation are both supported through Simulator.simulate(mode=...).

Deferred Follow-Ups

Some broader architecture work is intentionally deferred to smaller follow-up PRs:

Task-environment expansion beyond the current bandit support is also deferred to a later Milestone 5 PR.

Validation

Checks run during PR prep:

uv run pytest tests/rl -q --no-cov
uv run pytest tests/rl tests/test_hssm_support.py tests/test_simulator.py -q --no-cov
uv run pytest
uv run ruff check .
uv run ruff format --check .
python -m json.tool docs/core_tutorials/rlssm_tutorial.ipynb
python -m json.tool docs/core_tutorials/rlssm_advanced_tutorial.ipynb
MPLCONFIGDIR=/tmp/.mpl uv run --extra docs mkdocs build

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced RLSSM simulation framework with generative and posterior-predictive modes
    • Added bandit task environments with Bernoulli and Gaussian reward distributions
    • Implemented Rescorla-Wagner learning processes with optional dual-alpha variant
    • Provided preset configurations for common RLSSM models
    • Added optional JAX backend for differentiable learning
  • Documentation

    • Added comprehensive RLSSM API reference and tutorials
    • Updated installation guide with JAX setup instructions
  • Dependencies

    • Added optional JAX/JAXLib dependency support for development
  • Tests

    • Added extensive test coverage for RLSSM components and integrations

krishnbera and others added 10 commits May 13, 2026 14:45
Milestone 1, Commit 1: Defines the LearningProcess protocol (the handshake
contract between learning and decision processes) and the first built-in
implementation — RescorlaWagnerDeltaRule — which is numerically equivalent
to HSSM's compute_v_trial_wise(). Includes 13 unit tests covering Q-value
trajectories, drift ordering, HSSM numerical equivalence, and protocol compliance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 2: Defines the TaskEnvironment protocol for reward
generation, TwoArmedBandit (Bernoulli bandit with configurable per-arm
probabilities), and TaskConfig convenience dataclass for common paradigms.
Includes 14 unit tests covering reward statistics, reproducibility,
input validation, protocol compliance, and TaskConfig builder.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 3: Defines RLSSMModelConfig — the structural model
specification that resolves the handshake between learning process and
decision process (SSM). Auto-derives list_params, bounds, and defaults
from components. Includes validate() for config consistency checking and
to_hssm_config_dict() for bridging to HSSM's RLSSMConfig. 13 tests cover
auto-derivation, handshake validation, computed_param_mapping, TaskConfig
auto-build, and HSSM dict contract.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 4: Implements the core RLSSMSimulator class that runs
the trial-by-trial interleaved loop: compute SSM params from learning state,
simulate one SSM trial, observe choice, generate reward, update learning.
Reuses ssm-simulators' existing simulator() with n_samples=1 — all 40+ SSM
models work as decision processes. Includes 15 tests covering DataFrame
output shape, balanced panel, reproducibility, theta validation, edge cases,
and omission handling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Milestone 1, Commit 5: Adds the preset registry (register/get/list_rlssm_preset)
with rlssm1 preset (RW delta rule + angle SSM + two-armed bandit). Wires up the
ssms.rl public API with full __all__ exports and adds `from . import rl` to
ssms/__init__.py. Fixes circular import in rl_simulator.py (OMISSION_SENTINEL).
Includes 13 contract tests for HSSM compatibility (output dtypes, no NaNs,
config dict schema) and registry smoke tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov

codecov Bot commented May 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 95.37815% with 55 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ssms/rl/config.py 93.79% 16 Missing ⚠️
ssms/rl/learning.py 90.44% 15 Missing ⚠️
ssms/rl/assembled.py 95.21% 9 Missing ⚠️
ssms/rl/env.py 94.48% 8 Missing ⚠️
ssms/rl/validation.py 96.08% 7 Missing ⚠️
Flag Coverage Δ
unittests 93.38% <95.37%> (+1.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
ssms/rl/preset.py 100.00% <100.00%> (ø)
ssms/rl/simulator.py 100.00% <100.00%> (ø)
ssms/rl/validation.py 96.08% <96.08%> (ø)
ssms/rl/env.py 94.48% <94.48%> (ø)
ssms/rl/assembled.py 95.21% <95.21%> (ø)
ssms/rl/learning.py 90.44% <90.44%> (ø)
ssms/rl/config.py 93.79% <93.79%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@krishnbera krishnbera marked this pull request as ready for review May 19, 2026 23:36
Copilot AI review requested due to automatic review settings May 19, 2026 23:36

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new modular RLSSM simulation framework under ssms.rl, designed to interleave trial-wise reinforcement learning updates with existing SSM decision simulators and to export HSSM-compatible configuration/data.

Changes:

  • Added ssms.rl core components: ModelConfig, Simulator, task environments (bandits), learning rules (Rescorla–Wagner variants), and a preset registry.
  • Added comprehensive RLSSM-focused tests (learning, env, simulator behavior, and HSSM compatibility/contract checks).
  • Added an MkDocs tutorial entry for the new RLSSM simulator workflow.

Reviewed changes

Copilot reviewed 13 out of 15 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
ssms/rl/config.py Defines the structural RLSSM configuration and HSSM export support.
ssms/rl/simulator.py Implements the interleaved learning + SSM simulation loop and output formatting.
ssms/rl/env.py Adds a task environment protocol plus Bernoulli/Gaussian bandit implementations and task registry.
ssms/rl/learning.py Adds the learning process protocol and Rescorla–Wagner learning rules.
ssms/rl/preset.py Adds an RLSSM preset registry and a built-in rlssm1 preset.
ssms/rl/__init__.py Exposes the public ssms.rl API surface.
ssms/__init__.py Re-exports the rl module at the package top level.
tests/rl/test_task_environment.py Tests bandit environment behavior, validation, and task config building.
tests/rl/test_learning_process.py Tests RW learning rules’ numerical behavior and protocol compliance.
tests/rl/test_rl_config.py Tests config auto-derivation, handshake validation, response mapping, and HSSM dict export.
tests/rl/test_rl_simulator.py Tests simulation output schema, reproducibility, omission handling, and response/action mapping.
tests/rl/test_hssm_compatibility.py Contract tests for HSSM consumability and preset registry behavior.
tests/rl/__init__.py Initializes the RL test package.
mkdocs.yml Adds the RLSSM tutorial notebook to the docs nav.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread ssms/rl/env.py Outdated
Comment thread ssms/rl/env.py Outdated
Comment thread ssms/rl/config.py Outdated
Comment thread ssms/rl/config.py Outdated
Comment thread ssms/rl/simulator.py Outdated
Comment thread ssms/rl/simulator.py Outdated
Comment thread ssms/rl/learning.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 15 changed files in this pull request and generated 2 comments.

Comment thread ssms/rl/config.py
Comment thread ssms/rl/env.py Outdated
krishnbera and others added 7 commits June 8, 2026 16:10
Brings main's docs/branding refresh and simulator updates into the RLSSM
feature branch (PR #278). The only conflict was uv.lock; resolved by taking
main's lockfile as base and regenerating with `uv lock` so it stays consistent
with the merged pyproject.toml. uv.lock remains tracked (repo convention).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@cpaniaguam cpaniaguam left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a thorough review -- missing some files I didn't comment on.

Comment thread docs/contributing/README.md
Comment thread ssms/rl/compiled.py Outdated
Comment thread ssms/rl/assembled.py
Comment thread ssms/rl/compiled.py Outdated
Comment thread ssms/rl/compiled.py Outdated
Comment thread ssms/rl/assembled.py
Comment thread ssms/rl/compiled.py Outdated
Comment thread ssms/rl/compiled.py Outdated
Comment thread ssms/rl/compiled.py Outdated
Comment thread ssms/rl/compiled.py Outdated
krishnbera and others added 2 commits June 16, 2026 10:28
Reject empty list_params during ModelConfig validation and stop treating
empty bandit reward strings as the default bernoulli reward.

Co-authored-by: Cursor <cursoragent@cursor.com>
Remove unused compiled.py sentinel, simplify computed-param mapping,
document direct compile validation, and add uv sync --extra jax docs.

Co-authored-by: Cursor <cursoragent@cursor.com>
@coderabbitai

coderabbitai Bot commented Jun 16, 2026

Copy link
Copy Markdown

Review Change Stack

Important

Review skipped

Review was skipped as selected files did not have any reviewable changes.

💤 Files selected but had no reviewable changes (1)
  • docs/core_tutorials/rlssm_tutorial.ipynb
⛔ Files ignored due to path filters (1)
  • docs/core_tutorials/assets/rlssm.jpg is excluded by !**/*.jpg
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: bf5b77b6-07c5-4e6c-8951-ed5d16380b4c

📥 Commits

Reviewing files that changed from the base of the PR and between 29db7fd and 16e747f.

⛔ Files ignored due to path filters (1)
  • docs/core_tutorials/assets/rlssm.jpg is excluded by !**/*.jpg
📒 Files selected for processing (1)
  • docs/core_tutorials/rlssm_tutorial.ipynb

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds the complete ssms.rl subpackage implementing Reinforcement Learning Sequential Sampling Models. The package includes task environment protocols and Bandit implementations, Rescorla-Wagner learning processes with optional JAX gradient support, a ModelConfig dataclass with SSM parameter handshake and panel data validation, generative and PPC simulators, assembled participant functions (Python row-loop or JAX lax.scan), a preset registry with HSSM bridge export, and a corresponding API documentation page.

Changes

ssms.rl RLSSM Framework

Layer / File(s) Summary
Package wiring and JAX optional dependency
pyproject.toml, README.md, docs/contributing/README.md, ssms/__init__.py, ssms/rl/__init__.py
ssms.rl is registered as a public ssms export; ssms/rl/__init__.py defines __all__; jax>=0.4.20/jaxlib>=0.4.20 are added to dev dependencies; README and contributing docs document the optional JAX install.
TaskEnvironment protocol and Bandit implementation
tests/rl/test_task_environment.py
TaskEnvironment and DiscreteChoiceEnvironment protocols, Bernoulli/Gaussian reward distributions, Bandit with seeded RNG and "feedback" context, and a TaskConfig-driven task registry are implemented and tested for statistics, reproducibility, validation errors, and duplicate registration handling.
LearningProcess protocol and Rescorla-Wagner rules
ssms/rl/learning.py, tests/rl/test_learning_process.py
LearningProcess runtime-checkable protocol is defined; RescorlaWagnerDeltaRule (NumPy + JAX) and RescorlaWagnerDualAlphaRule (sign-dependent learning rates) are implemented with reset-before-use enforcement; tests cover trajectory correctness, Python/JAX parity, and differentiability via jax.grad.
ModelConfig dataclass and data validation contract
ssms/rl/config.py, ssms/rl/validation.py, tests/rl/test_rl_config.py, tests/rl/test_data_validation.py
ModelConfig.__post_init__ auto-derives choices, response-to-choice mapping, context fields, backend/gradient policy, and SSM parameter handshake; DataValidationReport and validate_rlssm_data() enforce panel balance, contiguous participants, omission sentinels, response label validity, and RT sanity; to_hssm_config_dict() and participant_contract() expose the HSSM-facing contract.
Generative and PPC Simulator
ssms/rl/simulator.py, tests/rl/test_rl_simulator.py
Simulator.simulate() routes between generative (per-participant per-trial learning + SSM loop) and PPC (conditioned on observed history) modes with scalar/participant-wise theta normalization, omission sentinel handling, and backend-dispatch learning state updates; tests cover output shape, reproducibility, theta validation, omission handling, response-choice mapping, and backend compatibility.
AssembledModel and compiled participant functions
ssms/rl/assembled.py, tests/rl/test_assembled_model.py, tests/rl/test_hssm_bridge_contract.py
AssembledModel.from_config() snapshots validated model metadata and resolves backend/gradient; assemble_participant_fn() produces a Python row-loop or JAX lax.scan callable returning arrays or dicts; resolve_model() accepts preset names or existing configs; tests cover Python/JAX parity, multi-output mapping, and the HSSM bridge contract.
Preset registry and HSSM compatibility tests
ssms/rl/preset.py, tests/rl/test_hssm_compatibility.py
register(), get(), list(), and info() manage named ModelConfig factories; the built-in 2AB_RW_Angle preset is registered; tests validate simulator output dtypes, to_hssm_config_dict() schema, preset info metadata, custom preset registration, and rl.__all__ export surface.
API documentation and site navigation
docs/api/rlssm.md, mkdocs.yml
docs/api/rlssm.md documents quick-start, ModelConfig fields, theta parameter rules, simulation modes, data validation contract, PPC example, assembled-model workflow, and HSSM bridge path; mkdocs.yml adds RLSSM Tutorial, RLSSM Tutorial (Advanced), and the ssms.rl API navigation entry.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant Simulator
  participant ModelConfig
  participant LearningProcess
  participant TaskEnvironment
  participant SSM as ssms.simulator

  User->>ModelConfig: ModelConfig(decision_process, learning_process, task)
  ModelConfig->>ModelConfig: __post_init__ (derive choices, handshake, backend)
  User->>Simulator: Simulator(config)
  User->>Simulator: simulate(theta, mode="generative", n_participants, n_trials)
  Simulator->>Simulator: validate_theta_keys, expand_participant_theta
  loop per participant
    Simulator->>LearningProcess: reset / init_state
    Simulator->>TaskEnvironment: reset(seed)
    loop per trial
      Simulator->>LearningProcess: compute_python/compute_jax(state, params, context)
      LearningProcess-->>Simulator: SSM params {v, a, ...}
      Simulator->>SSM: simulator(model, merged_params, n_samples=1)
      SSM-->>Simulator: rt, response
      Simulator->>LearningProcess: update_python/update_jax(state, params, context)
      LearningProcess-->>Simulator: new state
    end
  end
  Simulator-->>User: pd.DataFrame (balanced panel)
Loading
sequenceDiagram
  participant HSSM
  participant ModelConfig
  participant AssembledModel
  participant LearningProcess

  HSSM->>ModelConfig: to_hssm_config_dict()
  ModelConfig-->>HSSM: dict (participant_contract, placeholders, learning_process_kind)
  HSSM->>ModelConfig: assemble(backend="jax")
  ModelConfig->>AssembledModel: from_config(config, backend)
  AssembledModel->>AssembledModel: resolve backend, gradient, snapshot metadata
  HSSM->>AssembledModel: assemble_participant_fn(input_fields, output="dict")
  AssembledModel->>LearningProcess: init_jax_state / compute_jax / update_jax
  AssembledModel-->>HSSM: participant_fn (JAX lax.scan callable)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Poem

🐇 A new subpackage hops into the code,
With bandits and Q-values lighting the road.
JAX gradients flow through each learning update,
While ModelConfig validates trial panel state.
The preset 2AB_RW_Angle stands ready to run,
RLSSM simulation — a framework begun! 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title 'Add modular RLSSM simulator framework' clearly summarizes the main change—introducing a new modular RLSSM simulator framework with learning processes, task environments, configuration, and simulation engine.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/rlssm-simulator

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
ssms/rl/env.py (1)

200-203: ⚡ Quick win

Prevent silent task-name overrides in the registry.

register_task() currently replaces existing entries without signaling. A duplicate registration can silently change TaskConfig behavior at runtime; make override explicit (or reject duplicates).

Suggested change
-def register_task(task: str, builder: TaskEnvironmentBuilder) -> None:
+def register_task(
+    task: str, builder: TaskEnvironmentBuilder, *, overwrite: bool = False
+) -> None:
     """Register a task environment builder for ``TaskConfig``."""
+    if not overwrite and task in _TASK_REGISTRY:
+        raise ValueError(
+            f"Task '{task}' is already registered. "
+            "Pass overwrite=True to replace it."
+        )
     _TASK_REGISTRY[task] = builder
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@ssms/rl/env.py` around lines 200 - 203, The register_task() function silently
overwrites existing entries in _TASK_REGISTRY without any warning or error,
which can cause unexpected behavior changes at runtime. Add a check before
assigning to _TASK_REGISTRY to either reject duplicate registrations by raising
an exception (e.g., ValueError) if the task name already exists, or log a
warning message to explicitly signal that an override is occurring. This ensures
task registration changes are intentional rather than accidental.
tests/rl/test_learning_process.py (1)

162-173: ⚡ Quick win

Add regression tests for invalid action indices.

Please add explicit tests for negative and out-of-range actions on update() / update_python() for both learning rules, so the zero-based action contract is enforced in CI.

Also applies to: 259-267

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/rl/test_learning_process.py` around lines 162 - 173, Add regression
tests to verify that both RescorlaWagnerDeltaRule and the learning rule at lines
259-267 properly validate action indices in their update() and update_python()
methods. Create test methods that assert RuntimeError is raised when negative
action indices are passed to update(), and separate test methods that assert
RuntimeError is raised when out-of-range action indices are passed to update().
Repeat these validation tests for both update() and update_python() methods
across both learning rule classes to ensure the zero-based action contract is
enforced and caught in CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@ssms/rl/compiled.py`:
- Around line 283-287: The JAX path in the compute function (lines 283–287)
silently defaults unknown response labels to action 0, whereas the Python path
(line 270) raises a KeyError when encountering unmapped labels. Add pre-scan
validation at the start of the compute function that checks whether all unique
response values in the data exist as keys in self.response_to_choice, and raise
a ValueError listing any unmapped labels found. This ensures consistent error
handling between the two code paths and prevents silent corruption of learning
trajectories.

In `@ssms/rl/config.py`:
- Around line 301-305: The auto mode fallback in the learning_backend property
returns "python" without verifying that "python" is actually in the available
backends list. If the learning process only supports "jax" and JAX is
unavailable, this will return an unsupported backend and defer the error to
runtime. In the conditional block where learning_backend equals "auto", after
checking if "jax" is available and in the available set, add a check to ensure
"python" is also in the available set before returning it as the fallback. If
"python" is not available either, raise a configuration error immediately with a
clear message indicating that no supported learning backend is available.

In `@ssms/rl/learning.py`:
- Around line 223-235: Add validation in both update_python() methods (the one
shown in the diff and the one in RescorlaWagnerDualAlphaRule) to verify that the
choice index is within valid bounds before using it to index into q_values.
After extracting and converting choice to int, validate that it is greater than
or equal to 0 and less than the length of q_values. If the choice is invalid,
raise an appropriate error to fail fast rather than silently updating the wrong
arm.

In `@ssms/rl/simulator.py`:
- Around line 139-140: Remove the int() type coercion being applied to
participant_id and trial_id across multiple locations in the simulator.py file.
At line 139-140 where participant_id is converted with int(participant_id),
remove the int() call to preserve the original identifier type. Apply the same
fix at line 158-159 where participant_id is similarly coerced, and at line 410
where trial_id is coerced to int(). This will ensure that non-integer IDs are
not corrupted and that original identifier values are preserved throughout the
codebase.

In `@ssms/rl/validation.py`:
- Around line 300-301: The validation logic at lines 300, 313-314, and 339-341
performs type conversions (int(v) and dtype=float) that can raise unhandled
exceptions when encountering non-numeric values in response/rt columns instead
of gracefully reporting validation failures. Wrap the int(v) conversion at line
300 in a try-except block to catch ValueError exceptions for non-numeric values
and collect them as invalid choices to be reported. Similarly, wrap the
dtype=float conversion at line 340 in a try-except block to catch conversion
errors and report them through the DataValidationReport instead of crashing.
Ensure that all malformed or non-numeric values are captured and included in the
validation report's issues list rather than causing hard exceptions that bypass
the validation reporting mechanism.

---

Nitpick comments:
In `@ssms/rl/env.py`:
- Around line 200-203: The register_task() function silently overwrites existing
entries in _TASK_REGISTRY without any warning or error, which can cause
unexpected behavior changes at runtime. Add a check before assigning to
_TASK_REGISTRY to either reject duplicate registrations by raising an exception
(e.g., ValueError) if the task name already exists, or log a warning message to
explicitly signal that an override is occurring. This ensures task registration
changes are intentional rather than accidental.

In `@tests/rl/test_learning_process.py`:
- Around line 162-173: Add regression tests to verify that both
RescorlaWagnerDeltaRule and the learning rule at lines 259-267 properly validate
action indices in their update() and update_python() methods. Create test
methods that assert RuntimeError is raised when negative action indices are
passed to update(), and separate test methods that assert RuntimeError is raised
when out-of-range action indices are passed to update(). Repeat these validation
tests for both update() and update_python() methods across both learning rule
classes to ensure the zero-based action contract is enforced and caught in CI.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: 13047bae-35fe-4e77-ab17-f7642f376e75

📥 Commits

Reviewing files that changed from the base of the PR and between 0b48391 and fe39017.

⛔ Files ignored due to path filters (2)
  • docs/core_tutorials/assets/pedersen_frank_2020_rlddm_fig1.png is excluded by !**/*.png
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (25)
  • README.md
  • docs/api/rlssm.md
  • docs/contributing/README.md
  • docs/core_tutorials/rlssm_advanced_tutorial.ipynb
  • docs/core_tutorials/rlssm_tutorial.ipynb
  • mkdocs.yml
  • pyproject.toml
  • ssms/__init__.py
  • ssms/rl/__init__.py
  • ssms/rl/compiled.py
  • ssms/rl/config.py
  • ssms/rl/env.py
  • ssms/rl/learning.py
  • ssms/rl/preset.py
  • ssms/rl/simulator.py
  • ssms/rl/validation.py
  • tests/rl/__init__.py
  • tests/rl/test_compiled_model.py
  • tests/rl/test_data_validation.py
  • tests/rl/test_hssm_bridge_contract.py
  • tests/rl/test_hssm_compatibility.py
  • tests/rl/test_learning_process.py
  • tests/rl/test_rl_config.py
  • tests/rl/test_rl_simulator.py
  • tests/rl/test_task_environment.py

Comment thread ssms/rl/assembled.py
Comment thread ssms/rl/config.py
Comment thread ssms/rl/learning.py
Comment thread ssms/rl/simulator.py Outdated
Comment thread ssms/rl/validation.py Outdated
krishnbera and others added 3 commits June 16, 2026 10:48
Add get_participant_input_fields with a backward-compatible alias,
validate output/backend options via get_args, and unify choice extraction
behind a single _extract_choice dispatcher.

Co-authored-by: Cursor <cursoragent@cursor.com>
Rename _normalize_theta to _expand_participant_theta to reflect that the
method tiles scalar or participant-wise theta into per-participant dicts.

Co-authored-by: Cursor <cursoragent@cursor.com>
Document external-only validate_data usage, simulator PPC as tutorial/smoke
tooling rather than PyMC PPC, and derived participant_contract metadata.

Co-authored-by: Cursor <cursoragent@cursor.com>
@krishnbera krishnbera linked an issue Jun 16, 2026 that may be closed by this pull request
krishnbera and others added 3 commits June 16, 2026 14:11
Raise a clear ValueError when learning_backend='auto' hits a JAX-only
learning process without JAX installed instead of silently falling back.

Co-authored-by: Cursor <cursoragent@cursor.com>
Reject non-int and out-of-range choice values before updating Q-values so
wrong-arm learning updates cannot happen silently.

Co-authored-by: Cursor <cursoragent@cursor.com>
Report invalid response/rt dtypes instead of crashing, validate JAX trial
responses against response_to_choice before scan, and preserve observed
participant_id and trial_id types in PPC output.

Co-authored-by: Cursor <cursoragent@cursor.com>
krishnbera and others added 4 commits June 16, 2026 16:42
Introduce a base TaskEnvironment protocol for context-only tasks and a
DiscreteChoiceEnvironment sub-protocol for response/choice mapping. Bandit
environments implement the latter with an n_arms alias for n_choices.
ModelConfig now requires discrete-choice environments for current RLSSM
models. Also guard register_task() against silent overwrites.

Co-authored-by: Cursor <cursoragent@cursor.com>
Replace compiled.py/CompiledModel/compile() with assembled.py,
AssembledModel, assemble(), and assemble_participant_fn(). Adopt StrEnum
for backend and output modes. Breaking change; HSSM bridge updates tracked
separately in hssm-rlssm-api handoff doc.

Co-authored-by: Cursor <cursoragent@cursor.com>
Clarify that _ssm_config is strictly derived, document TaskEnvironment
protocols and the assembled inference API in docs/api/rlssm.md, and add
helper docstrings in validation.py.

Co-authored-by: Cursor <cursoragent@cursor.com>
Refresh basic and advanced tutorial notebooks to use assemble(),
AssembledModel, and assemble_participant_fn() after the compiled rename.

Co-authored-by: Cursor <cursoragent@cursor.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@ssms/rl/assembled.py`:
- Around line 316-321: The code at lines 316-321 and the corresponding location
at lines 353-359 are silently casting response values to integers using
int(value), which converts non-integer values like 1.9 to 1 instead of rejecting
them. Add validation to check that responses are actually valid integers before
processing them. Instead of directly casting to int, verify that each response
value either is already an integer type or can be losslessly converted to an
integer (e.g., by checking if the value equals its integer conversion), and skip
or reject non-integer values to fail fast rather than silently corrupting the
data. Apply this validation at both affected locations in the code where
responses are being processed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro Plus

Run ID: ff214e8b-4f64-4b57-80a7-ab6884913ea0

📥 Commits

Reviewing files that changed from the base of the PR and between 51fe4a4 and 29db7fd.

📒 Files selected for processing (18)
  • docs/api/rlssm.md
  • docs/core_tutorials/rlssm_advanced_tutorial.ipynb
  • docs/core_tutorials/rlssm_tutorial.ipynb
  • ssms/rl/__init__.py
  • ssms/rl/assembled.py
  • ssms/rl/config.py
  • ssms/rl/env.py
  • ssms/rl/learning.py
  • ssms/rl/simulator.py
  • ssms/rl/validation.py
  • tests/rl/test_assembled_model.py
  • tests/rl/test_data_validation.py
  • tests/rl/test_hssm_bridge_contract.py
  • tests/rl/test_hssm_compatibility.py
  • tests/rl/test_learning_process.py
  • tests/rl/test_rl_config.py
  • tests/rl/test_rl_simulator.py
  • tests/rl/test_task_environment.py
✅ Files skipped from review due to trivial changes (1)
  • docs/api/rlssm.md
🚧 Files skipped from review as they are similar to previous changes (7)
  • ssms/rl/env.py
  • tests/rl/test_hssm_bridge_contract.py
  • tests/rl/test_data_validation.py
  • ssms/rl/learning.py
  • tests/rl/test_rl_config.py
  • ssms/rl/validation.py
  • ssms/rl/simulator.py

Comment thread ssms/rl/assembled.py
Comment on lines +316 to +321
unmapped = sorted(
{
int(value)
for value in np.unique(responses)
if int(value) not in mapping_keys
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-integer response labels before mapping to choices.

Line 318 and Line 353/Line 359 coerce responses with integer casting, so values like 1.9 are silently treated as 1. That can drive incorrect learning updates instead of failing fast.

Proposed fix
 def _validate_trial_response_labels(
     self,
     subject_trials,
     field_to_idx: dict[str, int],
     response_field: str,
 ) -> None:
     """Raise when trial responses are absent from ``response_to_choice``."""
     trials = np.asarray(subject_trials)
     if trials.ndim == 1:
         trials = trials.reshape(1, -1)
-    responses = trials[:, field_to_idx[response_field]]
+    responses = np.asarray(trials[:, field_to_idx[response_field]], dtype=np.float64)
+    rounded = np.rint(responses)
+    if not np.all(np.equal(responses, rounded)):
+        bad = np.unique(responses[responses != rounded]).tolist()
+        raise ValueError(
+            "Trial responses must be integer-coded labels. "
+            f"Got non-integer values: {bad}."
+        )
+    responses = rounded.astype(np.int64)

     mapping_keys = set(self.response_to_choice.keys())
     unmapped = sorted(
         {
             int(value)
             for value in np.unique(responses)
             if int(value) not in mapping_keys
         }
     )

Also applies to: 353-359

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@ssms/rl/assembled.py` around lines 316 - 321, The code at lines 316-321 and
the corresponding location at lines 353-359 are silently casting response values
to integers using int(value), which converts non-integer values like 1.9 to 1
instead of rejecting them. Add validation to check that responses are actually
valid integers before processing them. Instead of directly casting to int,
verify that each response value either is already an integer type or can be
losslessly converted to an integer (e.g., by checking if the value equals its
integer conversion), and skip or reject non-integer values to fail fast rather
than silently corrupting the data. Apply this validation at both affected
locations in the code where responses are being processed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Modular simulators for RLSSM

4 participants