diff --git a/docs/multiagent.md b/docs/multiagent.md new file mode 100644 index 000000000..4359ad4b4 --- /dev/null +++ b/docs/multiagent.md @@ -0,0 +1,553 @@ +# Multi-Agent Environments + +This guide covers building multi-agent environments where multiple actors interact. See [Environments](environments.md) for single-agent basics and [API Reference](reference.md) for type definitions. + +## Table of Contents + +- [Overview](#overview) +- [Core Components](#core-components) + - [Actor](#actor) + - [Protocol](#protocol) + - [MultiAgentEnv](#multiagentenv) + - [MultiAgentRubric](#multiagentubric) +- [Building a Multi-Agent Environment](#building-a-multi-agent-environment) + - [Turn Management](#turn-management) + - [Building Prompts](#building-prompts) + - [Game Logic](#game-logic) + - [Ending the Game](#ending-the-game) +- [Per-Actor Rewards](#per-actor-rewards) + - [State Splitting](#state-splitting) + - [Per-Actor GRPO](#per-actor-grpo) + - [Frozen Actors](#frozen-actors) +- [Hierarchical Spawning](#hierarchical-spawning) +- [Examples](#examples) + - [Alternating Turns (Twenty Questions)](#alternating-turns-twenty-questions) + - [Simultaneous Moves (Rock Paper Scissors)](#simultaneous-moves-rock-paper-scissors) + - [Hierarchical (Proposer-Solver)](#hierarchical-proposer-solver) + - [Complex Game (Multi-Player Poker)](#complex-game-multi-player-poker) + +## Overview + +Multi-agent environments enable training multiple actors that interact with each other. Key capabilities: + +- **Multiple actors** with distinct system prompts and configurations +- **Turn management** for alternating or simultaneous moves +- **Per-actor rewards** for credit assignment +- **Per-actor GRPO** advantages computed within actor groups +- **Hierarchical spawning** for complex multi-level games + +The architecture separates concerns: + +| Component | Responsibility | +|-----------|----------------| +| `Actor` | Configuration (system prompt, model, trainability) | +| `Protocol` | Registry (wires actors to environments, enables spawning) | +| `MultiAgentEnv` | Game logic (turn order, prompts, win conditions) | +| `MultiAgentRubric` | Scoring (per-actor rewards, per-actor GRPO) | + +## Core Components + +### Actor + +An actor is a trainable entity with a distinct identity: + +```python +from verifiers.envs.actor import Actor + +player1 = Actor( + id="player1", + system_prompt="You are Player 1 in a game...", + max_tokens=512, + is_trainable=True, +) + +judge = Actor( + id="judge", + system_prompt="You are a fair judge...", + is_trainable=False, # Frozen - no gradient updates + model="gpt-4o-mini", # Can use different model +) +``` + +**Fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `id` | `str` | Unique identifier (e.g., `"player1"`, `"guesser"`) | +| `system_prompt` | `str` | The actor's persona/instructions | +| `max_tokens` | `int` | Max response length (default: 4096) | +| `is_trainable` | `bool` | Whether to compute GRPO advantages (default: True) | +| `sampling_args` | `dict` | Per-actor sampling settings | +| `model` | `str \| None` | Model override (None = use default) | +| `client` | `AsyncOpenAI \| None` | Client override (None = use default) | + +### Protocol + +Protocol wires actors to environments and enables cross-environment spawning: + +```python +from verifiers.envs.protocol import Protocol +from verifiers.envs.actor import Actor + +# Define actors +player1 = Actor("player1", system_prompt="...") +player2 = Actor("player2", system_prompt="...") + +# Define environment +env = MyGameEnv(rubric=rubric, max_turns=10) + +# Wire them together +protocol = Protocol( + actors=[player1, player2], + envs=[env], +) +# Now env.protocol is set, and env.get_actor("player1") works +``` + +**Methods:** + +| Method | Returns | Description | +|--------|---------|-------------| +| `get_actor(actor_id)` | `Actor` | Look up actor by ID | +| `get_env(name)` | `Environment` | Look up environment by name | +| `spawn(inputs, client, model, ...)` | `list[State]` | Run child rollouts in target environments | + +### MultiAgentEnv + +`MultiAgentEnv` extends `MultiTurnEnv` with multi-actor support: + +```python +class MyGameEnv(vf.MultiAgentEnv): + name = "MyGame" # For protocol registration + actors = ["player1", "player2"] # Actor IDs this env uses + + # Required: implement these four methods + def get_initial_actor(self, state) -> str: ... + def get_next_actor(self, state) -> str: ... + async def build_actor_prompt(self, actor_id, state) -> Messages: ... + async def on_turn_complete(self, state) -> None: ... + + # Optional: override for simultaneous moves + def get_active_actors(self, state) -> list[str]: ... + + # Optional: final metrics before scoring + async def on_game_end(self, state) -> None: ... +``` + +**Key differences from MultiTurnEnv:** + +| MultiTurnEnv | MultiAgentEnv | +|--------------|---------------| +| Single actor | Multiple actors | +| `env_response()` after each turn | `on_turn_complete()` for game logic | +| Single prompt throughout | `build_actor_prompt()` per actor | +| — | Turn order via `get_initial_actor()` / `get_next_actor()` | + +### MultiAgentRubric + +`MultiAgentRubric` extends `Rubric` with per-actor rewards: + +```python +from verifiers.rubrics.multiagent_rubric import MultiAgentRubric + +rubric = MultiAgentRubric() + +# Global reward (applies to all actors) +rubric.add_reward_func(game_completed) + +# Per-actor rewards +rubric.add_actor_reward_func("player1", player1_win_bonus) +rubric.add_actor_reward_func("player2", player2_win_bonus) + +# Per-actor metrics (weight=0) +rubric.add_actor_metric("guesser", questions_asked) +``` + +**Key differences from Rubric:** + +| Rubric | MultiAgentRubric | +|--------|------------------| +| Single GRPO across all states | Per-actor GRPO (solver vs solver) | +| Global reward functions only | Per-actor reward functions | +| — | Children scored before parents | + +## Building a Multi-Agent Environment + +### Turn Management + +Implement turn order with two methods: + +```python +class TwentyQuestionsEnv(vf.MultiAgentEnv): + actors = ["guesser", "thinker"] + + def get_initial_actor(self, state) -> str: + """Who goes first.""" + return "guesser" + + def get_next_actor(self, state) -> str: + """Who goes next (alternating).""" + current = state["extras"]["current_actor_id"] + return "thinker" if current == "guesser" else "guesser" +``` + +For **simultaneous moves**, override `get_active_actors`: + +```python +class RPSEnv(vf.MultiAgentEnv): + actors = ["player1", "player2"] + + def get_active_actors(self, state) -> list[str]: + """Both players move each round.""" + return ["player1", "player2"] +``` + +### Building Prompts + +`build_actor_prompt` is called before each model response: + +```python +async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """Build the prompt for this actor's turn.""" + actor = self.get_actor(actor_id) # Get Actor config from Protocol + + # Build context from game state + history = self.format_game_history(state) + + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": f"Game history:\n{history}\n\nYour turn:"}, + ] +``` + +The actor's `model`, `client`, and `sampling_args` are automatically used when generating. + +### Game Logic + +`on_turn_complete` is called after each model response: + +```python +async def on_turn_complete(self, state: State) -> None: + """Update game state after a turn.""" + # Get the response that was just added + last_step = state["trajectory"][-1] + response_text = last_step["completion"][-1]["content"] + actor_id = last_step["extras"]["actor_id"] + + # Parse and process + move = self.parse_move(response_text) + state["extras"]["moves"].append((actor_id, move)) + + # Check win condition + if self.check_winner(state): + state["extras"]["winner"] = actor_id +``` + +### Ending the Game + +To end a game, set `state["final_env_response"]`: + +```python +async def on_turn_complete(self, state: State) -> None: + if self.check_game_over(state): + winner = state["extras"]["winner"] + state["final_env_response"] = [ + {"role": "assistant", "content": f"Game over! {winner} wins!"} + ] +``` + +This triggers the `has_final_env_response` stop condition. You can also add custom stop conditions: + +```python +@vf.stop +async def game_won(self, state: State) -> bool: + return state["extras"].get("won", False) +``` + +## Per-Actor Rewards + +### State Splitting + +After a game completes, `MultiAgentEnv.run_group()` splits each game state into per-actor states: + +``` +Game State (6 turns): +trajectory: [p1, p2, p1, p2, p1, p2] + │ + ┌───────────┴───────────┐ + ▼ ▼ + Player1 State Player2 State + traj: [p1, p1, p1] traj: [p2, p2, p2] + prompt: "You are P1" prompt: "You are P2" +``` + +This is handled internally by `create_actor_states()`. Each actor state contains: + +- Only that actor's trajectory steps +- The actor's prompt (from their first turn) +- Shared references to `client`, `model`, `trajectory_id` +- Fresh `reward`, `advantage`, `metrics` fields for scoring + +### Per-Actor GRPO + +`MultiAgentRubric.score_group()` computes advantages within actor groups: + +``` +Without per-actor grouping (bad): + Solver reward=0.8, Proposer reward=0.2 + Mean = 0.5 + Solver advantage = +0.3, Proposer advantage = -0.3 + Unfair comparison across different roles + +With per-actor grouping (good): + Solvers compared to other solvers (mean = 0.75) + Proposers compared to other proposers (mean = 0.25) + Fair comparison within same role +``` + +### Frozen Actors + +Actors with `is_trainable=False` are scored but don't receive gradients: + +```python +thinker = Actor( + "thinker", + system_prompt="Answer yes/no about the secret word.", + is_trainable=False, # Frozen - just answers questions +) + +guesser = Actor( + "guesser", + system_prompt="Guess the word in 20 questions.", + is_trainable=True, # Learning to ask good questions +) +``` + +Frozen actors get `advantage=0` during GRPO computation. + +## Hierarchical Spawning + +Protocol enables spawning child rollouts in other environments: + +```python +class ProposerEnv(vf.MultiAgentEnv): + async def on_turn_complete(self, state: State) -> None: + if self.proposer_submitted_problem(state): + problem = self.extract_problem(state) + + # Spawn solver attempts in SolverEnv + child_states = await self.protocol.spawn( + inputs=[ + {"task": "SolverEnv", "prompt": problem}, + {"task": "SolverEnv", "prompt": problem}, + {"task": "SolverEnv", "prompt": problem}, + ], + client=state["client"], + model=state["model"], + ) + + # Store for later access + state["child_states"] = child_states + + # Score proposer based on solver success + solver_rewards = [s["reward"] for s in child_states] + state["extras"]["solver_success_rate"] = sum(solver_rewards) / len(solver_rewards) +``` + +Child states are automatically included in `run_group()` output for training. + +## Examples + +### Alternating Turns (Twenty Questions) + +```python +class TwentyQuestionsEnv(vf.MultiAgentEnv): + name = "TwentyQuestions" + actors = ["guesser", "thinker"] + + def get_initial_actor(self, state) -> str: + return "guesser" + + def get_next_actor(self, state) -> str: + current = state["extras"]["current_actor_id"] + return "thinker" if current == "guesser" else "guesser" + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + actor = self.get_actor(actor_id) + secret = state["info"]["secret_word"] + history = self.format_qa_history(state) + + if actor_id == "guesser": + content = f"History:\n{history}\n\nAsk a yes/no question or make a final guess." + else: + content = f"The secret word is: {secret}\n\nHistory:\n{history}\n\nAnswer yes or no." + + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": content}, + ] + + async def on_turn_complete(self, state: State) -> None: + actor_id = state["extras"]["current_actor_id"] + response = state["trajectory"][-1]["completion"][-1]["content"] + + if actor_id == "guesser" and "FINAL GUESS:" in response: + guess = self.extract_guess(response) + secret = state["info"]["secret_word"] + state["extras"]["won"] = (guess.lower() == secret.lower()) + state["final_env_response"] = [ + {"role": "assistant", "content": f"{'Correct!' if state['extras']['won'] else 'Wrong!'}"} + ] +``` + +### Simultaneous Moves (Rock Paper Scissors) + +```python +class RPSEnv(vf.MultiAgentEnv): + name = "RPS" + actors = ["player1", "player2"] + + def get_active_actors(self, state) -> list[str]: + return ["player1", "player2"] # Both move each round + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + actor = self.get_actor(actor_id) + history = self.format_history_for(state, actor_id) # Hide opponent's pending move + + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": f"Previous rounds:\n{history}\n\nChoose: rock, paper, or scissors"}, + ] + + async def on_turn_complete(self, state: State) -> None: + # Check if both players have moved this round + recent = state["trajectory"][-2:] + actors_moved = {step["extras"]["actor_id"] for step in recent} + + if actors_moved == {"player1", "player2"}: + # Resolve the round + p1_move = self.parse_move(recent[0]["completion"][-1]["content"]) + p2_move = self.parse_move(recent[1]["completion"][-1]["content"]) + winner = self.determine_winner(p1_move, p2_move) + + state["extras"]["rounds"].append({ + "p1": p1_move, "p2": p2_move, "winner": winner + }) + + if len(state["extras"]["rounds"]) >= 3: + state["final_env_response"] = [ + {"role": "assistant", "content": "Best of 3 complete!"} + ] +``` + +### Hierarchical (Proposer-Solver) + +```python +class ProposerSolverEnv(vf.MultiAgentEnv): + name = "ProposerSolver" + actors = ["proposer"] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # SolverEnv must also be registered with Protocol + + async def on_turn_complete(self, state: State) -> None: + response = state["trajectory"][-1]["completion"][-1]["content"] + problem = self.extract_problem(response) + + # Spawn solvers + child_states = await self.protocol.spawn( + inputs=[ + {"task": "SolverEnv", "prompt": self.format_problem(problem)} + for _ in range(4) # 4 solver attempts + ], + client=state["client"], + model=state["model"], + ) + + state["child_states"] = child_states + state["extras"]["num_solved"] = sum(1 for s in child_states if s["reward"] > 0) + state["final_env_response"] = [{"role": "assistant", "content": "Solvers finished."}] + + +# Reward proposer based on solver success +async def proposer_reward(state: State) -> float: + num_solved = state["extras"].get("num_solved", 0) + return num_solved / 4.0 # Fraction of solvers that succeeded +``` + +Wire with Protocol: + +```python +proposer = Actor("proposer", system_prompt="Generate a challenging math problem.") +solver = Actor("solver", system_prompt="Solve the given problem.") + +proposer_env = ProposerSolverEnv(rubric=proposer_rubric) +solver_env = SolverEnv(rubric=solver_rubric) + +protocol = Protocol( + actors=[proposer, solver], + envs=[proposer_env, solver_env], +) +``` + +### Complex Game (Multi-Player Poker) + +See [`environments/poker_multi/poker_multi.py`](../environments/poker_multi/poker_multi.py) for a full example demonstrating: + +- **Dynamic player count** - 2-9 players via dynamic `self.actors` list +- **Position-based turns** - UTG acts first preflop, dealer button rotates +- **Per-player model configs** - Different models/strategies per player +- **Multiple stop conditions** - `one_player_left`, `hand_complete`, `max_actions_hit` +- **Game phases** - Preflop → flop → turn → river → showdown with betting rounds +- **Per-actor rewards** - Chip profit/loss as fraction of starting stack + +Key patterns from the implementation: + +```python +class PokerMultiEnv(vf.MultiAgentEnv): + name = "poker_multi" + + def __init__(self, num_players: int = 6, **kwargs): + super().__init__(**kwargs) + # Dynamic actor list based on player count + self.actors = [f"player{i}" for i in range(1, num_players + 1)] + + def get_initial_actor(self, state) -> str: + """Position-based: UTG (dealer+3) acts first preflop.""" + dealer_idx = state["extras"]["dealer_idx"] + if self.num_players == 2: + return self.actors[dealer_idx] # Heads-up: dealer acts first + return self.actors[(dealer_idx + 3) % self.num_players] # UTG + + def get_next_actor(self, state) -> str: + """Next player clockwise who hasn't folded and isn't all-in.""" + current_idx = self.actors.index(state["extras"]["current_actor_id"]) + for i in range(1, self.num_players + 1): + candidate = self.actors[(current_idx + i) % self.num_players] + if candidate not in state["extras"]["folded"]: + return candidate + return state["extras"]["current_actor_id"] + + @vf.stop + async def one_player_left(self, state) -> bool: + """End when all others fold.""" + active = [p for p in self.actors if p not in state["extras"]["folded"]] + return len(active) == 1 +``` + +```python +# Per-player model/strategy configuration +PLAYER_CONFIGS = [ + {"endpoint": "model-a", "strategy": "aggressive", "is_trainable": True}, + {"endpoint": "model-b", "strategy": "conservative", "is_trainable": False}, +] + +# Per-actor reward based on chip profit +def player_reward(actor_id: str): + def reward_func(state: State) -> float: + starting = state["extras"]["starting_chips"] + final = state["extras"]["chips"].get(actor_id, starting) + return (final - starting) / starting + return reward_func +``` diff --git a/environments/poker/poker.py b/environments/poker/poker.py new file mode 100644 index 000000000..80a23409b --- /dev/null +++ b/environments/poker/poker.py @@ -0,0 +1,884 @@ +""" +Poker: Heads-Up No-Limit Texas Hold'em. + +This environment demonstrates: +- Hidden information (each player sees only their own hole cards) +- Complex state management (chips, pot, betting rounds) +- Multi-phase gameplay (preflop, flop, turn, river, showdown) +- JSON-structured action parsing +- Different models per player (small vs large) + +Game flow: +1. Post blinds (dealer=small blind, other=big blind) +2. Deal 2 hole cards each (hidden) +3. Preflop betting (dealer acts first in heads-up) +4. Deal flop (3 community cards) +5. Postflop betting (non-dealer acts first) +6. Deal turn (1 card) +7. Turn betting +8. Deal river (1 card) +9. River betting +10. Showdown - compare hands, award pot +11. Repeat for num_hands (dealer rotates) +""" + +import json +import random +import re +from collections import Counter +from itertools import combinations +from datasets import Dataset + +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State +from verifiers.utils.client_utils import get_actor_client +import verifiers as vf + +# ============================================================================= +# Model Configuration +# ============================================================================= +# Change these to use different models for each player. +# Set to None to use the default model from the eval command. +# +# Small models: "olmo3-7b-i", "trinity-mini", "haiku", "gemini-3-flash" +# Large models: "sonnet", "opus", "qwen3-235b-i", "gemini-3-pro" +# ============================================================================= + +PLAYER1_ENDPOINT = "olmo3-7b-i" # Small model +PLAYER2_ENDPOINT = "qwen3-235b-i" # Large model + +player1_client, player1_model = get_actor_client(PLAYER1_ENDPOINT) +player2_client, player2_model = get_actor_client(PLAYER2_ENDPOINT) + + +# ============================================================================= +# Card Utilities +# ============================================================================= + +RANKS = "23456789TJQKA" +SUITS = "hdcs" # hearts, diamonds, clubs, spades +RANK_VALUES = {r: i for i, r in enumerate(RANKS, 2)} # 2=2, 3=3, ..., A=14 + + +def create_deck() -> list[str]: + """Create a shuffled 52-card deck.""" + deck = [f"{r}{s}" for r in RANKS for s in SUITS] + random.shuffle(deck) + return deck + + +def card_rank(card: str) -> int: + """Get numeric rank value (2-14).""" + return RANK_VALUES[card[0]] + + +def card_suit(card: str) -> str: + """Get suit character.""" + return card[1] + + +def format_cards(cards: list[str]) -> str: + """Format cards for display.""" + return ", ".join(cards) if cards else "None" + + +# ============================================================================= +# Hand Evaluation (Simple built-in evaluator) +# ============================================================================= + +# Hand rankings (higher = better) +HAND_RANKS = { + "high_card": 1, + "pair": 2, + "two_pair": 3, + "three_of_a_kind": 4, + "straight": 5, + "flush": 6, + "full_house": 7, + "four_of_a_kind": 8, + "straight_flush": 9, + "royal_flush": 10, +} + + +def evaluate_five_cards(cards: list[str]) -> tuple[int, list[int]]: + """ + Evaluate a 5-card hand. + Returns (hand_rank, tiebreakers) where higher is better. + """ + ranks = sorted([card_rank(c) for c in cards], reverse=True) + suits = [card_suit(c) for c in cards] + rank_counts = Counter(ranks) + + is_flush = len(set(suits)) == 1 + + # Check for straight (including A-2-3-4-5 wheel) + unique_ranks = sorted(set(ranks), reverse=True) + is_straight = False + straight_high = 0 + + if len(unique_ranks) == 5: + if unique_ranks[0] - unique_ranks[4] == 4: + is_straight = True + straight_high = unique_ranks[0] + # Wheel: A-2-3-4-5 + elif unique_ranks == [14, 5, 4, 3, 2]: + is_straight = True + straight_high = 5 # 5-high straight + + # Get counts for pair/trips/etc detection + counts = sorted(rank_counts.values(), reverse=True) + + # Determine hand type + if is_straight and is_flush: + if straight_high == 14 and 13 in ranks: # Royal flush + return (HAND_RANKS["royal_flush"], [14]) + return (HAND_RANKS["straight_flush"], [straight_high]) + + if counts == [4, 1]: + quad_rank = [r for r, c in rank_counts.items() if c == 4][0] + kicker = [r for r, c in rank_counts.items() if c == 1][0] + return (HAND_RANKS["four_of_a_kind"], [quad_rank, kicker]) + + if counts == [3, 2]: + trip_rank = [r for r, c in rank_counts.items() if c == 3][0] + pair_rank = [r for r, c in rank_counts.items() if c == 2][0] + return (HAND_RANKS["full_house"], [trip_rank, pair_rank]) + + if is_flush: + return (HAND_RANKS["flush"], ranks) + + if is_straight: + return (HAND_RANKS["straight"], [straight_high]) + + if counts == [3, 1, 1]: + trip_rank = [r for r, c in rank_counts.items() if c == 3][0] + kickers = sorted([r for r, c in rank_counts.items() if c == 1], reverse=True) + return (HAND_RANKS["three_of_a_kind"], [trip_rank] + kickers) + + if counts == [2, 2, 1]: + pairs = sorted([r for r, c in rank_counts.items() if c == 2], reverse=True) + kicker = [r for r, c in rank_counts.items() if c == 1][0] + return (HAND_RANKS["two_pair"], pairs + [kicker]) + + if counts == [2, 1, 1, 1]: + pair_rank = [r for r, c in rank_counts.items() if c == 2][0] + kickers = sorted([r for r, c in rank_counts.items() if c == 1], reverse=True) + return (HAND_RANKS["pair"], [pair_rank] + kickers) + + return (HAND_RANKS["high_card"], ranks) + + +def evaluate_hand(hole_cards: list[str], community: list[str]) -> tuple[int, list[int], str]: + """ + Find best 5-card hand from 7 cards. + Returns (hand_rank, tiebreakers, hand_name). + """ + all_cards = hole_cards + community + best_score = (0, []) + best_name = "high_card" + + # Try all 21 combinations of 5 cards from 7 + for five_cards in combinations(all_cards, 5): + score = evaluate_five_cards(list(five_cards)) + if score > best_score: + best_score = score + # Find hand name + for name, rank in HAND_RANKS.items(): + if rank == score[0]: + best_name = name + break + + return (best_score[0], best_score[1], best_name) + + +def compare_hands( + hole1: list[str], hole2: list[str], community: list[str] +) -> tuple[str, str, str]: + """ + Compare two hands. + Returns (winner, hand1_name, hand2_name) where winner is "player1", "player2", or "tie". + """ + eval1 = evaluate_hand(hole1, community) + eval2 = evaluate_hand(hole2, community) + + score1 = (eval1[0], eval1[1]) + score2 = (eval2[0], eval2[1]) + + if score1 > score2: + return ("player1", eval1[2], eval2[2]) + elif score2 > score1: + return ("player2", eval1[2], eval2[2]) + else: + return ("tie", eval1[2], eval2[2]) + + +# ============================================================================= +# Actors +# ============================================================================= + +AGGRESSIVE_PROMPT = """You are an AGGRESSIVE poker player in Heads-Up No-Limit Texas Hold'em. + +Rules: +- You and your opponent each have 2 hole cards (hidden from each other) +- 5 community cards are dealt face-up over multiple rounds +- Best 5-card hand from your 7 cards wins + +YOUR STRATEGY - Play aggressively: +- RAISE frequently, especially preflop with any decent hand +- BLUFF often - bet and raise even with mediocre hands to pressure opponent +- NEVER fold preflop unless you have absolute garbage (like 2-7 offsuit) +- When in doubt, RAISE rather than call or check +- Put maximum pressure on your opponent + +On your turn, output ONLY a JSON object with your action: +- Fold: {"action": "fold"} +- Check (if no bet to call): {"action": "check"} +- Call (match current bet): {"action": "call"} +- Raise to amount: {"action": "raise", "amount": 100} +- All-in: {"action": "allin"} + +Output ONLY the JSON, nothing else.""" + +CONSERVATIVE_PROMPT = """You are a SMART poker player in Heads-Up No-Limit Texas Hold'em. + +Rules: +- You and your opponent each have 2 hole cards (hidden from each other) +- 5 community cards are dealt face-up over multiple rounds +- Best 5-card hand from your 7 cards wins + +YOUR STRATEGY - Play solid poker: +- CALL or RAISE with good hands (pairs, high cards like A/K/Q, suited connectors) +- CHECK when you can to see free cards +- Only FOLD when facing a big bet with a truly weak hand +- If you have a strong hand (pair or better), RAISE to build the pot +- Go to SHOWDOWN when possible to see who wins + +On your turn, output ONLY a JSON object with your action: +- Fold: {"action": "fold"} +- Check (if no bet to call): {"action": "check"} +- Call (match current bet): {"action": "call"} +- Raise to amount: {"action": "raise", "amount": 100} +- All-in: {"action": "allin"} + +Output ONLY the JSON, nothing else.""" + +PLAYER1 = Actor( + id="player1", + system_prompt=CONSERVATIVE_PROMPT, # Small model plays smart + max_tokens=50, + is_trainable=True, + model=player1_model, + client=player1_client, +) + +PLAYER2 = Actor( + id="player2", + system_prompt=AGGRESSIVE_PROMPT, # Big model plays aggressive + max_tokens=50, + is_trainable=True, + model=player2_model, + client=player2_client, +) + + +# ============================================================================= +# Environment +# ============================================================================= + +class PokerEnv(MultiAgentEnv): + """Heads-Up No-Limit Texas Hold'em.""" + + name = "poker" + actors = ["player1", "player2"] + + def __init__( + self, + num_hands: int = 1, + max_actions_per_hand: int = 20, + starting_chips: int = 1000, + small_blind: int = 5, + big_blind: int = 10, + **kwargs, + ): + super().__init__(**kwargs) + self.num_hands = num_hands + self.max_actions_per_hand = max_actions_per_hand + self.starting_chips = starting_chips + self.small_blind = small_blind + self.big_blind = big_blind + + # ------------------------------------------------------------------------- + # Turn Management + # ------------------------------------------------------------------------- + + def get_initial_actor(self, state: State) -> str: + """Dealer (small blind) acts first preflop in heads-up.""" + return state["extras"]["dealer"] + + def get_next_actor(self, state: State) -> str: + """Alternate between players, skip folded.""" + current = state["extras"]["current_actor_id"] + next_actor = "player2" if current == "player1" else "player1" + + # Skip if folded + if next_actor in state["extras"]["folded"]: + return current + return next_actor + + # ------------------------------------------------------------------------- + # Stop Conditions + # ------------------------------------------------------------------------- + + @vf.stop + async def player_folded(self, state: State) -> bool: + """One player folded - hand over.""" + return len(state["extras"]["folded"]) > 0 + + @vf.stop + async def hand_complete(self, state: State) -> bool: + """Showdown complete or all hands played.""" + return state["extras"]["phase"] == "complete" + + @vf.stop + async def max_actions_hit(self, state: State) -> bool: + """Safety limit - force showdown.""" + if state["extras"]["actions_this_hand"] >= self.max_actions_per_hand: + await self._force_showdown(state) + return True + return False + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """Initialize poker game state.""" + state = await super().setup_state(state) + + # Session tracking + state["extras"]["hands_played"] = 0 + state["extras"]["starting_chips"] = self.starting_chips + state["extras"]["chips"] = { + "player1": self.starting_chips, + "player2": self.starting_chips, + } + + # Start first hand + state["extras"]["dealer"] = "player1" # Will rotate each hand + await self._start_new_hand(state) + + return state + + async def _start_new_hand(self, state: State) -> None: + """Initialize state for a new hand.""" + extras = state["extras"] + + # Create and shuffle deck + extras["deck"] = create_deck() + + # Deal hole cards + extras["hole_cards"] = { + "player1": [extras["deck"].pop(), extras["deck"].pop()], + "player2": [extras["deck"].pop(), extras["deck"].pop()], + } + + # Reset hand state + extras["community_cards"] = [] + extras["pot"] = 0 + extras["current_bet"] = 0 + extras["bets_this_round"] = {"player1": 0, "player2": 0} + extras["phase"] = "preflop" + extras["folded"] = [] + extras["actions_this_hand"] = 0 + extras["actions_this_round"] = {"player1": 0, "player2": 0} + extras["last_aggressor"] = None + + # Post blinds + dealer = extras["dealer"] + non_dealer = "player2" if dealer == "player1" else "player1" + + # Dealer posts small blind + sb_amount = min(self.small_blind, extras["chips"][dealer]) + extras["chips"][dealer] -= sb_amount + extras["bets_this_round"][dealer] = sb_amount + extras["pot"] += sb_amount + + # Non-dealer posts big blind + bb_amount = min(self.big_blind, extras["chips"][non_dealer]) + extras["chips"][non_dealer] -= bb_amount + extras["bets_this_round"][non_dealer] = bb_amount + extras["pot"] += bb_amount + extras["current_bet"] = bb_amount + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """Build prompt showing only this player's hole cards.""" + extras = state["extras"] + + my_cards = extras["hole_cards"][actor_id] + community = extras["community_cards"] + pot = extras["pot"] + my_chips = extras["chips"][actor_id] + opponent_id = "player2" if actor_id == "player1" else "player1" + opponent_chips = extras["chips"][opponent_id] + + to_call = extras["current_bet"] - extras["bets_this_round"][actor_id] + + # Build situation description + phase = extras["phase"].upper() + hand_num = extras["hands_played"] + 1 + + position = "Dealer (Small Blind)" if actor_id == extras["dealer"] else "Big Blind" + + situation = f"""=== HAND {hand_num} - {phase} === +Position: {position} + +Your hole cards: {format_cards(my_cards)} +Community cards: {format_cards(community)} + +Pot: {pot} +Your chips: {my_chips} +Opponent chips: {opponent_chips} + +Current bet: {extras['current_bet']} +Your bet this round: {extras['bets_this_round'][actor_id]} +To call: {to_call} + +Your action?""" + + actor = self.get_actor(actor_id) + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": situation}, + ] + + # ------------------------------------------------------------------------- + # Action Parsing + # ------------------------------------------------------------------------- + + def _parse_action(self, text: str, state: State, actor_id: str) -> dict: + """Parse JSON action from model output.""" + try: + # Strip markdown code blocks if present + clean = text.strip() + if "```" in clean: + match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", clean, re.DOTALL) + if match: + clean = match.group(1) + + # Try to find JSON object in text + json_match = re.search(r"\{[^}]+\}", clean) + if json_match: + clean = json_match.group(0) + + action = json.loads(clean) + + # Validate required field + if "action" not in action: + raise ValueError("Missing 'action' key") + + action_type = action["action"].lower() + + # Validate action type + valid_actions = ["fold", "check", "call", "raise", "allin"] + if action_type not in valid_actions: + raise ValueError(f"Invalid action: {action_type}") + + # Raise requires amount + if action_type == "raise": + if "amount" not in action: + raise ValueError("Raise requires 'amount'") + action["amount"] = int(action["amount"]) + + action["action"] = action_type + return action + + except (json.JSONDecodeError, ValueError, KeyError): + # FALLBACK: Safe default action + return self._get_fallback_action(state, actor_id) + + def _get_fallback_action(self, state: State, actor_id: str) -> dict: + """When parsing fails, do the safest legal action.""" + to_call = state["extras"]["current_bet"] - state["extras"]["bets_this_round"][actor_id] + if to_call == 0: + return {"action": "check"} + else: + return {"action": "fold"} + + def _validate_and_adjust_action(self, action: dict, state: State, actor_id: str) -> dict: + """Validate action and adjust if needed (clamp raises, etc.).""" + extras = state["extras"] + my_chips = extras["chips"][actor_id] + to_call = extras["current_bet"] - extras["bets_this_round"][actor_id] + + action_type = action["action"] + + # Can't check if there's a bet to call + if action_type == "check" and to_call > 0: + # Convert to call or fold based on having chips + if my_chips >= to_call: + return {"action": "call"} + else: + return {"action": "allin"} + + # Can't call more than we have + if action_type == "call": + if my_chips <= to_call: + return {"action": "allin"} + return action + + # Handle raise + if action_type == "raise": + amount = action.get("amount", 0) + min_raise = extras["current_bet"] + self.big_blind # Minimum raise + max_raise = my_chips + extras["bets_this_round"][actor_id] # All-in + + # Clamp to valid range + if amount >= max_raise: + return {"action": "allin"} + if amount < min_raise: + amount = min_raise + if amount > max_raise: + amount = max_raise + + return {"action": "raise", "amount": amount} + + # All-in is always valid + if action_type == "allin": + return action + + # Fold is always valid + return action + + # ------------------------------------------------------------------------- + # Game Logic + # ------------------------------------------------------------------------- + + async def on_turn_complete(self, state: State) -> None: + """Process action after each turn.""" + if not state["trajectory"]: + return + + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return + + # Get actor who just played + actor_id = last_step.get("extras", {}).get("actor_id") + if not actor_id: + return + + # Parse action + content = last_completion[-1].get("content", "") if isinstance(last_completion[-1], dict) else str(last_completion[-1]) + action = self._parse_action(content, state, actor_id) + action = self._validate_and_adjust_action(action, state, actor_id) + + extras = state["extras"] + extras["actions_this_hand"] += 1 + extras["actions_this_round"][actor_id] += 1 + + # Process action + await self._process_action(action, state, actor_id) + + # Check if betting round is complete + if self._is_betting_round_complete(state): + await self._advance_phase(state) + + async def _process_action(self, action: dict, state: State, actor_id: str) -> None: + """Process a validated action.""" + extras = state["extras"] + action_type = action["action"] + + if action_type == "fold": + extras["folded"].append(actor_id) + # Award pot to opponent + opponent = "player2" if actor_id == "player1" else "player1" + extras["chips"][opponent] += extras["pot"] + extras["pot"] = 0 + extras["winner"] = opponent + extras["win_reason"] = "fold" + extras["phase"] = "complete" + + elif action_type == "check": + pass # No chip movement + + elif action_type == "call": + to_call = extras["current_bet"] - extras["bets_this_round"][actor_id] + call_amount = min(to_call, extras["chips"][actor_id]) + extras["chips"][actor_id] -= call_amount + extras["bets_this_round"][actor_id] += call_amount + extras["pot"] += call_amount + + elif action_type == "raise": + amount = action["amount"] + # Amount is total bet, not additional + current_bet = extras["bets_this_round"][actor_id] + additional = amount - current_bet + additional = min(additional, extras["chips"][actor_id]) + + extras["chips"][actor_id] -= additional + extras["bets_this_round"][actor_id] += additional + extras["pot"] += additional + extras["current_bet"] = extras["bets_this_round"][actor_id] + extras["last_aggressor"] = actor_id + + # Reset opponent's action count (they need to respond to raise) + opponent = "player2" if actor_id == "player1" else "player1" + extras["actions_this_round"][opponent] = 0 + + elif action_type == "allin": + all_in_amount = extras["chips"][actor_id] + extras["bets_this_round"][actor_id] += all_in_amount + extras["pot"] += all_in_amount + extras["chips"][actor_id] = 0 + + if extras["bets_this_round"][actor_id] > extras["current_bet"]: + extras["current_bet"] = extras["bets_this_round"][actor_id] + extras["last_aggressor"] = actor_id + # Reset opponent's action count + opponent = "player2" if actor_id == "player1" else "player1" + extras["actions_this_round"][opponent] = 0 + + def _is_betting_round_complete(self, state: State) -> bool: + """Check if current betting round is complete.""" + extras = state["extras"] + + # Hand already over + if extras["phase"] == "complete" or extras["folded"]: + return False + + # Both players must have acted at least once this round + p1_actions = extras["actions_this_round"]["player1"] + p2_actions = extras["actions_this_round"]["player2"] + + if p1_actions == 0 or p2_actions == 0: + return False + + # Bets must be equal (or someone is all-in) + p1_bet = extras["bets_this_round"]["player1"] + p2_bet = extras["bets_this_round"]["player2"] + p1_chips = extras["chips"]["player1"] + p2_chips = extras["chips"]["player2"] + + bets_equal = p1_bet == p2_bet + p1_allin = p1_chips == 0 + p2_allin = p2_chips == 0 + + return bets_equal or p1_allin or p2_allin + + async def _advance_phase(self, state: State) -> None: + """Move to next phase of the hand.""" + extras = state["extras"] + current_phase = extras["phase"] + + # Reset for new betting round + extras["bets_this_round"] = {"player1": 0, "player2": 0} + extras["current_bet"] = 0 + extras["actions_this_round"] = {"player1": 0, "player2": 0} + extras["last_aggressor"] = None + + # Determine next phase + phase_order = ["preflop", "flop", "turn", "river", "showdown"] + current_idx = phase_order.index(current_phase) + next_phase = phase_order[current_idx + 1] + extras["phase"] = next_phase + + # Deal community cards + if next_phase == "flop": + # Burn and deal 3 + extras["deck"].pop() # Burn + extras["community_cards"].extend([ + extras["deck"].pop(), + extras["deck"].pop(), + extras["deck"].pop(), + ]) + elif next_phase == "turn": + # Burn and deal 1 + extras["deck"].pop() # Burn + extras["community_cards"].append(extras["deck"].pop()) + elif next_phase == "river": + # Burn and deal 1 + extras["deck"].pop() # Burn + extras["community_cards"].append(extras["deck"].pop()) + elif next_phase == "showdown": + await self._resolve_showdown(state) + + # Update current actor for post-flop (non-dealer acts first) + if next_phase in ["flop", "turn", "river"]: + dealer = extras["dealer"] + non_dealer = "player2" if dealer == "player1" else "player1" + # Only change if non-dealer hasn't folded + if non_dealer not in extras["folded"]: + extras["current_actor_id"] = non_dealer + + async def _force_showdown(self, state: State) -> None: + """Force showdown when max actions reached.""" + extras = state["extras"] + + # Deal remaining community cards + while len(extras["community_cards"]) < 5: + if extras["deck"]: + extras["deck"].pop() # Burn + if extras["deck"]: + extras["community_cards"].append(extras["deck"].pop()) + + extras["phase"] = "showdown" + await self._resolve_showdown(state) + + async def _resolve_showdown(self, state: State) -> None: + """Compare hands and award pot.""" + extras = state["extras"] + + p1_hole = extras["hole_cards"]["player1"] + p2_hole = extras["hole_cards"]["player2"] + community = extras["community_cards"] + + winner, p1_hand, p2_hand = compare_hands(p1_hole, p2_hole, community) + + extras["player1_hand"] = { + "hole_cards": p1_hole, + "hand_name": p1_hand.replace("_", " ").title(), + } + extras["player2_hand"] = { + "hole_cards": p2_hole, + "hand_name": p2_hand.replace("_", " ").title(), + } + + if winner == "tie": + # Split pot + half = extras["pot"] // 2 + extras["chips"]["player1"] += half + extras["chips"]["player2"] += extras["pot"] - half + extras["winner"] = "tie" + extras["win_reason"] = "split pot" + else: + extras["chips"][winner] += extras["pot"] + extras["winner"] = winner + loser = "player2" if winner == "player1" else "player1" + winner_hand = extras[f"{winner}_hand"]["hand_name"] + loser_hand = extras[f"{loser}_hand"]["hand_name"] + extras["win_reason"] = f"{winner_hand} beats {loser_hand}" + + extras["pot"] = 0 + extras["hands_played"] += 1 + + # Check if session continues + if extras["hands_played"] < self.num_hands: + # Check if both players have chips + if extras["chips"]["player1"] > 0 and extras["chips"]["player2"] > 0: + # Rotate dealer and start new hand + extras["dealer"] = "player2" if extras["dealer"] == "player1" else "player1" + await self._start_new_hand(state) + return + + # Session complete + extras["phase"] = "complete" + + async def on_game_end(self, state: State) -> None: + """Compute final session results.""" + extras = state["extras"] + + # Determine overall winner + p1_chips = extras["chips"]["player1"] + p2_chips = extras["chips"]["player2"] + + if p1_chips > p2_chips: + extras["session_winner"] = "player1" + elif p2_chips > p1_chips: + extras["session_winner"] = "player2" + else: + extras["session_winner"] = "tie" + + # Calculate profit/loss + starting = extras["starting_chips"] + extras["player1_profit"] = p1_chips - starting + extras["player2_profit"] = p2_chips - starting + + +# ============================================================================= +# Rubric (Scoring) +# ============================================================================= + +def create_rubric() -> MultiAgentRubric: + """Create rubric based on chip profit/loss.""" + rubric = MultiAgentRubric() + + def player_reward(actor_id: str): + def reward_func(state: State, **kwargs) -> float: + extras = state.get("extras", {}) + starting = extras.get("starting_chips", 1000) + final = extras.get("chips", {}).get(actor_id, starting) + # Normalize to [-1, 1] range (losing all = -1, doubling up = +1) + return (final - starting) / starting + return reward_func + + def hands_played_metric(state: State, **kwargs) -> float: + return float(state.get("extras", {}).get("hands_played", 0)) + + def showdowns_metric(state: State, **kwargs) -> float: + # Count non-fold endings + winner = state.get("extras", {}).get("winner") + reason = state.get("extras", {}).get("win_reason", "") + return 0.0 if "fold" in reason else 1.0 + + rubric.add_actor_reward_func("player1", player_reward("player1"), weight=1.0) + rubric.add_actor_reward_func("player2", player_reward("player2"), weight=1.0) + rubric.add_reward_func(hands_played_metric, weight=0.0) + rubric.add_reward_func(showdowns_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset(num_games: int = 10) -> Dataset: + """Create dataset for poker games.""" + return Dataset.from_list([ + { + "example_id": i, + "prompt": [{"role": "user", "content": "play"}], + "answer": "", + "task": "poker", + } + for i in range(num_games) + ]) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + num_hands: int = 5, + max_actions_per_hand: int = 20, + starting_chips: int = 1000, + small_blind: int = 5, + big_blind: int = 10, + num_examples: int = -1, +) -> PokerEnv: + """Factory function to create a fully configured Poker environment.""" + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + env = PokerEnv( + num_hands=num_hands, + max_actions_per_hand=max_actions_per_hand, + starting_chips=starting_chips, + small_blind=small_blind, + big_blind=big_blind, + rubric=create_rubric(), + max_turns=num_hands * max_actions_per_hand + 10, + dataset=dataset, + ) + + Protocol(actors=[PLAYER1, PLAYER2], envs=[env]) + + return env diff --git a/environments/poker/pyproject.toml b/environments/poker/pyproject.toml new file mode 100644 index 000000000..1ade5d9f9 --- /dev/null +++ b/environments/poker/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "poker" +description = "Heads-Up No-Limit Texas Hold'em Poker" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["poker.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 1 +rollouts_per_example = 1 +num_hands = 3 +max_actions_per_hand = 20 +starting_chips = 1000 +small_blind = 5 +big_blind = 10 diff --git a/environments/poker_multi/poker_multi.py b/environments/poker_multi/poker_multi.py new file mode 100644 index 000000000..d79c2b512 --- /dev/null +++ b/environments/poker_multi/poker_multi.py @@ -0,0 +1,1145 @@ +""" +Poker Multi: Multi-player No-Limit Texas Hold'em (2-9 players). + +This environment extends the heads-up poker to support: +- Configurable number of players (2-9) +- Full side pot tracking for all-in situations +- Proper dealer button rotation with SB/BB posting +- Turn rotation through active (non-folded) players + +Game flow: +1. Post blinds (SB = dealer+1, BB = dealer+2) +2. Deal 2 hole cards each (hidden) +3. Preflop betting (UTG = dealer+3 acts first, or SB in heads-up) +4. Deal flop (3 community cards) +5. Postflop betting (first active player left of dealer) +6. Deal turn (1 card) +7. Turn betting +8. Deal river (1 card) +9. River betting +10. Showdown - compare hands, award pots (including side pots) +11. Repeat for num_hands (dealer rotates) +""" + +import json +import random +import re +from collections import Counter +from itertools import combinations +from datasets import Dataset + +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State +from verifiers.utils.client_utils import get_actor_client +import verifiers as vf + + +# ============================================================================= +# Player Configuration +# ============================================================================= +# Configure each player's model and strategy. +# Players without a config use the default model from the -m flag. +# +# Available endpoints: See configs/endpoints.py (e.g., "olmo3-7b-i", "sonnet") +# Available strategies: "aggressive", "conservative", "balanced" +# ============================================================================= + +PLAYER_CONFIGS = [ + {"endpoint": "olmo3-7b-i", "strategy": "aggressive", "is_trainable": True}, + {"endpoint": "qwen3-235b-i", "strategy": "conservative", "is_trainable": False}, + {"endpoint": "qwen3-30b-i", "strategy": "balanced", "is_trainable": False}, + {"endpoint": "qwen3-235b-i", "strategy": "aggressive", "is_trainable": False}, + {"endpoint": "olmo3-7b-i", "strategy": "conservative", "is_trainable": False}, + # Player 6+ use default model from -m flag with "balanced" strategy +] + +STRATEGY_PROMPTS = { + "aggressive": """ +YOUR STRATEGY - Play aggressively: +- RAISE frequently, especially preflop with any decent hand +- BLUFF often - bet and raise even with mediocre hands to pressure opponents +- NEVER fold preflop unless you have absolute garbage (like 2-7 offsuit) +- When in doubt, RAISE rather than call or check +- Put maximum pressure on your opponents""", + + "conservative": """ +YOUR STRATEGY - Play solid poker: +- CALL or RAISE with good hands (pairs, high cards like A/K/Q, suited connectors) +- CHECK when you can to see free cards +- Only FOLD when facing a big bet with a truly weak hand +- If you have a strong hand (pair or better), RAISE to build the pot +- Go to SHOWDOWN when possible to see who wins""", + + "balanced": """ +YOUR STRATEGY - Play balanced poker: +- Mix up your play - sometimes raise, sometimes call, occasionally bluff +- RAISE with strong hands (pairs, AK, AQ) to build pots +- CALL with speculative hands (suited connectors, small pairs) to see flops +- FOLD weak hands when facing significant bets +- Pay attention to pot odds and position""", +} + + +# ============================================================================= +# Card Utilities +# ============================================================================= + +RANKS = "23456789TJQKA" +SUITS = "hdcs" # hearts, diamonds, clubs, spades +RANK_VALUES = {r: i for i, r in enumerate(RANKS, 2)} # 2=2, 3=3, ..., A=14 + + +def create_deck() -> list[str]: + """Create a shuffled 52-card deck.""" + deck = [f"{r}{s}" for r in RANKS for s in SUITS] + random.shuffle(deck) + return deck + + +def card_rank(card: str) -> int: + """Get numeric rank value (2-14).""" + return RANK_VALUES[card[0]] + + +def card_suit(card: str) -> str: + """Get suit character.""" + return card[1] + + +def format_cards(cards: list[str]) -> str: + """Format cards for display.""" + return ", ".join(cards) if cards else "None" + + +# ============================================================================= +# Hand Evaluation +# ============================================================================= + +HAND_RANKS = { + "high_card": 1, + "pair": 2, + "two_pair": 3, + "three_of_a_kind": 4, + "straight": 5, + "flush": 6, + "full_house": 7, + "four_of_a_kind": 8, + "straight_flush": 9, + "royal_flush": 10, +} + + +def evaluate_five_cards(cards: list[str]) -> tuple[int, list[int]]: + """ + Evaluate a 5-card hand. + Returns (hand_rank, tiebreakers) where higher is better. + """ + ranks = sorted([card_rank(c) for c in cards], reverse=True) + suits = [card_suit(c) for c in cards] + rank_counts = Counter(ranks) + + is_flush = len(set(suits)) == 1 + + # Check for straight (including A-2-3-4-5 wheel) + unique_ranks = sorted(set(ranks), reverse=True) + is_straight = False + straight_high = 0 + + if len(unique_ranks) == 5: + if unique_ranks[0] - unique_ranks[4] == 4: + is_straight = True + straight_high = unique_ranks[0] + # Wheel: A-2-3-4-5 + elif unique_ranks == [14, 5, 4, 3, 2]: + is_straight = True + straight_high = 5 # 5-high straight + + counts = sorted(rank_counts.values(), reverse=True) + + # Determine hand type + if is_straight and is_flush: + if straight_high == 14 and 13 in ranks: + return (HAND_RANKS["royal_flush"], [14]) + return (HAND_RANKS["straight_flush"], [straight_high]) + + if counts == [4, 1]: + quad_rank = [r for r, c in rank_counts.items() if c == 4][0] + kicker = [r for r, c in rank_counts.items() if c == 1][0] + return (HAND_RANKS["four_of_a_kind"], [quad_rank, kicker]) + + if counts == [3, 2]: + trip_rank = [r for r, c in rank_counts.items() if c == 3][0] + pair_rank = [r for r, c in rank_counts.items() if c == 2][0] + return (HAND_RANKS["full_house"], [trip_rank, pair_rank]) + + if is_flush: + return (HAND_RANKS["flush"], ranks) + + if is_straight: + return (HAND_RANKS["straight"], [straight_high]) + + if counts == [3, 1, 1]: + trip_rank = [r for r, c in rank_counts.items() if c == 3][0] + kickers = sorted([r for r, c in rank_counts.items() if c == 1], reverse=True) + return (HAND_RANKS["three_of_a_kind"], [trip_rank] + kickers) + + if counts == [2, 2, 1]: + pairs = sorted([r for r, c in rank_counts.items() if c == 2], reverse=True) + kicker = [r for r, c in rank_counts.items() if c == 1][0] + return (HAND_RANKS["two_pair"], pairs + [kicker]) + + if counts == [2, 1, 1, 1]: + pair_rank = [r for r, c in rank_counts.items() if c == 2][0] + kickers = sorted([r for r, c in rank_counts.items() if c == 1], reverse=True) + return (HAND_RANKS["pair"], [pair_rank] + kickers) + + return (HAND_RANKS["high_card"], ranks) + + +def evaluate_hand(hole_cards: list[str], community: list[str]) -> tuple[int, list[int], str]: + """ + Find best 5-card hand from 7 cards. + Returns (hand_rank, tiebreakers, hand_name). + """ + all_cards = hole_cards + community + best_score = (0, []) + best_name = "high_card" + + for five_cards in combinations(all_cards, 5): + score = evaluate_five_cards(list(five_cards)) + if score > best_score: + best_score = score + for name, rank in HAND_RANKS.items(): + if rank == score[0]: + best_name = name + break + + return (best_score[0], best_score[1], best_name) + + +def find_best_hand( + players: list[str], + hole_cards: dict[str, list[str]], + community: list[str], +) -> tuple[list[str], dict[str, tuple]]: + """ + Find the best hand(s) among a list of players. + Returns (winners, evaluations) where winners may be multiple (split pot). + """ + evaluations = {} + for player in players: + eval_result = evaluate_hand(hole_cards[player], community) + evaluations[player] = eval_result + + # Find best score + best_score = max((evaluations[p][0], evaluations[p][1]) for p in players) + + # Find all players with that score (for split pots) + winners = [ + p for p in players + if (evaluations[p][0], evaluations[p][1]) == best_score + ] + + return winners, evaluations + + +# ============================================================================= +# Actors +# ============================================================================= + +def create_player_system_prompt(num_players: int, strategy: str = "balanced") -> str: + """Create system prompt with strategy instructions.""" + strategy_text = STRATEGY_PROMPTS.get(strategy, STRATEGY_PROMPTS["balanced"]) + + return f"""You are playing {num_players}-player No-Limit Texas Hold'em Poker. + +Rules: +- Each player has 2 hole cards (hidden from others) +- 5 community cards are dealt face-up over multiple rounds +- Best 5-card hand from your 7 cards wins +{strategy_text} + +On your turn, output ONLY a JSON object with your action: +- Fold: {{"action": "fold"}} +- Check (if no bet to call): {{"action": "check"}} +- Call (match current bet): {{"action": "call"}} +- Raise to amount: {{"action": "raise", "amount": 100}} +- All-in: {{"action": "allin"}} + +Output ONLY the JSON, nothing else.""" + + +def create_actors(num_players: int) -> list[Actor]: + """ + Create player actors with per-player model and strategy configs. + + Players with a config in PLAYER_CONFIGS get that model/strategy. + Players without a config use the default model from -m flag with "balanced" strategy. + """ + actors = [] + for i in range(num_players): + # Get config if exists, otherwise use defaults + config = PLAYER_CONFIGS[i] if i < len(PLAYER_CONFIGS) else None + + if config: + client, model = get_actor_client(config.get("endpoint")) + strategy = config.get("strategy", "balanced") + is_trainable = config.get("is_trainable", False) + else: + # No config = use default model from -m flag + client, model = None, None + strategy = "balanced" + is_trainable = False + + system_prompt = create_player_system_prompt(num_players, strategy) + + actors.append(Actor( + id=f"player{i+1}", + system_prompt=system_prompt, + max_tokens=50, + is_trainable=is_trainable, + model=model, + client=client, + )) + return actors + + +# ============================================================================= +# Side Pot Management +# ============================================================================= + +class SidePotManager: + """ + Manages side pots for all-in situations. + + When a player goes all-in for less than the current bet, + we split the pot so they can only win from those who matched their bet. + """ + + def __init__(self, players: list[str]): + self.players = players + # Track total amount each player has put in across all rounds + self.contributions: dict[str, int] = {p: 0 for p in players} + # Players still in the hand (not folded) + self.active: set[str] = set(players) + + def add_contribution(self, player: str, amount: int) -> None: + """Record chips put into pot by player.""" + self.contributions[player] += amount + + def fold(self, player: str) -> None: + """Mark player as folded (can't win, contributions stay).""" + self.active.discard(player) + + def calculate_pots(self) -> list[dict]: + """ + Calculate main pot and side pots based on contributions. + + Returns list of pots, each with: + - amount: chips in this pot + - eligible: players who can win this pot + """ + if not self.active: + return [] + + # Get unique contribution levels from active players + active_contributions = { + p: self.contributions[p] for p in self.active + } + + # All contribution levels (including folded players who added chips) + all_contributions = list(self.contributions.values()) + + # Unique levels sorted ascending + levels = sorted(set(c for c in all_contributions if c > 0)) + + if not levels: + return [] + + pots = [] + prev_level = 0 + + for level in levels: + # How much each player contributes to this tier + tier_amount = level - prev_level + + # Players who contributed at least this much + contributors = [ + p for p in self.players + if self.contributions[p] >= level + ] + + # Eligible winners = active players who contributed + eligible = [p for p in contributors if p in self.active] + + if eligible and tier_amount > 0: + pot_amount = tier_amount * len(contributors) + pots.append({ + "amount": pot_amount, + "eligible": eligible, + "level": level, + }) + + prev_level = level + + return pots + + def total_pot(self) -> int: + """Get total chips in all pots.""" + return sum(self.contributions.values()) + + +# ============================================================================= +# Environment +# ============================================================================= + +class PokerMultiEnv(MultiAgentEnv): + """Multi-player No-Limit Texas Hold'em (2-9 players).""" + + name = "poker_multi" + + def __init__( + self, + num_players: int = 6, + num_hands: int = 1, + max_actions_per_hand: int = 50, + starting_chips: int = 1000, + small_blind: int = 5, + big_blind: int = 10, + **kwargs, + ): + # Validate player count + if num_players < 2 or num_players > 9: + raise ValueError("num_players must be between 2 and 9") + + super().__init__(**kwargs) + self.num_players = num_players + self.num_hands = num_hands + self.max_actions_per_hand = max_actions_per_hand + self.starting_chips = starting_chips + self.small_blind = small_blind + self.big_blind = big_blind + + # Dynamic actor list + self.actors = [f"player{i}" for i in range(1, num_players + 1)] + + # ------------------------------------------------------------------------- + # Turn Management + # ------------------------------------------------------------------------- + + def _get_seat_order(self, state: State) -> list[str]: + """Get players in seat order starting from dealer.""" + extras = state["extras"] + dealer_idx = extras["dealer_idx"] + n = self.num_players + return [self.actors[(dealer_idx + i) % n] for i in range(n)] + + def _get_active_players(self, state: State) -> list[str]: + """Get non-folded players in seat order.""" + extras = state["extras"] + return [p for p in self._get_seat_order(state) if p not in extras["folded"]] + + def _get_players_who_can_act(self, state: State) -> list[str]: + """Get players who can still act (not folded, not all-in).""" + extras = state["extras"] + return [ + p for p in self._get_seat_order(state) + if p not in extras["folded"] and extras["chips"][p] > 0 + ] + + def get_initial_actor(self, state: State) -> str: + """UTG acts first preflop (dealer+3), or dealer+1 for heads-up.""" + extras = state["extras"] + dealer_idx = extras["dealer_idx"] + n = self.num_players + + if n == 2: + # Heads-up: dealer (SB) acts first preflop + start_idx = dealer_idx + else: + # UTG = dealer + 3 (after SB and BB) + start_idx = (dealer_idx + 3) % n + + # Find first player who can act starting from this position + for i in range(n): + candidate_idx = (start_idx + i) % n + candidate = self.actors[candidate_idx] + if candidate not in extras["folded"] and extras["chips"][candidate] > 0: + return candidate + + # Fallback - shouldn't happen if game state is valid + return self.actors[start_idx] + + def get_next_actor(self, state: State) -> str: + """Get next player to act in clockwise order.""" + extras = state["extras"] + current = extras["current_actor_id"] + current_idx = self.actors.index(current) + n = self.num_players + + # Find next player who can act + for i in range(1, n + 1): + next_idx = (current_idx + i) % n + next_player = self.actors[next_idx] + + # Skip folded players + if next_player in extras["folded"]: + continue + # Skip all-in players (can't act but still in hand) + if extras["chips"][next_player] == 0: + continue + + return next_player + + # No one can act (everyone folded or all-in) + return current + + # ------------------------------------------------------------------------- + # Stop Conditions + # ------------------------------------------------------------------------- + + @vf.stop + async def one_player_left(self, state: State) -> bool: + """Only one player remains (all others folded).""" + active = self._get_active_players(state) + if len(active) == 1: + # Award pot to remaining player + winner = active[0] + extras = state["extras"] + extras["chips"][winner] += extras["pot_manager"].total_pot() + extras["winner"] = winner + extras["win_reason"] = "all others folded" + extras["hands_played"] += 1 + + # Check for next hand + if extras["hands_played"] < self.num_hands: + players_with_chips = [p for p in self.actors if extras["chips"][p] > 0] + if len(players_with_chips) >= 2: + # Rotate dealer to next active player + extras["dealer_idx"] = (extras["dealer_idx"] + 1) % self.num_players + extras["dealer_idx"] = self._next_active_seat(state, extras["dealer_idx"]) + await self._start_new_hand(state) + return False + + extras["phase"] = "complete" + return True + return False + + @vf.stop + async def hand_complete(self, state: State) -> bool: + """Showdown complete or all hands played.""" + return state["extras"]["phase"] == "complete" + + @vf.stop + async def max_actions_hit(self, state: State) -> bool: + """Safety limit - force showdown.""" + if state["extras"]["actions_this_hand"] >= self.max_actions_per_hand: + await self._force_showdown(state) + return True + return False + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """Initialize poker game state.""" + state = await super().setup_state(state) + + extras = state["extras"] + extras["hands_played"] = 0 + extras["starting_chips"] = self.starting_chips + extras["chips"] = {p: self.starting_chips for p in self.actors} + + # Start first hand with player1 as dealer + extras["dealer_idx"] = 0 + await self._start_new_hand(state) + + return state + + def _next_active_seat(self, state: State, start_idx: int) -> int: + """Find next seat index with an active (non-busted) player.""" + extras = state["extras"] + n = self.num_players + for i in range(n): + idx = (start_idx + i) % n + if extras["chips"][self.actors[idx]] > 0: + return idx + return start_idx # Fallback + + async def _start_new_hand(self, state: State) -> None: + """Initialize state for a new hand.""" + extras = state["extras"] + n = self.num_players + + # Ensure dealer is on an active player + extras["dealer_idx"] = self._next_active_seat(state, extras["dealer_idx"]) + dealer_idx = extras["dealer_idx"] + + # Create and shuffle deck + extras["deck"] = create_deck() + + # Deal hole cards to all players + extras["hole_cards"] = {} + for player in self.actors: + if extras["chips"][player] > 0: + extras["hole_cards"][player] = [ + extras["deck"].pop(), + extras["deck"].pop(), + ] + else: + extras["hole_cards"][player] = [] # Busted player + + # Reset hand state + extras["community_cards"] = [] + extras["current_bet"] = 0 + extras["bets_this_round"] = {p: 0 for p in self.actors} + extras["phase"] = "preflop" + extras["folded"] = [p for p in self.actors if extras["chips"][p] == 0] + extras["actions_this_hand"] = 0 + extras["actions_this_round"] = {p: 0 for p in self.actors} + extras["last_aggressor"] = None + + # Initialize pot manager + active_players = [p for p in self.actors if extras["chips"][p] > 0] + extras["pot_manager"] = SidePotManager(active_players) + + # Find SB and BB seats (skip busted players) + num_active = len(active_players) + if num_active == 2: + # Heads-up: dealer is SB + sb_idx = dealer_idx + bb_idx = self._next_active_seat(state, (dealer_idx + 1) % n) + else: + # Regular: SB is left of dealer, BB is left of SB + sb_idx = self._next_active_seat(state, (dealer_idx + 1) % n) + bb_idx = self._next_active_seat(state, (sb_idx + 1) % n) + + sb_player = self.actors[sb_idx] + bb_player = self.actors[bb_idx] + + # Post small blind + sb_amount = min(self.small_blind, extras["chips"][sb_player]) + extras["chips"][sb_player] -= sb_amount + extras["bets_this_round"][sb_player] = sb_amount + extras["pot_manager"].add_contribution(sb_player, sb_amount) + + # Post big blind + bb_amount = min(self.big_blind, extras["chips"][bb_player]) + extras["chips"][bb_player] -= bb_amount + extras["bets_this_round"][bb_player] = bb_amount + extras["current_bet"] = bb_amount + extras["pot_manager"].add_contribution(bb_player, bb_amount) + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """Build prompt showing only this player's hole cards.""" + extras = state["extras"] + + my_cards = extras["hole_cards"].get(actor_id, []) + community = extras["community_cards"] + my_chips = extras["chips"][actor_id] + to_call = extras["current_bet"] - extras["bets_this_round"][actor_id] + + phase = extras["phase"].upper() + hand_num = extras["hands_played"] + 1 + total_pot = extras["pot_manager"].total_pot() + + # Build opponent info + opponent_info = [] + for p in self.actors: + if p != actor_id: + status = "folded" if p in extras["folded"] else ( + "all-in" if extras["chips"][p] == 0 else "active" + ) + opponent_info.append( + f" {p}: {extras['chips'][p]} chips ({status})" + ) + + # Determine position + dealer_idx = extras["dealer_idx"] + my_idx = self.actors.index(actor_id) + relative_pos = (my_idx - dealer_idx) % self.num_players + + if self.num_players == 2: + position = "Dealer/SB" if relative_pos == 0 else "BB" + else: + positions = ["Dealer", "SB", "BB"] + [f"UTG+{i}" for i in range(6)] + position = positions[relative_pos] if relative_pos < len(positions) else f"Seat {relative_pos}" + + situation = f"""=== HAND {hand_num} - {phase} === +Players: {self.num_players} +Position: {position} + +Your hole cards: {format_cards(my_cards)} +Community cards: {format_cards(community)} + +Pot: {total_pot} +Your chips: {my_chips} +Opponents: +{chr(10).join(opponent_info)} + +Current bet: {extras['current_bet']} +Your bet this round: {extras['bets_this_round'][actor_id]} +To call: {to_call} + +Your action?""" + + actor = self.get_actor(actor_id) + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": situation}, + ] + + # ------------------------------------------------------------------------- + # Action Parsing + # ------------------------------------------------------------------------- + + def _parse_action(self, text: str, state: State, actor_id: str) -> dict: + """Parse JSON action from model output.""" + try: + clean = text.strip() + if "```" in clean: + match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", clean, re.DOTALL) + if match: + clean = match.group(1) + + json_match = re.search(r"\{[^}]+\}", clean) + if json_match: + clean = json_match.group(0) + + action = json.loads(clean) + + if "action" not in action: + raise ValueError("Missing 'action' key") + + action_type = action["action"].lower() + valid_actions = ["fold", "check", "call", "raise", "allin"] + if action_type not in valid_actions: + raise ValueError(f"Invalid action: {action_type}") + + if action_type == "raise": + if "amount" not in action: + raise ValueError("Raise requires 'amount'") + action["amount"] = int(action["amount"]) + + action["action"] = action_type + return action + + except (json.JSONDecodeError, ValueError, KeyError): + return self._get_fallback_action(state, actor_id) + + def _get_fallback_action(self, state: State, actor_id: str) -> dict: + """When parsing fails, do the safest legal action.""" + to_call = state["extras"]["current_bet"] - state["extras"]["bets_this_round"][actor_id] + if to_call == 0: + return {"action": "check"} + else: + return {"action": "fold"} + + def _validate_and_adjust_action(self, action: dict, state: State, actor_id: str) -> dict: + """Validate action and adjust if needed.""" + extras = state["extras"] + my_chips = extras["chips"][actor_id] + to_call = extras["current_bet"] - extras["bets_this_round"][actor_id] + + action_type = action["action"] + + if action_type == "check" and to_call > 0: + if my_chips >= to_call: + return {"action": "call"} + else: + return {"action": "allin"} + + if action_type == "call": + if my_chips <= to_call: + return {"action": "allin"} + return action + + if action_type == "raise": + amount = action.get("amount", 0) + min_raise = extras["current_bet"] + self.big_blind + max_raise = my_chips + extras["bets_this_round"][actor_id] + + if amount >= max_raise: + return {"action": "allin"} + if amount < min_raise: + amount = min_raise + if amount > max_raise: + amount = max_raise + + return {"action": "raise", "amount": amount} + + return action + + # ------------------------------------------------------------------------- + # Game Logic + # ------------------------------------------------------------------------- + + async def on_turn_complete(self, state: State) -> None: + """Process action after each turn.""" + if not state["trajectory"]: + return + + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return + + actor_id = last_step.get("extras", {}).get("actor_id") + if not actor_id: + return + + content = last_completion[-1].get("content", "") if isinstance(last_completion[-1], dict) else str(last_completion[-1]) + action = self._parse_action(content, state, actor_id) + action = self._validate_and_adjust_action(action, state, actor_id) + + extras = state["extras"] + extras["actions_this_hand"] += 1 + extras["actions_this_round"][actor_id] += 1 + + await self._process_action(action, state, actor_id) + + if self._is_betting_round_complete(state): + await self._advance_phase(state) + + async def _process_action(self, action: dict, state: State, actor_id: str) -> None: + """Process a validated action.""" + extras = state["extras"] + action_type = action["action"] + pot_mgr = extras["pot_manager"] + + if action_type == "fold": + extras["folded"].append(actor_id) + pot_mgr.fold(actor_id) + + elif action_type == "check": + pass + + elif action_type == "call": + to_call = extras["current_bet"] - extras["bets_this_round"][actor_id] + call_amount = min(to_call, extras["chips"][actor_id]) + extras["chips"][actor_id] -= call_amount + extras["bets_this_round"][actor_id] += call_amount + pot_mgr.add_contribution(actor_id, call_amount) + + elif action_type == "raise": + amount = action["amount"] + current_bet = extras["bets_this_round"][actor_id] + additional = amount - current_bet + additional = min(additional, extras["chips"][actor_id]) + + extras["chips"][actor_id] -= additional + extras["bets_this_round"][actor_id] += additional + extras["current_bet"] = extras["bets_this_round"][actor_id] + extras["last_aggressor"] = actor_id + pot_mgr.add_contribution(actor_id, additional) + + # Reset action counts for players who need to respond + for p in self.actors: + if p != actor_id and p not in extras["folded"] and extras["chips"][p] > 0: + extras["actions_this_round"][p] = 0 + + elif action_type == "allin": + all_in_amount = extras["chips"][actor_id] + extras["bets_this_round"][actor_id] += all_in_amount + extras["chips"][actor_id] = 0 + pot_mgr.add_contribution(actor_id, all_in_amount) + + if extras["bets_this_round"][actor_id] > extras["current_bet"]: + extras["current_bet"] = extras["bets_this_round"][actor_id] + extras["last_aggressor"] = actor_id + # Reset action counts for players who need to respond + for p in self.actors: + if p != actor_id and p not in extras["folded"] and extras["chips"][p] > 0: + extras["actions_this_round"][p] = 0 + + def _is_betting_round_complete(self, state: State) -> bool: + """Check if current betting round is complete.""" + extras = state["extras"] + + if extras["phase"] == "complete": + return False + + active = self._get_active_players(state) + if len(active) <= 1: + return False # Will be handled by one_player_left stop + + # Players who can still act + can_act = self._get_players_who_can_act(state) + + # If no one can act, round is complete + if not can_act: + return True + + # Check if all can-act players have acted and bets are equal + for player in can_act: + # Player hasn't acted + if extras["actions_this_round"][player] == 0: + return False + # Player's bet doesn't match (and they're not all-in) + if extras["bets_this_round"][player] < extras["current_bet"]: + return False + + return True + + async def _advance_phase(self, state: State) -> None: + """Move to next phase of the hand.""" + extras = state["extras"] + current_phase = extras["phase"] + + # Reset for new betting round + extras["bets_this_round"] = {p: 0 for p in self.actors} + extras["current_bet"] = 0 + extras["actions_this_round"] = {p: 0 for p in self.actors} + extras["last_aggressor"] = None + + phase_order = ["preflop", "flop", "turn", "river", "showdown"] + current_idx = phase_order.index(current_phase) + next_phase = phase_order[current_idx + 1] + extras["phase"] = next_phase + + if next_phase == "flop": + extras["deck"].pop() # Burn + extras["community_cards"].extend([ + extras["deck"].pop(), + extras["deck"].pop(), + extras["deck"].pop(), + ]) + elif next_phase == "turn": + extras["deck"].pop() + extras["community_cards"].append(extras["deck"].pop()) + elif next_phase == "river": + extras["deck"].pop() + extras["community_cards"].append(extras["deck"].pop()) + elif next_phase == "showdown": + await self._resolve_showdown(state) + return + + # Set first actor for post-flop (first active player left of dealer) + if next_phase in ["flop", "turn", "river"]: + dealer_idx = extras["dealer_idx"] + for i in range(1, self.num_players + 1): + candidate_idx = (dealer_idx + i) % self.num_players + candidate = self.actors[candidate_idx] + if candidate not in extras["folded"] and extras["chips"][candidate] > 0: + extras["current_actor_id"] = candidate + break + + async def _force_showdown(self, state: State) -> None: + """Force showdown when max actions reached.""" + extras = state["extras"] + + while len(extras["community_cards"]) < 5: + if extras["deck"]: + extras["deck"].pop() + if extras["deck"]: + extras["community_cards"].append(extras["deck"].pop()) + + extras["phase"] = "showdown" + await self._resolve_showdown(state) + + async def _resolve_showdown(self, state: State) -> None: + """Compare hands and award pots (including side pots).""" + extras = state["extras"] + pot_mgr = extras["pot_manager"] + community = extras["community_cards"] + + # Calculate pots + pots = pot_mgr.calculate_pots() + + # Track winnings + winnings = {p: 0 for p in self.actors} + pot_results = [] + + for pot in pots: + eligible = pot["eligible"] + amount = pot["amount"] + + if len(eligible) == 1: + # Only one player eligible + winner = eligible[0] + winnings[winner] += amount + pot_results.append({ + "amount": amount, + "winners": [winner], + "reason": "sole eligible" + }) + else: + # Compare hands + winners, evaluations = find_best_hand( + eligible, + extras["hole_cards"], + community, + ) + + # Split pot among winners + share = amount // len(winners) + remainder = amount % len(winners) + + for i, w in enumerate(winners): + winnings[w] += share + if i < remainder: + winnings[w] += 1 # Odd chips go to first winners + + hand_name = evaluations[winners[0]][2].replace("_", " ").title() + pot_results.append({ + "amount": amount, + "winners": winners, + "hand": hand_name, + "reason": "best hand" if len(winners) == 1 else "split pot" + }) + + # Award chips + for player, amount in winnings.items(): + extras["chips"][player] += amount + + # Store results + extras["pot_results"] = pot_results + + # Determine overall winner (most chips won this hand) + if winnings: + max_won = max(winnings.values()) + winners = [p for p, w in winnings.items() if w == max_won and w > 0] + if len(winners) == 1: + extras["winner"] = winners[0] + else: + extras["winner"] = "split" + else: + extras["winner"] = None + + extras["win_reason"] = "showdown" + + # Store hand evaluations for all active players + active = self._get_active_players(state) + extras["final_hands"] = {} + for p in active: + ev = evaluate_hand(extras["hole_cards"][p], community) + extras["final_hands"][p] = { + "hole_cards": extras["hole_cards"][p], + "hand_name": ev[2].replace("_", " ").title(), + } + + extras["hands_played"] += 1 + + # Check if session continues + if extras["hands_played"] < self.num_hands: + players_with_chips = [p for p in self.actors if extras["chips"][p] > 0] + if len(players_with_chips) >= 2: + # Rotate dealer to next active player + next_dealer = (extras["dealer_idx"] + 1) % self.num_players + extras["dealer_idx"] = self._next_active_seat(state, next_dealer) + await self._start_new_hand(state) + return + + extras["phase"] = "complete" + + async def on_game_end(self, state: State) -> None: + """Compute final session results.""" + extras = state["extras"] + + # Rank players by final chips + chip_ranking = sorted( + self.actors, + key=lambda p: extras["chips"][p], + reverse=True, + ) + + extras["final_ranking"] = chip_ranking + extras["session_winner"] = chip_ranking[0] + + # Calculate profits + starting = extras["starting_chips"] + extras["profits"] = { + p: extras["chips"][p] - starting + for p in self.actors + } + + +# ============================================================================= +# Rubric (Scoring) +# ============================================================================= + +def create_rubric(num_players: int) -> MultiAgentRubric: + """Create rubric based on chip profit/loss.""" + rubric = MultiAgentRubric() + + def player_reward(actor_id: str): + def reward_func(state: State, **kwargs) -> float: + extras = state.get("extras", {}) + starting = extras.get("starting_chips", 1000) + final = extras.get("chips", {}).get(actor_id, starting) + return (final - starting) / starting + return reward_func + + def hands_played_metric(state: State, **kwargs) -> float: + return float(state.get("extras", {}).get("hands_played", 0)) + + def showdowns_metric(state: State, **kwargs) -> float: + reason = state.get("extras", {}).get("win_reason", "") + return 0.0 if "fold" in reason else 1.0 + + for i in range(1, num_players + 1): + player_id = f"player{i}" + rubric.add_actor_reward_func(player_id, player_reward(player_id), weight=1.0) + + rubric.add_reward_func(hands_played_metric, weight=0.0) + rubric.add_reward_func(showdowns_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset(num_games: int = 10) -> Dataset: + """Create dataset for poker games.""" + return Dataset.from_list([ + { + "example_id": i, + "prompt": [{"role": "user", "content": "play"}], + "answer": "", + "task": "poker_multi", + } + for i in range(num_games) + ]) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + num_players: int = 6, + num_hands: int = 1, + max_actions_per_hand: int = 50, + starting_chips: int = 1000, + small_blind: int = 5, + big_blind: int = 10, + num_examples: int = -1, +) -> PokerMultiEnv: + """Factory function to create a fully configured multi-player Poker environment.""" + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + env = PokerMultiEnv( + num_players=num_players, + num_hands=num_hands, + max_actions_per_hand=max_actions_per_hand, + starting_chips=starting_chips, + small_blind=small_blind, + big_blind=big_blind, + rubric=create_rubric(num_players), + max_turns=num_hands * max_actions_per_hand * num_players + 10, + dataset=dataset, + ) + + actors = create_actors(num_players) + Protocol(actors=actors, envs=[env]) + + return env diff --git a/environments/poker_multi/pyproject.toml b/environments/poker_multi/pyproject.toml new file mode 100644 index 000000000..d010c8ca6 --- /dev/null +++ b/environments/poker_multi/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "poker_multi" +description = "Multi-player No-Limit Texas Hold'em Poker (2-9 players)" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["poker_multi.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 1 +rollouts_per_example = 1 +num_players = 6 +num_hands = 3 +max_actions_per_hand = 50 +starting_chips = 1000 +small_blind = 5 +big_blind = 10 diff --git a/environments/proposer_solver/proposer_solver.py b/environments/proposer_solver/proposer_solver.py new file mode 100644 index 000000000..0387903f5 --- /dev/null +++ b/environments/proposer_solver/proposer_solver.py @@ -0,0 +1,460 @@ +""" +Proposer-Solver: Hierarchical episode spawning with parent-child relationships. + +This environment demonstrates: +- Hierarchical spawning: Proposer generates problems, spawns Solver children +- Cross-environment communication via Protocol.spawn() +- Parent reward derived from child performance (solve_rate) +- Both actors trainable with per-actor GRPO advantages + +Game flow: +1. Proposer generates a math problem (e.g., "What is 7 + 5?") +2. Multiple Solver instances are spawned as child episodes +3. Each Solver attempts to answer independently +4. Proposer's reward = solver success rate (incentivizes good problems) +5. Solver rewards = correctness (1.0 if right, 0.0 if wrong) + +Training dynamics: +- Proposer learns to generate problems that solvers can solve (not too hard) +- Solvers learn to solve arithmetic problems correctly +- GRPO advantages computed per-actor (proposers vs proposers, solvers vs solvers) +""" + +import re +from datasets import Dataset + +import verifiers as vf +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State + + +# ============================================================================= +# Actors +# ============================================================================= + +PROPOSER = Actor( + id="proposer", + system_prompt="""You are a Math Problem Proposer. Generate a simple arithmetic problem. + +Rules: +1. Create a problem using +, -, or * with two numbers between 1-20 +2. Format: Just state the problem, e.g., "What is 7 + 5?" +3. Make it solvable but not trivial + +Output only the problem, nothing else.""", + max_tokens=50, + is_trainable=True, +) + +SOLVER = Actor( + id="solver", + system_prompt="""You are a Math Solver. Solve the given problem. + +Rules: +1. Calculate the answer +2. Output ONLY the numeric answer, nothing else +3. Example: If asked "What is 3 + 4?", output just: 7""", + max_tokens=20, + is_trainable=True, +) + + +# ============================================================================= +# Solver Environment (child episodes) +# ============================================================================= + +class SolverEnv(MultiAgentEnv): + """ + Single-turn solver environment using MultiAgentEnv for proper actor_id tagging. + + Why MultiAgentEnv instead of SingleTurnEnv? + - MultiAgentEnv automatically tags trajectory steps with actor_id + - This enables per-actor GRPO advantages in MultiAgentRubric + - Consistent state structure with parent ProposerSolverEnv + + With max_turns=1, this behaves like SingleTurnEnv but with full + multi-agent integration. + + Note: We compute reward directly in on_turn_complete instead of using + a Rubric with funcs=[] to avoid metric name collisions when results + are collected across mixed proposer/solver states. + """ + + name = "solver" + actors = ["solver"] # Single actor for this env + + def __init__(self, **kwargs): + # Dummy dataset - actual problems come from spawn() + dummy_ds = Dataset.from_dict({ + "prompt": [[{"role": "user", "content": "dummy"}]], + "answer": ["0"], + "example_id": [0], + "task": ["solver"], + }) + + super().__init__( + dataset=dummy_ds, + max_turns=1, # Single turn - just answer the problem + **kwargs + ) + + # --- Turn Management --- + + def get_initial_actor(self, state: State) -> str: + """Solver is the only actor.""" + return "solver" + + def get_next_actor(self, state: State) -> str: + """Never called (max_turns=1), but required by ABC.""" + return "solver" + + # --- The Two Hooks --- + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """ + Build prompt for solver. + + The prompt comes from the spawn() input - it contains the problem + generated by the proposer. + """ + return list(state["prompt"]) + + async def on_turn_complete(self, state: State) -> None: + """ + Compute solver's reward based on correctness. + + We compute reward here instead of using a Rubric to avoid + metric name collisions in results collection. + """ + if not state["trajectory"]: + return + + # Get completion text + last_step = state["trajectory"][-1] + completion = last_step.get("completion", []) + if not completion: + state["reward"] = 0.0 + return + + last_msg = completion[-1] + answer_text = last_msg.get("content", "") if isinstance(last_msg, dict) else str(last_msg) + + # Get expected answer + expected = str(state.get("answer", "")) + + # Check correctness + numbers = re.findall(r'-?\d+', str(answer_text)) + if numbers and numbers[0] == expected: + state["reward"] = 1.0 + else: + state["reward"] = 0.0 + + +# ============================================================================= +# Proposer Environment (parent, spawns solvers) +# ============================================================================= + +class ProposerSolverEnv(MultiAgentEnv): + """ + Proposer environment that generates problems and spawns solver children. + + Flow: + 1. Proposer generates a math problem (one turn) + 2. on_turn_complete() parses problem and spawns N solver episodes + 3. Solver children run in parallel, get scored by SolverEnv.rubric + 4. Proposer's solve_rate computed from child rewards + 5. Stop condition triggers, rollout complete + + The proposer is rewarded based on how well solvers do - this creates + an incentive to generate problems that are solvable but challenging. + """ + + name = "proposer_solver" + actors = ["proposer"] # Only proposer takes turns here; solver is separate env + + def __init__(self, num_solvers: int = 3, **kwargs): + """ + Args: + num_solvers: How many solver instances to spawn per problem + """ + super().__init__(**kwargs) + self.num_solvers = num_solvers + + # --- Turn Management --- + + def get_initial_actor(self, state: State) -> str: + return "proposer" + + def get_next_actor(self, state: State) -> str: + # After proposer generates, on_turn_complete handles spawning + # Then stop condition triggers - this won't be called + return "proposer" + + # --- Stop Condition --- + + @vf.stop + async def problem_generated(self, state: State) -> bool: + """Stop after proposer generates and solvers complete.""" + return state.get("extras", {}).get("solvers_done", False) + + # --- State Setup --- + + async def setup_state(self, state: State) -> State: + """Initialize proposer-solver specific state fields.""" + state = await super().setup_state(state) + state["extras"]["solvers_done"] = False + state["extras"]["solver_results"] = [] # Individual solver rewards + state["extras"]["solve_rate"] = 0.0 # Fraction of solvers correct + state["extras"]["problem"] = "" # Generated problem text + state["extras"]["expected_answer"] = "" # Computed answer + return state + + # --- The Two Hooks --- + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """ + Build prompt for proposer. + + Uses the prompt from the dataset which asks proposer to generate + a math problem. + """ + return list(state["prompt"]) + + async def on_turn_complete(self, state: State) -> None: + """ + After proposer generates a problem, spawn solver episodes. + + This is where the hierarchical spawning happens: + 1. Parse proposer's generated problem + 2. Compute the expected answer + 3. Create N solver inputs with the problem + 4. Spawn via Protocol.spawn() - runs in parallel + 5. Collect results and compute solve_rate + """ + if not state["trajectory"]: + return + + # --- Extract proposer's problem --- + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return + + last_msg = last_completion[-1] + problem = last_msg.get("content", "") if isinstance(last_msg, dict) else str(last_msg) + state["extras"]["problem"] = problem + + # --- Compute expected answer --- + answer = self._solve_problem(problem) + answer_str = str(answer) if answer is not None else "" + state["extras"]["expected_answer"] = answer_str + + # --- Build solver inputs --- + solver_inputs = [] + for i in range(self.num_solvers): + solver_inputs.append({ + "prompt": [ + {"role": "system", "content": SOLVER.system_prompt}, + {"role": "user", "content": problem} + ], + "answer": answer_str, + "example_id": state["input"]["example_id"], + "task": "solver", # Routes to SolverEnv + }) + + # --- Spawn child episodes --- + # Protocol.spawn() runs rollouts in parallel + # Reward is computed by SolverEnv.on_turn_complete (not rubric) + solver_states = await self.protocol.spawn( + solver_inputs, + client=state["client"], + model=state["model"], + sampling_args=state.get("sampling_args"), + score=False, # Reward already computed in on_turn_complete + ) + + # --- Collect results --- + solver_rewards = [] + for solver_state in solver_states: + reward = solver_state.get("reward") + reward = float(reward) if reward is not None else 0.0 + solver_rewards.append(reward) + # Add as child for flattening later + state["child_states"].append(solver_state) + + state["extras"]["solver_results"] = solver_rewards + state["extras"]["solve_rate"] = ( + sum(solver_rewards) / len(solver_rewards) if solver_rewards else 0.0 + ) + state["extras"]["num_solvers"] = len(solver_rewards) # Store count for metrics + state["extras"]["solvers_done"] = True + + # --- Helper: Parse and Solve Math Problem --- + + def _solve_problem(self, problem: str) -> int | None: + """ + Parse a simple arithmetic problem and compute the answer. + + Handles: "What is 7 + 5?", "Calculate 12 * 3", "7+5", etc. + """ + match = re.search(r'(\d+)\s*([+\-*])\s*(\d+)', problem) + if not match: + return None + + a, op, b = int(match.group(1)), match.group(2), int(match.group(3)) + + if op == '+': + return a + b + elif op == '-': + return a - b + elif op == '*': + return a * b + return None + + +# ============================================================================= +# Rubric +# ============================================================================= + +def create_rubric() -> MultiAgentRubric: + """ + Create rubric with per-actor reward functions. + + Scoring strategy: + - Proposer: Rewarded based on solver success rate (solve_rate) + - This incentivizes generating solvable problems + - solve_rate is computed during on_turn_complete from child rewards + - Solver: Already scored by SolverEnv.rubric during spawn() + - We preserve that reward (don't re-score) + + GRPO advantages are computed per-actor group: + - Proposers compared to other proposers (across batch) + - Solvers compared to other solvers (across batch) + """ + rubric = MultiAgentRubric() + + # --- Proposer Reward --- + def proposer_reward(state: State, **kwargs) -> float: + """ + Proposer reward = solver success rate. + + This creates interesting training dynamics: + - Too easy problems: solvers always win, proposer gets 1.0 + - Too hard problems: solvers fail, proposer gets 0.0 + - Optimal: challenging but solvable problems + """ + extras = state.get("extras") or {} + solve_rate = extras.get("solve_rate") + return float(solve_rate) if solve_rate is not None else 0.0 + + # --- Solver Reward (for states that weren't pre-scored) --- + def solver_reward(state: State, **kwargs) -> float: + """ + Solver reward - returns existing reward if already scored. + + Solver states are scored by SolverEnv.rubric during spawn(). + This function is a fallback for edge cases. + """ + # Return existing reward if present (don't re-compute) + reward = state.get("reward") + return float(reward) if reward is not None else 0.0 + + # --- Proposer Metrics (weight=0, for logging only) --- + def solve_rate_metric(state: State, **kwargs) -> float: + """Track solve rate (only meaningful for proposer states).""" + extras = state.get("extras") or {} + solve_rate = extras.get("solve_rate") + return float(solve_rate) if solve_rate is not None else 0.0 + + def num_solvers_metric(state: State, **kwargs) -> float: + """Track number of solver children (stored in extras during on_turn_complete).""" + extras = state.get("extras") or {} + return float(extras.get("num_solvers", 0)) + + # Register reward functions (weight=1.0 for actual rewards) + rubric.add_actor_reward_func("proposer", proposer_reward, weight=1.0) + rubric.add_actor_reward_func("solver", solver_reward, weight=1.0) + + # Register proposer-specific metrics (weight=0 = tracked but not in reward) + # Using add_actor_reward_func so they only apply to proposer states + rubric.add_actor_reward_func("proposer", solve_rate_metric, weight=0.0) + rubric.add_actor_reward_func("proposer", num_solvers_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset() -> Dataset: + """ + Create dataset for proposer-solver games. + + Each row is a seed for the proposer to generate a problem. + The actual problem content is generated by the proposer model. + """ + items = [ + { + "prompt": [ + {"role": "system", "content": PROPOSER.system_prompt}, + {"role": "user", "content": "Generate a math problem."} + ], + "answer": "", # Not used - proposer generates the problem + "example_id": i, + "task": "proposer_solver", + } + for i in range(10) + ] + return Dataset.from_list(items) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + num_solvers: int = 3, + num_examples: int = -1, +) -> ProposerSolverEnv: + """ + Factory function to create a fully configured Proposer-Solver environment. + + Args: + num_solvers: Number of solver instances to spawn per problem (default 3) + num_examples: Number of problems to generate (-1 = all 10) + + Returns: + Ready-to-use ProposerSolverEnv with Protocol wired up + + Example: + env = load_environment(num_solvers=3, num_examples=5) + results = await env.evaluate(client, model, num_examples=5) + """ + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + rubric = create_rubric() + + # Create proposer environment (parent) + proposer_env = ProposerSolverEnv( + num_solvers=num_solvers, + rubric=rubric, + max_turns=2, # Proposer turn + potential continuation + dataset=dataset, + ) + + # Create solver environment (children spawned here) + solver_env = SolverEnv() + + # Wire everything together via Protocol + # - Registers both actors (PROPOSER, SOLVER) + # - Registers both environments + # - Injects protocol reference into each env + Protocol( + actors=[PROPOSER, SOLVER], + envs=[proposer_env, solver_env], + ) + + return proposer_env diff --git a/environments/proposer_solver/pyproject.toml b/environments/proposer_solver/pyproject.toml new file mode 100644 index 000000000..af98802bf --- /dev/null +++ b/environments/proposer_solver/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "proposer-solver" +description = "Proposer-Solver multi-agent game with hierarchical spawning" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["proposer_solver.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 3 +rollouts_per_example = 1 diff --git a/environments/rock_paper_scissors/pyproject.toml b/environments/rock_paper_scissors/pyproject.toml new file mode 100644 index 000000000..aa03d251c --- /dev/null +++ b/environments/rock_paper_scissors/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "rock-paper-scissors" +description = "Rock-Paper-Scissors with simultaneous moves" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["rock_paper_scissors.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 3 +rollouts_per_example = 1 +num_rounds = 3 diff --git a/environments/rock_paper_scissors/rock_paper_scissors.py b/environments/rock_paper_scissors/rock_paper_scissors.py new file mode 100644 index 000000000..c348e0f93 --- /dev/null +++ b/environments/rock_paper_scissors/rock_paper_scissors.py @@ -0,0 +1,282 @@ +""" +Rock-Paper-Scissors: Multi-agent environment with simultaneous moves. + +This environment demonstrates: +- Simultaneous moves via get_active_actors() returning both players +- Per-actor reward functions (competitive scoring) +- Round-based game with history tracking + +Game flow: +1. Both players see the round number and previous results +2. Both make their choice (simultaneously from game perspective) +3. Round is resolved, scores updated +4. Repeat for num_rounds +5. Split into per-actor states for scoring +""" + +from datasets import Dataset + +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State +import verifiers as vf + + +# ============================================================================= +# Actors +# ============================================================================= + +PLAYER1 = Actor( + id="player1", + system_prompt="""You are Player 1 in Rock-Paper-Scissors. + +Choose ONE of: rock, paper, or scissors + +Output ONLY your choice (one word, lowercase). Nothing else.""", + max_tokens=10, + is_trainable=True, +) + +PLAYER2 = Actor( + id="player2", + system_prompt="""You are Player 2 in Rock-Paper-Scissors. + +Choose ONE of: rock, paper, or scissors + +Output ONLY your choice (one word, lowercase). Nothing else.""", + max_tokens=10, + is_trainable=True, +) + + +# ============================================================================= +# Environment +# ============================================================================= + +class RockPaperScissorsEnv(MultiAgentEnv): + """Rock-Paper-Scissors with simultaneous moves.""" + + name = "rock_paper_scissors" + actors = ["player1", "player2"] + + def __init__(self, num_rounds: int = 3, **kwargs): + super().__init__(**kwargs) + self.num_rounds = num_rounds + + # ------------------------------------------------------------------------- + # Turn Management + # ------------------------------------------------------------------------- + + def get_initial_actor(self, state: State) -> str: + return "player1" + + def get_next_actor(self, state: State) -> str: + return "player1" # Not really used since get_active_actors returns both + + def get_active_actors(self, state: State) -> list[str]: + """Both players act simultaneously each round.""" + return ["player1", "player2"] + + # ------------------------------------------------------------------------- + # Stop Condition + # ------------------------------------------------------------------------- + + @vf.stop + async def game_over(self, state: State) -> bool: + """Stop after all rounds played.""" + return state.get("extras", {}).get("round", 0) >= self.num_rounds + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """Initialize RPS-specific game state.""" + state = await super().setup_state(state) + state["extras"]["round"] = 0 + state["extras"]["p1_score"] = 0 + state["extras"]["p2_score"] = 0 + state["extras"]["history"] = [] + state["extras"]["p1_choice"] = None + state["extras"]["p2_choice"] = None + return state + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """Build fresh prompt for this player.""" + actor = self.get_actor(actor_id) + round_num = state["extras"]["round"] + 1 + + # Build history from this player's perspective + history = state["extras"]["history"] + history_str = "" + if history: + history_str = "\n\nPrevious rounds:\n" + for i, (p1, p2, result) in enumerate(history, 1): + you = p1 if actor_id == "player1" else p2 + opponent = p2 if actor_id == "player1" else p1 + history_str += f" Round {i}: You={you}, Opponent={opponent} -> {result}\n" + + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": f"Round {round_num} of {self.num_rounds}. Make your choice!{history_str}"} + ] + + # ------------------------------------------------------------------------- + # Game Logic + # ------------------------------------------------------------------------- + + async def on_turn_complete(self, state: State) -> None: + """ + Process choice and resolve round if both players have chosen. + Called AFTER each turn completes. + """ + if not state["trajectory"]: + return + + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return + + # Parse choice + content = last_completion[-1].get("content", "").lower().strip() if isinstance(last_completion[-1], dict) else str(last_completion[-1]).lower().strip() + choice = self._parse_choice(content) + + # Store choice for the actor who just played + actor_id = last_step.get("extras", {}).get("actor_id") + if actor_id == "player1": + state["extras"]["p1_choice"] = choice + else: + state["extras"]["p2_choice"] = choice + + # If both have chosen, resolve the round + p1_choice = state["extras"]["p1_choice"] + p2_choice = state["extras"]["p2_choice"] + + if p1_choice and p2_choice: + winner = self._determine_winner(p1_choice, p2_choice) + + if winner == "player1": + state["extras"]["p1_score"] += 1 + result = "Player 1 wins" + elif winner == "player2": + state["extras"]["p2_score"] += 1 + result = "Player 2 wins" + else: + result = "Tie" + + state["extras"]["history"].append((p1_choice, p2_choice, result)) + state["extras"]["round"] += 1 + state["extras"]["p1_choice"] = None + state["extras"]["p2_choice"] = None + + async def on_game_end(self, state: State) -> None: + """Declare final winner after all rounds complete.""" + p1_score = state["extras"]["p1_score"] + p2_score = state["extras"]["p2_score"] + + if p1_score > p2_score: + state["extras"]["winner"] = "player1" + elif p2_score > p1_score: + state["extras"]["winner"] = "player2" + else: + state["extras"]["winner"] = "tie" + + # ------------------------------------------------------------------------- + # Helper Functions + # ------------------------------------------------------------------------- + + def _parse_choice(self, text: str) -> str: + """Extract rock/paper/scissors from model output.""" + text = text.lower() + if "rock" in text: + return "rock" + elif "paper" in text: + return "paper" + elif "scissors" in text: + return "scissors" + return "rock" + + def _determine_winner(self, p1: str, p2: str) -> str | None: + """Determine winner. Returns None for tie.""" + if p1 == p2: + return None + wins = {"rock": "scissors", "paper": "rock", "scissors": "paper"} + if wins.get(p1) == p2: + return "player1" + return "player2" + + +# ============================================================================= +# Rubric (Scoring) +# ============================================================================= + +def create_rubric() -> MultiAgentRubric: + """Create competitive rubric with per-actor rewards.""" + rubric = MultiAgentRubric() + + def player1_reward(state: State, **kwargs) -> float: + extras = state.get("extras", {}) + p1_score = extras.get("p1_score", 0) + total_rounds = extras.get("round", 1) + return p1_score / total_rounds if total_rounds > 0 else 0.0 + + def player2_reward(state: State, **kwargs) -> float: + extras = state.get("extras", {}) + p2_score = extras.get("p2_score", 0) + total_rounds = extras.get("round", 1) + return p2_score / total_rounds if total_rounds > 0 else 0.0 + + def rounds_played_metric(state: State, **kwargs) -> float: + return float(state.get("extras", {}).get("round", 0)) + + rubric.add_actor_reward_func("player1", player1_reward, weight=1.0) + rubric.add_actor_reward_func("player2", player2_reward, weight=1.0) + rubric.add_reward_func(rounds_played_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset() -> Dataset: + """Create dataset for RPS games.""" + return Dataset.from_list([ + { + "example_id": i, + "prompt": [{"role": "user", "content": "play"}], + "answer": "", + "task": "rock_paper_scissors" + } + for i in range(10) + ]) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + num_rounds: int = 3, + num_examples: int = -1, +) -> RockPaperScissorsEnv: + """Factory function to create a fully configured RPS environment.""" + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + env = RockPaperScissorsEnv( + num_rounds=num_rounds, + rubric=create_rubric(), + max_turns=num_rounds * 2 + 2, + dataset=dataset, + ) + + Protocol(actors=[PLAYER1, PLAYER2], envs=[env]) + + return env diff --git a/environments/twenty_questions/pyproject.toml b/environments/twenty_questions/pyproject.toml new file mode 100644 index 000000000..9d3b174c8 --- /dev/null +++ b/environments/twenty_questions/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "twenty-questions" +description = "20 Questions multi-agent guessing game" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["twenty_questions.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 5 +rollouts_per_example = 1 diff --git a/environments/twenty_questions/twenty_questions.py b/environments/twenty_questions/twenty_questions.py new file mode 100644 index 000000000..98b872961 --- /dev/null +++ b/environments/twenty_questions/twenty_questions.py @@ -0,0 +1,324 @@ +""" +20 Questions: A simple multi-agent guessing game. + +This environment demonstrates: +- Alternating turns via get_next_actor() (standard turn-based flow) +- Asymmetric actors (one trainable, one frozen) +- Multiple stop conditions (win or max questions) +- Fresh prompts per actor with different context +- Different models per actor (small guesser vs large thinker) + +Game flow: +1. Guesser receives category hint and asks first question +2. Thinker (with secret word) answers yes/no +3. Alternate until guesser wins or runs out of questions +4. Only guesser is trained - rewarded for winning quickly +""" + +import re +from datasets import Dataset + +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State +from verifiers.utils.client_utils import get_actor_client +import verifiers as vf + +# ============================================================================= +# Model Configuration +# ============================================================================= +# Change these to use different models for each actor. +# Set to None to use the default model from the eval command. +# +# Small models: "olmo3-7b-i", "trinity-mini", "haiku", "gemini-3-flash" +# Large models: "sonnet", "opus", "qwen3-235b-i", "gemini-3-pro" +# ============================================================================= + +THINKER_ENDPOINT = "qwen3-235b-i" # Large model answers questions +GUESSER_ENDPOINT = "olmo3-7b-i" # Small model asks questions + +thinker_client, thinker_model = get_actor_client(THINKER_ENDPOINT) +guesser_client, guesser_model = get_actor_client(GUESSER_ENDPOINT) + + +# ============================================================================= +# Actors +# ============================================================================= + +THINKER = Actor( + id="thinker", + system_prompt="""You are the Thinker in 20 Questions. You have a SECRET WORD. + +Rules: +1. Answer questions with ONLY "Yes" or "No" +2. Be honest and consistent +3. If asked to guess, confirm with "Correct!" or "No, try again" + +Format your response as exactly one of: +- Yes +- No +- Correct! +- No, try again""", + max_tokens=20, + is_trainable=False, + model=thinker_model, + client=thinker_client, +) + +GUESSER = Actor( + id="guesser", + system_prompt="""You are the Guesser in 20 Questions. Try to figure out the secret word. + +Rules: +1. Ask yes/no questions to narrow down possibilities +2. When ready to guess, say "Is it [your guess]?" +3. You have 20 questions maximum + +Good strategy: Start broad (Is it alive? Is it man-made?) then narrow down. + +Format: Just ask your question directly.""", + max_tokens=50, + is_trainable=True, + model=guesser_model, + client=guesser_client, +) + + +# ============================================================================= +# Environment +# ============================================================================= + +class TwentyQuestionsEnv(MultiAgentEnv): + """20 Questions game environment.""" + + name = "twenty_questions" + actors = ["thinker", "guesser"] + + def __init__(self, max_questions: int = 20, **kwargs): + super().__init__(**kwargs) + self.max_questions = max_questions + + # ------------------------------------------------------------------------- + # Turn Management + # ------------------------------------------------------------------------- + + def get_initial_actor(self, state: State) -> str: + return "guesser" + + def get_next_actor(self, state: State) -> str: + current = state["extras"]["current_actor_id"] + return "thinker" if current == "guesser" else "guesser" + + # ------------------------------------------------------------------------- + # Stop Conditions + # ------------------------------------------------------------------------- + + @vf.stop + async def game_won(self, state: State) -> bool: + return state.get("extras", {}).get("won", False) + + @vf.stop + async def max_questions_reached(self, state: State) -> bool: + return state.get("extras", {}).get("question_count", 0) >= self.max_questions + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + state = await super().setup_state(state) + secret_word = state["input"].get("answer", "dog") + state["extras"]["secret_word"] = secret_word.lower() + state["extras"]["question_count"] = 0 + state["extras"]["won"] = False + state["extras"]["questions"] = [] + return state + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """Build fresh prompt for current actor.""" + secret = state["extras"]["secret_word"] + category = state["input"].get("info", {}).get("category", "thing") + + # Guesser's turn + if actor_id == "guesser": + # First turn: initial prompt + if len(state["trajectory"]) == 0: + return [ + {"role": "system", "content": GUESSER.system_prompt}, + {"role": "user", "content": f"I'm thinking of a {category}. You have {self.max_questions} questions to guess what it is. Ask your first question!"} + ] + + # Subsequent turns: show Q&A history + remaining = self.max_questions - state["extras"]["question_count"] + history_str = self._build_qa_history(state) + + return [ + {"role": "system", "content": GUESSER.system_prompt}, + {"role": "user", "content": f"I'm thinking of a {category}.\n\nConversation so far:\n{history_str}\n\nYou have {remaining} questions left. Ask another question or make a guess!"} + ] + + # Thinker's turn: show question with secret word + else: + last_question = self._get_last_guesser_response(state) + question_num = state["extras"]["question_count"] + 1 # Will be incremented after this turn + return [ + {"role": "system", "content": THINKER.system_prompt + f"\n\nYour secret word is: {secret}"}, + {"role": "user", "content": f"Question {question_num}: {last_question}"} + ] + + def _build_qa_history(self, state: State) -> str: + """Build Q&A history string from trajectory.""" + history_lines = [] + current_question = None + + for step in state["trajectory"]: + actor_id = step.get("extras", {}).get("actor_id") + completion = step.get("completion", []) + content = completion[-1].get("content", "") if completion else "" + + if actor_id == "guesser": + current_question = content + elif actor_id == "thinker" and current_question: + history_lines.append(f"Q: {current_question}") + history_lines.append(f"A: {content}") + current_question = None + + return "\n".join(history_lines) + + def _get_last_guesser_response(self, state: State) -> str: + """Get the most recent guesser response from trajectory.""" + for step in reversed(state["trajectory"]): + if step.get("extras", {}).get("actor_id") == "guesser": + completion = step.get("completion", []) + if completion: + return completion[-1].get("content", "") + return "" + + # ------------------------------------------------------------------------- + # Game Logic + # ------------------------------------------------------------------------- + + async def on_turn_complete(self, state: State) -> None: + """Process game logic after each turn.""" + if not state["trajectory"]: + return + + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return + + content = last_completion[-1].get("content", "") if isinstance(last_completion[-1], dict) else str(last_completion[-1]) + content_lower = content.lower().strip() + secret = state["extras"]["secret_word"] + last_actor = last_step.get("extras", {}).get("actor_id", "") + + if last_actor == "guesser": + # Guesser just asked a question + state["extras"]["question_count"] += 1 + state["extras"]["questions"].append(content) + + # Check if it's a correct guess + guess_match = re.search(r"is it (?:a |an )?([a-zA-Z]+)\s*\??$", content_lower) + if guess_match and guess_match.group(1).lower() == secret: + state["extras"]["won"] = True + state["final_env_response"] = [{"role": "user", "content": "Correct! You win!"}] + + else: + # Thinker just answered + if "correct" in content_lower: + state["extras"]["won"] = True + state["final_env_response"] = [{"role": "user", "content": "Correct! You win!"}] + elif state["extras"]["question_count"] >= self.max_questions: + state["final_env_response"] = [{"role": "user", "content": f"Game over! The word was: {secret}"}] + + async def on_game_end(self, state: State) -> None: + """Compute final game metrics after game loop exits.""" + won = state["extras"]["won"] + questions = state["extras"]["question_count"] + + # Efficiency: winning faster = higher score (1.0 for 1 question, 0.1 for 20) + if won: + state["extras"]["efficiency"] = 1.0 - 0.9 * (questions - 1) / 19 + else: + state["extras"]["efficiency"] = 0.0 + + +# ============================================================================= +# Rubric (Scoring) +# ============================================================================= + +def create_rubric() -> MultiAgentRubric: + """Create rubric - guesser rewarded for winning fast.""" + rubric = MultiAgentRubric() + + def guesser_reward(state, **kwargs) -> float: + """Read efficiency from extras (computed in on_game_end).""" + return state.get("extras", {}).get("efficiency", 0.0) + + def game_length_metric(state, **kwargs) -> float: + return float(state.get("extras", {}).get("question_count", 0)) + + def win_rate_metric(state, **kwargs) -> float: + return 1.0 if state.get("extras", {}).get("won", False) else 0.0 + + rubric.add_actor_reward_func("guesser", guesser_reward, weight=1.0) + rubric.add_reward_func(game_length_metric, weight=0.0) + rubric.add_reward_func(win_rate_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset() -> Dataset: + """Create dataset of secret words.""" + def make_prompt(category: str) -> list: + return [ + {"role": "system", "content": GUESSER.system_prompt}, + {"role": "user", "content": f"I'm thinking of a {category}. You have 20 questions to guess what it is. Ask your first question!"} + ] + + items = [ + {"prompt": make_prompt("animal"), "answer": "dog", "info": {"category": "animal"}, "example_id": 0, "task": "twenty_questions"}, + {"prompt": make_prompt("animal"), "answer": "cat", "info": {"category": "animal"}, "example_id": 1, "task": "twenty_questions"}, + {"prompt": make_prompt("animal"), "answer": "elephant", "info": {"category": "animal"}, "example_id": 2, "task": "twenty_questions"}, + {"prompt": make_prompt("animal"), "answer": "penguin", "info": {"category": "animal"}, "example_id": 3, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "chair", "info": {"category": "object"}, "example_id": 4, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "book", "info": {"category": "object"}, "example_id": 5, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "computer", "info": {"category": "object"}, "example_id": 6, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "bicycle", "info": {"category": "object"}, "example_id": 7, "task": "twenty_questions"}, + {"prompt": make_prompt("food"), "answer": "pizza", "info": {"category": "food"}, "example_id": 8, "task": "twenty_questions"}, + {"prompt": make_prompt("food"), "answer": "apple", "info": {"category": "food"}, "example_id": 9, "task": "twenty_questions"}, + ] + return Dataset.from_list(items) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + max_questions: int = 20, + num_examples: int = -1, +) -> TwentyQuestionsEnv: + """Factory function to create a fully configured 20 Questions environment.""" + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + env = TwentyQuestionsEnv( + max_questions=max_questions, + rubric=create_rubric(), + max_turns=max_questions * 2 + 2, + dataset=dataset, + ) + + Protocol(actors=[THINKER, GUESSER], envs=[env]) + + return env diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 535a16870..5f571a57f 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -19,6 +19,12 @@ from .envs.multiturn_env import MultiTurnEnv # noqa # isort: skip from .envs.tool_env import ToolEnv # noqa # isort: skip +# Multi-agent support +from .envs.actor import Actor # noqa # isort: skip +from .envs.protocol import Protocol # noqa # isort: skip +from .envs.multiagent_env import MultiAgentEnv # noqa # isort: skip +from .rubrics.multiagent_rubric import MultiAgentRubric # noqa # isort: skip + # main imports from .envs.env_group import EnvGroup from .envs.singleturn_env import SingleTurnEnv @@ -54,6 +60,11 @@ "JudgeRubric", "RubricGroup", "MathRubric", + "MultiAgentRubric", + # Multi-agent support + "Actor", + "Protocol", + "MultiAgentEnv", "TextArenaEnv", "ReasoningGymEnv", "GymEnv", diff --git a/verifiers/envs/actor.py b/verifiers/envs/actor.py new file mode 100644 index 000000000..1dee51304 --- /dev/null +++ b/verifiers/envs/actor.py @@ -0,0 +1,53 @@ +""" +Actor: A trainable entity with distinct identity (system prompt) in multi-agent environments. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from openai import AsyncOpenAI + + +@dataclass +class Actor: + """ + A trainable actor with distinct system prompt. + + Fields: + id: Unique identifier for this actor (e.g., "player1", "guesser") + system_prompt: The actor's persona/instructions (used in build_actor_prompt) + max_tokens: Max response length for this actor + is_trainable: Whether to compute GRPO advantages (False for frozen actors) + sampling_args: Per-actor model settings (temperature, etc.) + model: Model name override (e.g., "gpt-4o"). None = use default from trainer. + client: AsyncOpenAI client override (for different API endpoints). None = use default. + """ + + id: str + system_prompt: str = "" + max_tokens: int = 4096 + is_trainable: bool = True + sampling_args: dict[str, Any] = field(default_factory=dict) + model: str | None = None + client: "AsyncOpenAI | None" = None + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Actor): + return self.id == other.id + return False + + def __repr__(self) -> str: + trainable_str = "trainable" if self.is_trainable else "frozen" + return f"Actor(id={self.id!r}, {trainable_str})" + + def merge_sampling_args(self, base_args: dict[str, Any]) -> dict[str, Any]: + """Merge actor's sampling args with base args (actor takes precedence).""" + merged = dict(base_args) + merged.update(self.sampling_args) + if self.max_tokens: + merged["max_tokens"] = self.max_tokens + return merged diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py new file mode 100644 index 000000000..c32e0864b --- /dev/null +++ b/verifiers/envs/multiagent_env.py @@ -0,0 +1,563 @@ +""" +Multi-agent environment with turn order management and hierarchical spawning. + +This module provides the base class for multi-agent RL environments, extending +MultiTurnEnv with support for: +- Multiple actors with distinct system prompts and sampling args +- Turn order management via get_initial_actor() / get_next_actor() / get_active_actors() +- Per-actor trajectory tagging for credit assignment +- Per-actor state splitting for individual reward computation +- Hierarchical episode spawning via Protocol.spawn() + +Key concepts: +- Actor: A trainable entity with its own system prompt (defined in actor.py) +- Protocol: Wires actors to environments, enables spawning (defined in protocol.py) +- State splitting: One game state -> multiple actor states for per-actor rewards + +Game Implementation: +- Subclasses implement these main hooks: + - build_actor_prompt(actor_id, state): Build fresh prompt for this actor + - on_turn_complete(state): Update game state after each turn + - on_game_end(state): Compute final metrics after game loop exits (optional) +- Framework handles timing automatically + +""" + +from __future__ import annotations + +import asyncio +import uuid +from abc import abstractmethod +from typing import TYPE_CHECKING + +from datasets import Dataset +from openai import AsyncOpenAI + +import verifiers as vf +from verifiers.envs.multiturn_env import MultiTurnEnv +from verifiers.types import ( + Messages, + RolloutInput, + SamplingArgs, + State, + TrajectoryStep, +) + +if TYPE_CHECKING: + from verifiers.envs.actor import Actor + from verifiers.envs.protocol import Protocol + + +def _dummy_dataset() -> Dataset: + """ + Create a placeholder dataset for environments that don't specify one. + + The real dataset is typically owned by Protocol. This prevents errors + when MultiTurnEnv requires a dataset but one isn't provided. + """ + return Dataset.from_dict({ + "example_id": [0], + "prompt": [[{"role": "user", "content": "dummy"}]], + "answer": [""], + }) + + +# ============================================================================= +# MultiAgentEnv Base Class +# ============================================================================= + +class MultiAgentEnv(MultiTurnEnv): + """ + Base class for multi-agent environments. + + Subclasses must implement: + - get_initial_actor(): Who goes first + - get_next_actor(): Who goes next (for alternating turns) + - build_actor_prompt(): Build prompt for current actor + - on_turn_complete(): Game logic after each turn + + Subclasses may optionally override: + - on_game_end(): Compute final metrics after game loop exits + + The Protocol reference is injected by Protocol.__init__ when wiring + actors to environments. + """ + + # ------------------------------------------------------------------------- + # Class Attributes + # ------------------------------------------------------------------------- + + # List of actor IDs this environment uses (e.g., ["player1", "player2"]) + # Subclasses should override this + actors: list[str] = [] + + # Injected by Protocol.__init__ - provides actor lookup and spawning + protocol: "Protocol | None" = None + + def __init__(self, **kwargs): + """ + Initialize with dummy dataset if none provided. + + The parent class (MultiTurnEnv) requires a dataset, but for multi-agent + environments the Protocol often owns the real dataset. + """ + if "dataset" not in kwargs and "eval_dataset" not in kwargs: + kwargs["dataset"] = _dummy_dataset() + super().__init__(**kwargs) + + # ------------------------------------------------------------------------- + # Turn Management + # ------------------------------------------------------------------------- + + @abstractmethod + def get_initial_actor(self, state: State) -> str: + """ + Return the actor ID that starts the rollout. + + Example: return "guesser" for Twenty Questions + """ + pass + + @abstractmethod + def get_next_actor(self, state: State) -> str: + """ + Return the actor ID for the next turn. + + Example: return "thinker" if current == "guesser" else "guesser" + """ + pass + + def get_active_actors(self, state: State) -> list[str]: + """ + Return actor IDs that act this turn. + + Default: Single actor (standard alternating turns). + Override for simultaneous moves (e.g., RPS returns ["player1", "player2"]). + """ + current = state["extras"].get("current_actor_id") + if current is None: + return [self.get_initial_actor(state)] + return [self.get_next_actor(state)] + + def get_actor(self, actor_id: str) -> "Actor": + """Get an actor by ID from Protocol.""" + if self.protocol is None: + raise RuntimeError( + f"Cannot get_actor('{actor_id}') before Protocol is initialized. " + f"Ensure this environment is passed to Protocol(envs=[...])." + ) + return self.protocol.get_actor(actor_id) + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """ + Initialize multi-agent state fields. + + Sets up state["extras"] with: + - current_actor_id: Who is currently speaking (set in rollout) + - actor_history: List of (actor_id, turn_index) for credit assignment + - episode_id: Unique ID for this rollout + - parent_episode_id: Links to parent if this is a spawned child + + Also initializes state["child_states"] for per-actor state splitting. + """ + state = await super().setup_state(state) + + state["child_states"] = [] + state["extras"] = { + "current_actor_id": None, # Set in rollout() after setup + "actor_history": [], # Tracks who spoke at each turn + "episode_id": state.get("trajectory_id", uuid.uuid4().hex), + "parent_episode_id": None, # Set if spawned from parent episode + } + + return state + + # ------------------------------------------------------------------------- + # Game Hooks (Subclasses Implement These) + # ------------------------------------------------------------------------- + + @abstractmethod + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """ + Build the prompt for the given actor's turn. + + This is called BEFORE the model generates a response. + Build a fresh prompt with whatever context this actor needs. + + Args: + actor_id: The actor who will respond (e.g., "guesser", "player1") + state: Current game state with trajectory and extras + + Returns: + Messages list with system prompt and user content + """ + pass + + @abstractmethod + async def on_turn_complete(self, state: State) -> None: + """ + Update game state after a turn completes. + + This is called AFTER the model response is stored in trajectory. + Use this for game logic: + - Update scores, counters, flags + - Check win conditions (set state["extras"]["won"] = True, etc.) + - Store choices for later resolution (simultaneous moves) + + The last turn's info is in state["trajectory"][-1]: + - ["completion"][-1]["content"]: The model's response text + - ["extras"]["actor_id"]: Which actor just responded + + Args: + state: Current game state (mutate extras as needed) + """ + pass + + async def on_game_end(self, state: State) -> None: + """ + Finalize game state after the game loop exits. + + This is called ONCE after the game loop completes (stop condition met), + BEFORE render_completion() and BEFORE the rubric scores the state. + + Use this for: + - Computing final metrics (win rates, efficiency scores, etc.) + - Declaring the winner + - Preparing data that the rubric will read from state["extras"] + + Unlike on_turn_complete() which is called after each turn, this is + called exactly once when the game is definitely over. + + Args: + state: Final game state (mutate extras as needed for scoring) + """ + pass + + # ------------------------------------------------------------------------- + # Parent Class Requirement (env_response) + # ------------------------------------------------------------------------- + + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + """ + Satisfy MultiTurnEnv's abstract requirement. + + MultiAgentEnv uses on_turn_complete() instead, which is called + explicitly in our rollout() after storing the response. + This method is not used in the multi-agent flow. + """ + return [] + + # ------------------------------------------------------------------------- + # Trajectory Management + # ------------------------------------------------------------------------- + + async def add_trajectory_step( + self, state: State, trajectory_step: TrajectoryStep + ) -> None: + """Tag trajectory step with actor_id and record actor history.""" + current_actor_id = state["extras"]["current_actor_id"] + if current_actor_id: + # Tag step with actor_id for credit assignment + trajectory_step["extras"]["actor_id"] = current_actor_id + # Record history + turn_index = len(state["trajectory"]) + state["extras"]["actor_history"].append((current_actor_id, turn_index)) + await super().add_trajectory_step(state, trajectory_step) + + # ------------------------------------------------------------------------- + # Main Rollout Loop + # ------------------------------------------------------------------------- + + async def rollout( + self, + input, + client, + model, + sampling_args=None, + ) -> State: + """ + Run a multi-agent episode. + + Flow: + 1. Setup state + 2. Loop until game ends: + a. Get active actors (1 for alternating, multiple for simultaneous) + b. For each actor: + - Build prompt via build_actor_prompt() + - Get model response + - Store in trajectory + - Process via on_turn_complete() + 3. Call on_game_end() for final metrics + 4. Return final state + """ + state = await self.init_state(input, client, model, sampling_args) + try: + state = await self.setup_state(state) + except vf.Error as e: + state["error"] = e + return state + + while not await self.is_completed(state): + active_actors = self.get_active_actors(state) + + for actor_id in active_actors: + state["extras"]["current_actor_id"] = actor_id + + try: + # 1. Build prompt for this actor + prompt_messages = await self.build_actor_prompt(actor_id, state) + + # 2. Get model response with actor's sampling args and optional model/client override + actor = self.get_actor(actor_id) + merged_args = actor.merge_sampling_args(sampling_args or {}) + + # Log which model is being used for this actor + used_model = actor.model or state.get("model", "default") + self.logger.info(f"[{actor_id}] using model: {used_model}") + + response = await self.get_model_response( + state, + prompt_messages, + client=actor.client, # None = use default from state + model=actor.model, # None = use default from state + sampling_args=merged_args, + ) + + # 3. Store in trajectory + await self.add_model_response(state, prompt_messages, response) + + # 4. Process turn (game logic) + await self.on_turn_complete(state) + + except vf.OverlongPromptError: + state["prompt_too_long"] = True + state["is_truncated"] = True + break + except vf.Error as e: + state["error"] = e + break + + # Check if we should stop after processing all active actors + if await self.is_completed(state): + break + + await self.on_game_end(state) + await self.render_completion(state) + return state + + # ------------------------------------------------------------------------- + # Per-Actor State Creation + # ------------------------------------------------------------------------- + # + # After a game completes, we split the single game state into per-actor + # states for individual reward computation and GRPO advantage calculation. + # + # Example: RPS game with 6 turns + # Full trajectory: [p1, p2, p1, p2, p1, p2] + # Player1 state: trajectory=[p1, p1, p1], prompt="You are Player 1..." + # Player2 state: trajectory=[p2, p2, p2], prompt="You are Player 2..." + # ------------------------------------------------------------------------- + + # Fields shared by reference across all actor states + # NOTE: "input" is deliberately NOT shared because State.__getitem__/__setitem__ + # forward reads/writes for INPUT_FIELDS (prompt, answer, etc.) to input[key]. + # If we shared input, all actor_states would read the same prompt. + # NOTE: "timing" is deliberately NOT shared - each actor state gets its own copy + # to avoid bugs where score_group() updates the same dict multiple times. + SHARED_STATE_FIELDS = { + "client", # AsyncOpenAI API client + "model", # Model name string (e.g., "gpt-4o-mini") + "trajectory_id", # Unique rollout identifier + } + + def create_actor_state( + self, + parent_state: State, + actor_id: str, + actor_trajectory: list[TrajectoryStep], + ) -> State: + """ + Create a child state for a specific actor from a parent state. + + This splits a multi-actor game state into per-actor states for: + - Per-actor reward computation (via MultiAgentRubric) + - GRPO advantage calculation per actor + - Training only specific actors (is_trainable filtering) + + Args: + parent_state: The full game state with all actors' turns + actor_id: The actor this state is for (e.g., "guesser", "player1") + actor_trajectory: Only this actor's trajectory steps (filtered) + + Returns: + A new State with shared fields referenced and actor-specific fields fresh + """ + # Create empty State - no "input" key means INPUT_FIELDS forwarding doesn't apply + actor_state = State() + + # Copy shared fields by reference (not duplicated in memory) + for key in parent_state.keys(): + if key in self.SHARED_STATE_FIELDS: + actor_state[key] = parent_state[key] + + # Copy timing as a new dict (not shared) to avoid score_group() updating same dict multiple times + if "timing" in parent_state: + actor_state["timing"] = dict(parent_state["timing"]) + + # Copy INPUT_FIELDS directly (safe because actor_state has no "input" key) + actor_state["answer"] = parent_state.get("answer", "") + actor_state["task"] = parent_state.get("task", "") + actor_state["example_id"] = parent_state.get("example_id", 0) + actor_state["info"] = parent_state.get("info", {}) + + # Set actor-specific trajectory (filtered to just this actor's steps) + actor_state["trajectory"] = actor_trajectory + + # Copy extras but override actor_id to mark whose state this is + actor_state["extras"] = { + **parent_state.get("extras", {}), + "current_actor_id": actor_id, + } + + # Fresh fields for scoring (will be computed by rubric) + actor_state["child_states"] = [] + actor_state["reward"] = None + actor_state["advantage"] = None + actor_state["metrics"] = None + + # Copy trainability from Actor to State (so downstream doesn't need Protocol lookup) + actor = self.get_actor(actor_id) + actor_state["is_trainable"] = actor.is_trainable + + # Extract actor-specific prompt and completion + if actor_trajectory: + # Prompt: Find the LAST system message (actor's own prompt) + # The raw prompt may contain accumulated context from other actors + raw_prompt = actor_trajectory[0].get("prompt", []) + prompt_ref = raw_prompt + for i in range(len(raw_prompt) - 1, -1, -1): + if raw_prompt[i].get("role") == "system": + prompt_ref = raw_prompt[i:] # From last system message onward + break + actor_state["prompt"] = prompt_ref + + # Completion: Collect all responses across all turns + all_completions = [] + for step in actor_trajectory: + step_completion = step.get("completion", []) + all_completions.extend(step_completion) + actor_state["completion"] = all_completions + else: + # No trajectory for this actor - use parent's prompt + actor_state["prompt"] = parent_state.get("prompt", []) + actor_state["completion"] = [] + + return actor_state + + def create_actor_states(self, state: State, actor_ids: list[str] | None = None) -> list[State]: + """ + Split a parent state into per-actor child states. + + Filters the full trajectory by actor_id (set in add_trajectory_step), + then creates a state for each actor with their filtered trajectory. + + Args: + state: The full game state with all actors' turns + actor_ids: List of actor IDs to create states for. + Defaults to self.actors if not provided. + + Returns: + List of per-actor states, one for each actor_id + """ + if actor_ids is None: + actor_ids = self.actors + + actor_states = [] + for actor_id in actor_ids: + # Filter trajectory to only this actor's steps + actor_trajectory = [ + step for step in state.get("trajectory", []) + if step.get("extras", {}).get("actor_id") == actor_id + ] + + new_state = self.create_actor_state(state, actor_id, actor_trajectory) + actor_states.append(new_state) + + return actor_states + + # ------------------------------------------------------------------------- + # run_group Override (Flattening for prime-rl) + # ------------------------------------------------------------------------- + + async def run_group( + self, + group_inputs: list[RolloutInput], + client: AsyncOpenAI, + model: str, + gen_sampling_args: SamplingArgs, + gen_sem: asyncio.Semaphore, + score_sem: asyncio.Semaphore, + score: bool = True, + ) -> list[State]: + """ + Run rollouts and flatten to per-actor states for training. + + This is what prime-rl calls. Returns flattened states so GRPO + advantages get computed per-actor automatically. + + Flow: + 1. Run game rollouts via parent (each produces one game trajectory) + 2. Flatten: split each game into per-actor states + 3. Include any spawned child_states (proposer-solver pattern) + 4. Score all flattened states together (per-actor GRPO) + """ + # Run game rollouts (don't score yet - we'll score after flattening) + game_states = await super().run_group( + group_inputs, client, model, gen_sampling_args, + gen_sem, score_sem, score=False + ) + + # Flatten: one game -> multiple per-actor states + flattened = [] + for game_state in game_states: + # Split game trajectory by actor_id + flattened.extend(self.create_actor_states(game_state)) + # Include spawned children (proposer-solver pattern) + flattened.extend(game_state.get("child_states", [])) + + # Score flattened states (per-actor GRPO advantages) + if score and self.rubric: + await self.rubric.score_group(flattened, score_sem=score_sem) + + return flattened + + # ------------------------------------------------------------------------- + # Result Building (for generate/eval) + # ------------------------------------------------------------------------- + + def _prepare_rollout_results( + self, + all_states: list[State], + model: str, + client: AsyncOpenAI, + state_columns: list[str] | None, + results_path, + gen_sampling_args: SamplingArgs, + start_time: float, + ): + """Add actor_id to result dict for multi-agent environments.""" + result = super()._prepare_rollout_results( + all_states, model, client, state_columns, + results_path, gen_sampling_args, start_time + ) + result["actor_id"] = [ + s.get("extras", {}).get("current_actor_id", "unknown") + for s in all_states + ] + return result diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 9841cf6d8..84299b779 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -1,6 +1,5 @@ import logging from abc import abstractmethod -from typing import final from openai import AsyncOpenAI @@ -70,15 +69,20 @@ async def setup_state(self, state: State) -> State: return state async def get_prompt_messages(self, state: State) -> Messages: - """Override for rollouts with non-linear message sequences.""" + """Build prompt messages for the current turn.""" if len(state["trajectory"]) == 0: - return state["prompt"] + messages = list(state["prompt"]) # Copy to avoid mutation else: prev_turn_prompt = state["trajectory"][-1]["prompt"] prev_turn_completion = state["trajectory"][-1]["completion"] messages = concat_messages([prev_turn_prompt, prev_turn_completion]) env_response = await self.env_response(messages, state) - return concat_messages([messages, env_response]) + messages = concat_messages([messages, env_response]) + return self.modify_prompt_messages(messages, state) + + def modify_prompt_messages(self, messages: Messages, state: State) -> Messages: + """Override to transform prompt before sending to model (e.g., inject system prompt).""" + return messages async def render_completion(self, state: State): """Override for rollouts with non-linear message sequences.""" @@ -125,7 +129,6 @@ async def add_model_response( ) await self.add_trajectory_step(state, trajectory_step) - @final async def rollout( self, input: RolloutInput, @@ -138,18 +141,26 @@ async def rollout( state = await self.setup_state(state) except vf.Error as e: state["error"] = e + return state + while not await self.is_completed(state): try: prompt_messages = await self.get_prompt_messages(state) if state.get("final_env_response") is not None: - continue - response = await self.get_model_response(state, prompt_messages) + break + + response = await self.get_model_response( + state, prompt_messages, sampling_args=sampling_args + ) await self.add_model_response(state, prompt_messages, response) + except vf.Error as e: if isinstance(e, vf.OverlongPromptError): state["prompt_too_long"] = True state["is_truncated"] = True else: state["error"] = e + break + await self.render_completion(state) return state diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py new file mode 100644 index 000000000..35034c9c6 --- /dev/null +++ b/verifiers/envs/protocol.py @@ -0,0 +1,161 @@ +""" +Protocol: Wires actors to environments and enables cross-environment spawning. + +Protocol is the glue that connects: +- Actors (trainable entities with system prompts) +- Environments (where rollouts happen) + +Key functionality: +- Actor registry: Look up actors by ID +- Env registry: Look up environments by name +- spawn(): Run child rollouts in other environments (e.g., Proposer spawns Solvers) +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from openai import AsyncOpenAI + +from verifiers.types import RolloutInput, SamplingArgs, State +from verifiers.utils.async_utils import maybe_semaphore + +from .actor import Actor + +if TYPE_CHECKING: + from .environment import Environment + + +class Protocol: + """ + Wires actors to environments. Enables spawn() for cross-env communication. + + """ + + def __init__( + self, + actors: list[Actor], + envs: list["Environment"], + ): + """ + Register actors and environments. + + Args: + actors: List of Actor instances to register + envs: List of Environment instances to register + """ + # Register actors by ID + self._actors: dict[str, Actor] = {} + for actor in actors: + if actor.id in self._actors: + raise ValueError(f"Duplicate actor id: {actor.id}") + self._actors[actor.id] = actor + + # Register environments by name + self._envs: dict[str, "Environment"] = {} + for env in envs: + name = getattr(env, "name", env.__class__.__name__) + if name in self._envs: + raise ValueError(f"Duplicate environment name: {name}") + self._envs[name] = env + # Inject protocol reference so env can call self.protocol.spawn() + env.protocol = self + + def get_actor(self, actor_id: str) -> Actor: + """Get actor by ID.""" + if actor_id not in self._actors: + raise KeyError( + f"Actor '{actor_id}' not found. Available: {list(self._actors.keys())}" + ) + return self._actors[actor_id] + + def get_env(self, name: str) -> "Environment": + """Get environment by name.""" + if name not in self._envs: + raise KeyError( + f"Environment '{name}' not found. Available: {list(self._envs.keys())}" + ) + return self._envs[name] + + @property + def actors(self) -> dict[str, Actor]: + """All registered actors.""" + return self._actors + + @property + def envs(self) -> dict[str, "Environment"]: + """All registered environments.""" + return self._envs + + async def spawn( + self, + inputs: list[RolloutInput], + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + score: bool = True, + ) -> list[State]: + """ + Spawn child rollouts in target environments. + + Routes each input to its environment based on input["task"], + runs rollouts in parallel, and optionally scores them. + + Args: + inputs: List of rollout inputs, each with "task" field for routing + client: AsyncOpenAI client (required) + model: Model name (required) + sampling_args: Optional sampling parameters + score: Whether to score rollouts with env's rubric (default True) + + Returns: + List of completed states from child rollouts + + ) + """ + # Run all rollouts in parallel + tasks = [] + for inp in inputs: + env_name = inp.get("task") + if not env_name: + raise ValueError("spawn() requires 'task' field in each input") + env = self.get_env(env_name) + tasks.append( + env.rollout( + inp, + client=client, + model=model, + sampling_args=sampling_args, + ) + ) + + all_states = await asyncio.gather(*tasks) + + # Mark spawned states as children (for progress tracking) + for state in all_states: + if "extras" not in state: + state["extras"] = {} + state["extras"]["parent_episode_id"] = "spawned" + + # Score rollouts if requested + # Use score_group() instead of score_rollout() to properly handle: + # - MultiAgentRubric per-actor reward functions + # - GRPO advantage computation per-actor group + if score: + score_sem = await maybe_semaphore(-1) + # Group states by environment for proper score_group() semantics + states_by_env: dict[str, list[State]] = {} + for inp, state in zip(inputs, all_states): + env_name = inp.get("task") + if env_name not in states_by_env: + states_by_env[env_name] = [] + states_by_env[env_name].append(state) + + # Score each group with its environment's rubric + for env_name, env_states in states_by_env.items(): + env = self.get_env(env_name) + if env.rubric: + await env.rubric.score_group(env_states, score_sem=score_sem) + + return list(all_states) diff --git a/verifiers/rubrics/multiagent_rubric.py b/verifiers/rubrics/multiagent_rubric.py new file mode 100644 index 000000000..92f61732a --- /dev/null +++ b/verifiers/rubrics/multiagent_rubric.py @@ -0,0 +1,232 @@ +""" +Multi-agent rubric with per-actor rewards. + +Extends Rubric with: +- Per-actor reward functions (different rewards for different actors) +- Per-actor GRPO advantages (within-actor normalization) +- Group funcs run per-actor-group (solvers vs solvers, not vs proposers) +""" + +from __future__ import annotations + +import asyncio +import time +from collections import defaultdict +from typing import AsyncContextManager, cast + +import verifiers as vf +from verifiers.rubrics.rubric import Rubric +from verifiers.types import GroupRewardFunc, RewardFunc, State + + +class MultiAgentRubric(Rubric): + """ + Rubric with per-actor rewards. + + GRPO advantages are computed within actor groups (solver vs solver), + not across all actors, preventing unfair comparisons. + """ + + def __init__( + self, + funcs: list[RewardFunc] | None = None, + weights: list[float] | None = None, + parser: vf.Parser | None = None, + ): + super().__init__(funcs=funcs, weights=weights, parser=parser) + + # Per-actor reward functions: actor_id -> [(func, weight), ...] + self.actor_reward_funcs: dict[str, list[tuple[RewardFunc, float]]] = defaultdict(list) + + def add_actor_reward_func( + self, + actor_id: str, + func: RewardFunc, + weight: float = 1.0, + ) -> None: + """Add a reward function specific to an actor.""" + self.actor_reward_funcs[actor_id].append((func, weight)) + + def add_actor_metric( + self, + actor_id: str, + func: RewardFunc, + ) -> None: + """Add a metric (zero-weight reward) for logging without affecting reward.""" + self.add_actor_reward_func(actor_id, func, weight=0.0) + + def get_actor_id_from_state(self, state: State) -> str | None: + """Extract actor ID from state (checks extras, actor_history, trajectory).""" + # Check extras first (primary location) + extras = state.get("extras", {}) + if "current_actor_id" in extras: + return extras["current_actor_id"] + + # Check actor_history (take first actor if available) + actor_history = extras.get("actor_history", []) + if actor_history: + return actor_history[0][0] + + # Check trajectory steps + trajectory = state.get("trajectory", []) + for step in trajectory: + step_extras = step.get("extras", {}) + if "actor_id" in step_extras: + return step_extras["actor_id"] + + return None + + async def _compute_actor_reward( + self, + state: State, + actor_id: str, + score_sem: AsyncContextManager, + ) -> tuple[float, dict[str, float]]: + """Compute reward using actor-specific + global reward functions.""" + total_reward = 0.0 + metrics: dict[str, float] = {} + + # Compute rewards for the current actor + actor_funcs = self.actor_reward_funcs.get(actor_id, []) + for func, weight in actor_funcs: + try: + score = await self._call_individual_reward_func(func, state, score_sem) + score = score if score is not None else 0.0 + metrics[func.__name__] = score + total_reward += score * weight + except Exception as e: + self.logger.error(f"Error in actor reward func {func.__name__}: {e}") + metrics[func.__name__] = 0.0 + + # Also compute global reward functions + for func, weight in zip(self.funcs, self.weights): + if not self._is_group_func(func): + try: + score = await self._call_individual_reward_func(func, state, score_sem) + score = score if score is not None else 0.0 + total_reward += score * weight + metrics[func.__name__] = score + except Exception as e: + self.logger.error(f"Error in global reward func {func.__name__}: {e}") + metrics[func.__name__] = 0.0 + + return total_reward, metrics + + async def score_group( + self, + states: list[State], + score_sem: AsyncContextManager, + ) -> None: + """ + Score with per-actor GRPO advantages (solver vs solver, not vs proposer). + + Features: + - Children scored first (so parents can read child rewards) + - Per-actor reward functions + - Group funcs run per-actor-group + - Per-actor GRPO advantage normalization + """ + if not states: + self.logger.warning("No states to score") + return + + start_time = time.time() + + # Score children first, then parents (so parents can read child rewards) + children = [s for s in states if not s.get("child_states")] + parents = [s for s in states if s.get("child_states")] + + await self._score_states(children, score_sem) + await self._score_states(parents, score_sem) + + # Run group funcs per-actor-group + await self._run_group_funcs_per_actor(states, score_sem) + + # Compute GRPO advantages per-actor group + actor_groups: dict[str, list[State]] = defaultdict(list) + for state in states: + actor_id = self.get_actor_id_from_state(state) or "default" + actor_groups[actor_id].append(state) + + for actor_id, actor_states in actor_groups.items(): + # Skip GRPO advantage computation for non-trainable actors + # (they still get scored for logging, just no advantage for training) + if actor_states and not actor_states[0].get("is_trainable", True): + for state in actor_states: + state["advantage"] = 0.0 + for step in state.get("trajectory", []): + if step.get("advantage") is None: + step["advantage"] = 0.0 + if step.get("reward") is None: + step["reward"] = state["reward"] + continue + + actor_rewards = [s["reward"] for s in actor_states] + mean_reward = sum(actor_rewards) / len(actor_rewards) + + for state in actor_states: + advantage = state["reward"] - mean_reward + state["advantage"] = advantage + + for step in state.get("trajectory", []): + if step.get("advantage") is None: + step["advantage"] = advantage + if step.get("reward") is None: + step["reward"] = state["reward"] + + # Timing tracking (match parent) + end_time = time.time() + scoring_ms = (end_time - start_time) * 1000 + for state in states: + if "timing" in state: + state["timing"]["scoring_ms"] = scoring_ms + state["timing"]["total_ms"] += scoring_ms + + async def _score_states( + self, + states: list[State], + score_sem: AsyncContextManager, + ) -> None: + """Score a list of states with individual reward funcs.""" + if not states: + return + + actor_ids = [self.get_actor_id_from_state(s) or "default" for s in states] + + reward_tasks = [ + self._compute_actor_reward(state, actor_id, score_sem) + for state, actor_id in zip(states, actor_ids) + ] + results = await asyncio.gather(*reward_tasks) + + for state, actor_id, (reward, metrics) in zip(states, actor_ids, results): + state["reward"] = reward + state["metrics"] = metrics + + async def _run_group_funcs_per_actor( + self, + states: list[State], + score_sem: AsyncContextManager, + ) -> None: + """Run group reward funcs per-actor-group (solvers vs solvers, not vs proposers).""" + group_funcs = [(f, w) for f, w in zip(self.funcs, self.weights) if self._is_group_func(f)] + if not group_funcs: + return + + # Group states by actor + actor_groups: dict[str, list[State]] = defaultdict(list) + for state in states: + actor_id = self.get_actor_id_from_state(state) or "default" + actor_groups[actor_id].append(state) + + # Run group funcs on each actor group + for actor_id, actor_states in actor_groups.items(): + for func, weight in group_funcs: + group_func = cast(GroupRewardFunc, func) + scores = await self._call_group_reward_func(group_func, actor_states, score_sem) + + for state, score in zip(actor_states, scores): + state["reward"] += score * weight + if state["metrics"] is None: + state["metrics"] = {} + state["metrics"][func.__name__] = score diff --git a/verifiers/utils/client_utils.py b/verifiers/utils/client_utils.py index b3a34ec9b..34ed9cf37 100644 --- a/verifiers/utils/client_utils.py +++ b/verifiers/utils/client_utils.py @@ -64,3 +64,50 @@ def setup_client( ) return client + + +# ============================================================================= +# Actor Client Helper +# ============================================================================= + +try: + from configs.endpoints import ENDPOINTS as _DEFAULT_ENDPOINTS +except ImportError: + _DEFAULT_ENDPOINTS: dict = {} + + +def get_actor_client( + endpoint_key: str | None, + endpoints: dict | None = None, +) -> tuple[AsyncOpenAI | None, str | None]: + """ + Get client and model from an endpoint key for use with Actor. + + Returns (None, None) if endpoint_key is None or not found, + meaning the Actor will use the default model from the eval command. + + Example: + client, model = get_actor_client("sonnet") + actor = Actor(id="player", model=model, client=client) + """ + if not endpoint_key: + return None, None + + endpoints = endpoints or _DEFAULT_ENDPOINTS + + if endpoint_key not in endpoints: + logger.warning(f"Endpoint '{endpoint_key}' not found. Using default model.") + return None, None + + endpoint = endpoints[endpoint_key] + api_key = os.environ.get(endpoint["key"], "") + + if endpoint["key"] == "PRIME_API_KEY" and not api_key: + api_key = load_prime_config().get("api_key", "") + + client = AsyncOpenAI( + base_url=endpoint["url"], + api_key=api_key or "EMPTY", + ) + + return client, endpoint["model"] \ No newline at end of file diff --git a/verifiers/utils/display_utils.py b/verifiers/utils/display_utils.py index ac7843c8b..6227007a2 100644 --- a/verifiers/utils/display_utils.py +++ b/verifiers/utils/display_utils.py @@ -55,7 +55,7 @@ def is_tty() -> bool: class DisplayLogHandler(logging.Handler): """Custom log handler that captures log records for display.""" - def __init__(self, max_lines: int = 3) -> None: + def __init__(self, max_lines: int = 10) -> None: super().__init__() self.logs: deque[str] = deque(maxlen=max_lines) self.setFormatter(logging.Formatter("%(name)s: %(message)s")) @@ -88,7 +88,7 @@ def __init__(self, screen: bool = False, refresh_per_second: int = 4) -> None: self.console = Console() self._live: Live | None = None self._old_terminal_settings: list | None = None - self._log_handler = DisplayLogHandler(max_lines=3) + self._log_handler = DisplayLogHandler(max_lines=10) self._old_handler_levels: dict[logging.Handler, int] = {} self._old_datasets_level: int | None = None diff --git a/verifiers/utils/eval_display.py b/verifiers/utils/eval_display.py index d975a91d2..66014de69 100644 --- a/verifiers/utils/eval_display.py +++ b/verifiers/utils/eval_display.py @@ -8,6 +8,7 @@ import json import time +from collections import defaultdict from collections.abc import Mapping from dataclasses import dataclass, field from pathlib import Path @@ -40,9 +41,10 @@ class EnvEvalState: total: int = 0 # total rollouts num_examples: int = -1 # num examples (-1 means "all", updated by on_start) rollouts_per_example: int = 1 # rollouts per example (from config) - reward: float = 0.0 # reward (rolling avg) + reward: float = 0.0 # reward (rolling avg, 0.0 for multi-agent) metrics: dict[str, float] = field(default_factory=dict) # metrics (rolling avg) error_rate: float = 0.0 # error rate (rolling avg) + is_multiagent: bool = False # whether this env has multiple actors # path where results were saved (if save_results=true) save_path: Path | None = None @@ -191,6 +193,7 @@ def update_env_state( reward: float | None = None, metrics: dict[str, float] | None = None, error_rate: float | None = None, + is_multiagent: bool | None = None, error: str | None = None, save_path: Path | None = None, log_message: str | None = None, @@ -225,6 +228,9 @@ def update_env_state( if error_rate is not None: env_state.error_rate = error_rate + if is_multiagent is not None: + env_state.is_multiagent = is_multiagent + if error is not None: env_state.error = error @@ -246,10 +252,17 @@ def _get_error_rate_color(self, error_rate: float) -> str: return "white" def _make_metrics_row( - self, reward: float, metrics: dict[str, float], error_rate: float + self, + reward: float, + metrics: dict[str, float], + error_rate: float, + is_multiagent: bool = False, ) -> Table | None: """Create a metrics row with metrics left-aligned and error_rate right-aligned.""" - metrics = {"reward": reward, **metrics} + # For multi-agent, per-actor rewards are already in metrics + # For single-agent, show combined "reward" + if not is_multiagent: + metrics = {"reward": reward, **metrics} # build the left-aligned metrics text metrics_text = Text() @@ -363,7 +376,7 @@ def fmt_concurrency(val: int) -> str: # metrics display metrics_content = self._make_metrics_row( - env_state.reward, env_state.metrics, env_state.error_rate + env_state.reward, env_state.metrics, env_state.error_rate, env_state.is_multiagent ) # log message for special events @@ -496,7 +509,20 @@ def print_final_summary(self) -> None: examples_str = str(num_examples) rollouts_str = str(config.rollouts_per_example) - reward = f"{env_state.reward:.3f}" + # Per-actor rewards for multi-agent, otherwise single reward + results = env_state.results + actor_ids = results.get("actor_id", []) if results else [] + if actor_ids: + actor_rewards: dict[str, list[float]] = defaultdict(list) + for aid, rew in zip(actor_ids, results["reward"]): + actor_rewards[aid].append(rew) + reward_parts = [ + f"{aid}: {sum(rews)/len(rews):.2f}" + for aid, rews in actor_rewards.items() + ] + reward = " | ".join(reward_parts) + else: + reward = f"{env_state.reward:.3f}" # error rate with color coding error_rate = env_state.error_rate @@ -568,77 +594,146 @@ def _make_env_detail( """Create detailed content for a single environment's summary.""" items: list[Panel] = [] + # Check if multi-agent (has actor_id field) + actor_ids = results.get("actor_id", []) + is_multiagent = bool(actor_ids) + # Example 0 prompt/completion if results["prompt"] and results["completion"]: - prompt = messages_to_printable(results["prompt"][0]) - completion = messages_to_printable(results["completion"][0]) - reward_0 = results["reward"][0] if results["reward"] else 0.0 - error_0 = results["state"][0].get("error") if results["state"] else None - - # Prompt panel - items.append( - Panel( - _format_messages(prompt), - title="[dim]example 0 — prompt[/dim]", - border_style="dim", + if is_multiagent: + # Multi-agent: show one panel per unique actor for example 0 + example_ids = results.get("example_id", []) + first_example_id = example_ids[0] if example_ids else 0 + + # Find first occurrence of each actor in example 0 + seen_actors: set[str] = set() + for idx, (eid, aid) in enumerate(zip(example_ids, actor_ids)): + if eid != first_example_id: + continue + if aid in seen_actors: + continue + seen_actors.add(aid) + + completion = messages_to_printable(results["completion"][idx]) + reward_i = results["reward"][idx] + error_i = results["state"][idx].get("error") + + completion_text = _format_messages(completion) + if error_i is not None: + completion_text.append("\n\nerror: ", style="bold red") + completion_text.append(str(ErrorChain(error_i)), style="bold red") + completion_text.append("\n\nreward: ", style="bold cyan") + completion_text.append(f"{reward_i:.3f}", style="bold cyan") + + items.append( + Panel( + completion_text, + title=f"[dim]example 0 — {aid}[/dim]", + border_style="dim", + ) + ) + else: + # Single-agent: original behavior + prompt = messages_to_printable(results["prompt"][0]) + completion = messages_to_printable(results["completion"][0]) + reward_0 = results["reward"][0] if results["reward"] else 0.0 + error_0 = results["state"][0].get("error") if results["state"] else None + + # Prompt panel + items.append( + Panel( + _format_messages(prompt), + title="[dim]example 0 — prompt[/dim]", + border_style="dim", + ) ) - ) - # Completion panel (with error if any) - completion_text = _format_messages(completion) - if error_0 is not None: - completion_text.append("\n\nerror: ", style="bold red") - completion_text.append(str(ErrorChain(error_0)), style="bold red") - completion_text.append("\n\nreward: ", style="bold cyan") - completion_text.append(f"{reward_0:.3f}", style="bold cyan") - - items.append( - Panel( - completion_text, - title="[dim]example 0 — completion[/dim]", - border_style="dim", + # Completion panel (with error if any) + completion_text = _format_messages(completion) + if error_0 is not None: + completion_text.append("\n\nerror: ", style="bold red") + completion_text.append(str(ErrorChain(error_0)), style="bold red") + completion_text.append("\n\nreward: ", style="bold cyan") + completion_text.append(f"{reward_0:.3f}", style="bold cyan") + + items.append( + Panel( + completion_text, + title="[dim]example 0 — completion[/dim]", + border_style="dim", + ) ) - ) # Reward distribution rewards = results["reward"] if rewards: - # All rollouts histogram - all_rollouts_content = Group( - Text("all rollouts:", style="bold"), - _make_histogram(rewards, bins=8, width=25), - ) - - # Per-example averages if multiple rollouts - rollouts_per = config.rollouts_per_example - if rollouts_per > 1 and len(rewards) >= rollouts_per: - num_examples = len(rewards) // rollouts_per - example_avgs = [] - for i in range(num_examples): - example_rewards = rewards[i * rollouts_per : (i + 1) * rollouts_per] - example_avgs.append(sum(example_rewards) / len(example_rewards)) - - per_example_content = Group( - Text("per-example avg:", style="bold"), - _make_histogram(example_avgs, bins=8, width=25), - ) - - # Side by side - reward_display = Columns( - [all_rollouts_content, per_example_content], - equal=True, - expand=True, + if is_multiagent: + # Multi-agent: show per-actor histograms + actor_rewards: dict[str, list[float]] = defaultdict(list) + for aid, rew in zip(actor_ids, rewards): + actor_rewards[aid].append(rew) + + # Build per-actor content with histogram for each + actor_contents = [] + for aid, rews in actor_rewards.items(): + avg_rew = sum(rews) / len(rews) if rews else 0.0 + actor_content = Group( + Text(f"{aid}: ", style="bold cyan"), + Text(f"{avg_rew:.3f} avg ({len(rews)} rollouts)"), + _make_histogram(rews, bins=6, width=20), + ) + actor_contents.append(actor_content) + + # Display side by side if 2 actors, otherwise stacked + if len(actor_contents) == 2: + reward_display = Columns(actor_contents, equal=True, expand=True) + else: + reward_display = Group(*actor_contents) + + items.append( + Panel( + reward_display, + title="[dim]reward distribution (per-actor)[/dim]", + border_style="dim", + ) ) else: - reward_display = all_rollouts_content + # Single-agent: original behavior + all_rollouts_content = Group( + Text("all rollouts:", style="bold"), + _make_histogram(rewards, bins=8, width=25), + ) - items.append( - Panel( - reward_display, - title="[dim]reward distribution[/dim]", - border_style="dim", + # Per-example averages if multiple rollouts + rollouts_per = config.rollouts_per_example + if rollouts_per > 1 and len(rewards) >= rollouts_per: + num_examples = len(rewards) // rollouts_per + example_avgs = [] + for i in range(num_examples): + example_rewards = rewards[i * rollouts_per : (i + 1) * rollouts_per] + example_avgs.append(sum(example_rewards) / len(example_rewards)) + + per_example_content = Group( + Text("per-example avg:", style="bold"), + _make_histogram(example_avgs, bins=8, width=25), + ) + + # Side by side + reward_display = Columns( + [all_rollouts_content, per_example_content], + equal=True, + expand=True, + ) + else: + reward_display = all_rollouts_content + + items.append( + Panel( + reward_display, + title="[dim]reward distribution[/dim]", + border_style="dim", + ) ) - ) # Metrics if env_state.metrics: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 8035967c6..9746983ef 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -423,37 +423,69 @@ async def run_with_progress( env_config: EvalConfig, env_idx: int ) -> GenerateOutputs: """Run a single evaluation with display progress updates.""" - reward_accum = 0 - metrics_accum = defaultdict(float) - error_accum = 0 + actor_rewards: dict[str, list[float]] = defaultdict(list) + metrics_values: dict[str, list[float]] = defaultdict(list) + error_count = 0 + is_multiagent = False def on_start(total: int) -> None: - # total is num_examples * rollouts_per_example - # compute actual num_examples (resolves -1 to actual count) num_examples = total // env_config.rollouts_per_example display.update_env_state(env_idx, total=total, num_examples=num_examples) def on_progress(all_states: list[State], new_states: list[State]) -> None: - nonlocal error_accum, reward_accum, metrics_accum - - # Progress is always rollout-based - completed = len(all_states) + nonlocal error_count, is_multiagent for s in new_states: if s.get("error") is not None: - error_accum += 1 + error_count += 1 + + # Track rewards per-actor + actor_id = s.get("extras", {}).get("current_actor_id") reward = s.get("reward") if reward is not None: - reward_accum += reward - state_metrics = s.get("metrics") or {} - for name, value in state_metrics.items(): - if value is not None: - metrics_accum[name] += value + key = actor_id if actor_id else "_single" + actor_rewards[key].append(reward) + if actor_id: + is_multiagent = True + + # Track metrics (skip *_reward for multi-agent, we compute those ourselves) + for name, value in (s.get("metrics") or {}).items(): + if value is not None and not (is_multiagent and name.endswith("_reward")): + metrics_values[name].append(value) + + # Build display values + metrics: dict[str, float] = {} + + if is_multiagent: + # Multi-agent: per-actor rewards first + reward = 0.0 + for aid in sorted(actor_rewards.keys()): + vals = actor_rewards[aid] + if vals: + metrics[aid] = sum(vals) / len(vals) + else: + # Single-agent: combined reward + vals = actor_rewards.get("_single", []) + reward = sum(vals) / len(vals) if vals else 0.0 + + # Add other metrics + for name, vals in metrics_values.items(): + if vals: + metrics[name] = sum(vals) / len(vals) + + # Count completed: unique games for multi-agent, total states for single-agent + if is_multiagent: + # Multi-agent: count unique parent games (by trajectory_id) + # Exclude spawned children (parent_episode_id is set) to avoid overcounting + parent_states = [ + s for s in all_states + if s.get("extras", {}).get("parent_episode_id") is None + ] + completed = len(set(s.get("trajectory_id") for s in parent_states)) + else: + completed = len(all_states) - # Compute averages over completed rollouts - reward = reward_accum / completed - metrics = {name: metrics_accum[name] / completed for name in metrics_accum} - error_rate = error_accum / completed + error_rate = error_count / completed if completed > 0 else 0.0 display.update_env_state( env_idx, @@ -461,6 +493,7 @@ def on_progress(all_states: list[State], new_states: list[State]) -> None: reward=reward, metrics=metrics, error_rate=error_rate, + is_multiagent=is_multiagent, ) def on_log(message: str) -> None: @@ -562,6 +595,10 @@ def make_dataset(results: GenerateOutputs, **kwargs) -> Dataset: v = results["metrics"][k] results_dict[k] = v + # Add actor_id column for multi-agent environments + if "actor_id" in results and results["actor_id"]: + results_dict["actor_id"] = results["actor_id"] + # Add selected state columns if specified state_columns = results["metadata"]["state_columns"] if state_columns: