diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..c3352a5de Binary files /dev/null and b/.DS_Store differ diff --git a/examples/polaris_dev_1014.sh b/examples/polaris_dev_1014.sh new file mode 100644 index 000000000..0e7ccd560 --- /dev/null +++ b/examples/polaris_dev_1014.sh @@ -0,0 +1,230 @@ +#!/bin/bash + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +# ================== model config ============================= +SCRIPT_DIR=/root/slime/scripts +source "${SCRIPT_DIR}/models/qwen2.5-7B.sh" +# ============================================================= + +# ================= user config =============================== + +LOG_DIR=/home/projects/polyullm/kejing/slime_workspace/wandb/ + + +LOAD_DIR=/home/projects/polyullm/kejing/slime_workspace/slime_polaris/ckpt/Polaris-7B-bf16TI--8kstg1 +SAVE_DIR=/home/projects/polyullm/kejing/slime_workspace/slime_polaris/ckpt/Polaris-7B-bf16TI--8kstg1 +POLARIS_TRACKING_DIR=/home/projects/polyullm/kejing/slime_workspace/slime_polaris/Polaris-7B-bf16TI--8kstg1/polaris_reward_tracking + +DATA_DIR=/lustre/projects/polyullm/caishuo/cs_data/slime_rl/polaris-data-53K.jsonl + +# HF_CHECKPOINT=/lustre/projects/polyullm/caishuo/slime/InfiAlign-SFT-Qwen-7B-165K +HF_CHECKPOINT=/lustre/projects/polyullm/models/Qwen/Qwen2.5-7B-Instruct +#/lustre/projects/polyullm/caishuo/slime-0907/models/Qwen2.5-Math-7B__slime__fp8-bsz64-w0.05-f32scale-stg2-165k--0912----hf/iter_0012909__f32 +REF_LOAD=/home/projects/polyullm/caishuo/cs20251004/cs_models/slime_sft_models/TL0920_stg2/ +# ============================================================== + +# ================ paralle config ============================== +TP=4 +PP=1 +CP=2 +EP_MP=1 +EP_TP=1 +MAX_TOKENS_PER_GPU=8192 +# ============================================================== + +# ================ RL specific config ========================= +train_prompt_bsz=128 +gen_prompt_bsz=$((train_prompt_bsz)) #$((train_prompt_bsz * 3)) + +NUM_ROLLOUT=10240 +N_SAMPLES_PER_PROMPT=16 +GLOBAL_BATCH_SIZE=2048 +ROLLOUT_MAX_RESPONSE_LEN=8192 +ROLLOUT_TEMPERATURE=1.0 #1.1 +OVER_SAMPLING_BATCH_SIZE=${gen_prompt_bsz} +# ============================================================== + +CKPT_ARGS=( + --hf-checkpoint ${HF_CHECKPOINT} + --ref-load ${REF_LOAD} + --load ${LOAD_DIR} + --save ${SAVE_DIR} + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data ${DATA_DIR} + --input-key prompt + --label-key label + --apply-chat-template + # --rollout-shuffle # remove shufffle + --rm-type deepscaler + --num-rollout ${NUM_ROLLOUT} + --rollout-batch-size ${train_prompt_bsz} + --n-samples-per-prompt ${N_SAMPLES_PER_PROMPT} + --rollout-max-response-len ${ROLLOUT_MAX_RESPONSE_LEN} + --rollout-temperature ${ROLLOUT_TEMPERATURE} + --over-sampling-batch-size ${OVER_SAMPLING_BATCH_SIZE} # ${gen_prompt_bsz} + # --dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std + + --global-batch-size ${GLOBAL_BATCH_SIZE} # ${train_prompt_bsz} + --balance-data + + +) + +POLARIS_ARGS=( + --enable-polaris-dynamic-sampling + --polaris-good-reward-min 0.1 + --polaris-good-reward-max 0.9 + --polaris-min-good-ratio 0.33 + --enable-polaris-reward-tracking + --polaris-reward-tracking-dir ${POLARIS_TRACKING_DIR} + --polaris-verbose + # align with verl's behavior + --polaris-skip-batch-when-insufficient +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /home/projects/polyullm/kejing/slime_workspace/data/aime-2024.jsonl + --n-samples-per-eval-prompt 1 # 16 + --eval-max-response-len 30000 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size ${TP} + --sequence-parallel + --pipeline-model-parallel-size ${PP} + --context-parallel-size ${CP} + --expert-model-parallel-size ${EP_MP} + --expert-tensor-parallel-size ${EP_TP} + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu ${MAX_TOKENS_PER_GPU} +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-dev + --wandb-group Polaris-7B-8kstg1 + --wandb-mode offline + --wandb-dir ${LOG_DIR} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 4 + --sglang-mem-fraction-static 0.5 + --sglang-max-running-requests 128 + +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +PRECISE_ARGS=( + --transformer-impl transformer_engine + --bf16 + #--fp8-format e4m3 + #--fp8-recipe blockwise + #--fp8-param-gather + # --direct-update-fp8-weight +) + +TENSORBOARD_ARGS=( + --profile-step-start 10 + --profile-step-end 12 + --tensorboard-dir ${LOG_DIR}/tensorboard + --record-memory-history +) + +# launch the master node of ray in container +export http_proxy="" +export https_proxy="" + + + + +# Build the runtime environment JSON with proper variable substitution +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json='{ + "working_dir": "/home/projects/polyullm/kejing/slime_workspace/slime_polaris/slime", + "env_vars": { + "PYTHONPATH": "/root/Megatron-LM/:/home/projects/polyullm/kejing/slime_workspace/slime_polaris/slime", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:1024", + "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD": "1", + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "1", + "NVTE_DEBUG": "0", + "HOME": "/home/projects/polyullm/kejing/slime_workspace/slime_polaris/slime", + "http_proxy": "", + "https_proxy": "", + "NCCL_SOCKET_IFNAME": "bond0", + "WANDB_MODE": "offline", + "WANDB_DIR": "/home/projects/polyullm/kejing/slime_workspace/wandb/", + "RAY_DEDUP_LOGS_ALLOW_REGEX": "kejing", + "NO_PROXY": "localhost,127.0.0.1,klb-dgx-*", + "no_proxy": "localhost,127.0.0.1,klb-dgx-*" + } + }' \ + -- python3 train.py \ + --actor-num-nodes 4 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${POLARIS_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${DISTRIBUTED_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${PRECISE_ARGS[@]} \ + ${TENSORBOARD_ARGS[@]} \ + $@ + +# you can add more args to the script +# bash ./run-InfiAlign-SFT-Qwen-7B-165K-2node-rl.sh --num-rollout 20480 --rollout-batch-size 256 +# even in a sbatch script: +# sbatch --nodes=2 submit_4node_rl.sh ./run-InfiAlign-SFT-Qwen-7B-165K-2node-rl.sh --actor-num-nodes 2 + + diff --git a/examples/train_infer_mismatch_helper/mis.py b/examples/train_infer_mismatch_helper/mis.py new file mode 100644 index 000000000..6514e8a09 --- /dev/null +++ b/examples/train_infer_mismatch_helper/mis.py @@ -0,0 +1,328 @@ +from typing import Any, Dict, Optional, Tuple + +import torch + +from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp + + +def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + result = (x * loss_mask).sum() + return result.expand_as(x) if expand else result + + +def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1) + return result.expand_as(x) if expand else result + + +def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: + """ + + Every metrics-dict value is a list of 1D tensor, i.e., [torch.Tensor] with shapes exactly the same as log_probs. + + All metrics will be aggregated and averaged by `sum_of_sample_mean` and divided by DP size automatically + - If calculate_per_token_loss=False (default), the final results will first be averaged in each sequence, + then across all the sequences in the global batch. + - If calculate_per_token_loss=True, the final results will be the mean of all the tokens in the global batch. + + No need to specifically handle loss_mask, sum_of_sample_mean automatically ignores statistics where loss_mask = 0. + + e.g. + For token-level metric: + value = [ + [0.1, 0.2], + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.6] + ] + When calculate_per_token_loss = False (default): + result = (0.1 + 0.2) / 2 + (0.1 + 0.2 + 0.3 + 0.4 + 0.5) / 5 + (0.6) / 1 = 0.15 + 0.3 + 0.6 = 1.05 / 3 = 0.35 + When calculate_per_token_loss = True: + result = (0.1 + 0.2 + 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6) / 8 = 2.4 / 8 = 0.3 + For sequence-level metric: + original sequence lengths = [2, 5, 2] + We should expand the metrics to the length of each sequence: + value = [ + [2, 2], + [5, 5, 5, 5, 5], + [1, 1] + ] + When calculate_per_token_loss = False (default): + result = (2 + 2) / 2 + (5 + 5 + 5 + 5 + 5) / 5 + (1 + 1) / 2 = 2 + 5 + 1 = 8 / 3 = 2.6665 + Note that for sequence-level, calculating token-level loss is invalid; thus, calculate_per_token_loss should always be False. + """ + if key not in metrics: + metrics[key] = [] + metrics[key].append(value.clone().detach()) + + +def calculate_veto_mask( + log_ratio: torch.Tensor, + loss_mask: torch.Tensor, + veto_threshold: Optional[float], + metrics: Dict[str, list[torch.Tensor]], +) -> torch.Tensor: + if veto_threshold is None: + return torch.ones_like(log_ratio) + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio.device)) + # For each sequence, if it has any catastrophic tokens, return 0 for the sequence + catastrophic_tokens = ((log_ratio < log_veto_threshold)) & loss_mask.bool() + has_catastrophic = catastrophic_tokens.any() + veto_mask = (~has_catastrophic).float().expand_as(log_ratio) + + metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int()) + metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) + return veto_mask + + +def truncate( + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, list[torch.Tensor]], upper_bound: float +) -> torch.Tensor: + assert upper_bound is not None + metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) + return weights.clamp(0, upper_bound) * loss_mask + + +def clip( + weights: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], + lower_bound: float, + upper_bound: float, +) -> torch.Tensor: + assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound + metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) + return weights.clamp(lower_bound, upper_bound) * loss_mask + + +def mask( + weights: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], + lower_bound: float, + upper_bound: float, +) -> torch.Tensor: + assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound + metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int()) + mask = (weights >= lower_bound) & (weights <= upper_bound) + return weights * mask * loss_mask + + +def compute_mis_weights( + args, + *, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], +) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]: + """ + Compute the importance sampling (IS) weights and metrics between the inference and training engine. + Args: + train_log_probs: List of log probs from training backend. 1D tensor each. Lengths can be different. + rollout_log_probs: List of log probs from inference backend. 1D tensor each. + loss_masks: List of loss masks. 1D tensor each. + Note that for single turn RL, the loss_mask is [1] * response_length tensor for each sequence + For multi-turn RL, the tool response will be marked as 0 in the loss_mask. + + Returns: + weights: List of importance sampling weights. 1D tensor each. + metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. + """ + + level: str = args.mis_level + metrics: Dict[str, list[torch.Tensor]] = {} + + if args.mis_lower_bound is None: + return 1.0 / args.mis_upper_bound + + # Validate input lists have same length and each sequence has matching shapes + assert ( + len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) + ), f"Input lists must have the same number of sequences: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" + + for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): + assert ( + train.shape == rollout.shape == loss_mask.shape + ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}" + + SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow + all_weights = [] + + # handle each sequence independently + for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): + loss_mask = loss_mask.float() + add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) + raw_log_ratio_diff = train_log_prob - rollout_log_prob + + # level: The aggregation level for the importance sampling weights. + if level == "token": + # Per-token ratio (biased) + log_ratio_for_metrics = raw_log_ratio_diff + elif level == "sequence": + # Product of ratios (unbiased but high variance) + log_ratio_for_metrics = masked_sum(raw_log_ratio_diff, loss_mask, expand=True) + elif level == "geometric": + # Geometric mean of ratios (biased but low variance) + log_ratio_for_metrics = masked_mean(raw_log_ratio_diff, loss_mask, expand=True) + else: + raise ValueError(f"Invalid importance sampling level: {level}") + + log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) + weights = torch.exp(log_ratio_safe) + metrics_append(metrics, "mean_is_weight_before_clip", weights) + + # mask out catastrophic tokens + if args.mis_veto_threshold is not None: + veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics) + + # mode: how to handle the importance sampling weights exceeding the thresholds. + if args.mis_mode == "truncate": + # Cap the importance sampling weights at the upper threshold + # https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 + weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound) + elif args.mis_mode == "mask": + # Zero the importance sampling weights outside the [lower, upper] range. + # https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + weights = mask( + weights, + loss_mask, + metrics, + args.mis_lower_bound, + args.mis_upper_bound, + ) + elif args.mis_mode == "clip": + # Clip the importance sampling weights to the [lower, upper] range. + # Original behavior in slime. + weights = clip( + weights, + loss_mask, + metrics, + args.mis_lower_bound, + args.mis_upper_bound, + ) + else: + raise ValueError(f"Unsupported mis_mode: {args.mis_mode}") + + metrics_append(metrics, "ratio_mean_after_mis", weights) + if args.mis_veto_threshold is not None: + weights = weights * veto_mask + metrics_append(metrics, "ratio_mean_after_veto_mask", weights) + + weights = weights.detach() + all_weights.append(weights) + + return all_weights, metrics + + +def compute_mis_weights_with_cp( + args, + *, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], + **kwargs: Any, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute the importance sampling (IS) weights and metrics with context parallel. + Args: + train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. Lengths can be different. + rollout_log_probs: List of log probs from inference backend on this cp rank. 1D tensor each. + loss_masks: List of loss masks. 1D tensor each. + total_lengths: List of total lengths. + response_lengths: List of response lengths. + Returns: + is_weights: Importance sampling weights on this CP rank and flattened along dim=0. + is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. + Also flattened along dim=0. + """ + # Gather cp slice from other cp ranks + full_rollout_log_probs = [ + all_gather_with_cp(log_prob, total_length, response_length) + for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths) + ] + full_old_log_probs = [ + all_gather_with_cp(old_log_prob, total_length, response_length) + for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths) + ] + + # Main logic for is + is_weights, is_metrics = compute_mis_weights( + args=args, + train_log_probs=full_old_log_probs, + rollout_log_probs=full_rollout_log_probs, + loss_masks=loss_masks, + ) + + # Slice out the value shards for this CP rank and concat them into a 1D tensor along dim=0 for loss.py computation. + def slice_cp_and_concat( + values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] + ) -> torch.Tensor: + values = [ + # TODO: A rename of this function? + slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i]) + for i in range(len(values)) + ] + return torch.cat(values, dim=0) + + result_metrics = {} + is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths) + for key, values in is_metrics.items(): + key_name = f"mis_{key}" + values = slice_cp_and_concat(values, total_lengths, response_lengths) + result_metrics[key_name] = values + + return is_weights, result_metrics + + +def add_ppl_metrics( + train_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], +): + loss_mask = loss_mask.float() + + # 1. Training policy perplexity metrics + mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True) + training_log_ppl = -mean_log_prob_training + training_ppl = torch.exp(training_log_ppl) + metrics_append(metrics, "training_log_ppl", training_log_ppl) + metrics_append(metrics, "training_ppl", training_ppl) + + # 2. Rollout policy perplexity metrics + mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True) + rollout_log_ppl = -mean_log_prob_rollout + rollout_ppl = torch.exp(rollout_log_ppl) + metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl) + metrics_append(metrics, "rollout_ppl", rollout_ppl) + + # 3a. kl: Direct estimator for KL(π_rollout || π_training) + # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] + # Positive value means rollout policy is more confident than training policy + kl_per_token = rollout_log_prob - train_log_prob + metrics_append(metrics, "kl", kl_per_token) + + # 3b. K3 KL estimator for improved stability + # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1] + # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout + log_ratio = train_log_prob - rollout_log_prob + k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 + metrics_append(metrics, "k3_kl", k3_kl_matrix) + + # 3c. Log PPL difference (sequence-level perplexity difference) + # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training + # Since ppl = exp(-log_prob), we have: + # log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff + # Positive value means training assigns lower probability (higher PPL) than rollout + log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training + metrics_append(metrics, "log_ppl_diff", log_ppl_diff) + metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) + + # 3d. PPL ratio (how much higher is training PPL vs rollout PPL) + # For numerical stability, compute in log space using log_ppl_diff + # Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff) + ppl_ratio = torch.exp(log_ppl_diff) + metrics_append(metrics, "ppl_ratio", ppl_ratio) diff --git a/examples/train_infer_mismatch_helper/mis.yaml b/examples/train_infer_mismatch_helper/mis.yaml new file mode 100644 index 000000000..50bc9bfea --- /dev/null +++ b/examples/train_infer_mismatch_helper/mis.yaml @@ -0,0 +1,26 @@ +# Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py +use_mis: false + +# Aggregation level for importance sampling weights: +# token: per-token +# sequence: product over tokens +# geometric: geometric mean +mis_level: "token" + +# Handling mode for IS weights: +# truncate: cap to upper bound, TIS +# mask: zero outside [lower, upper], MIS +# clip: clip to [lower, upper], CIS +mis_mode: "truncate" + +# For mask or clip mode, the lower bound of the IS weights. +# For truncate mode, it will not be used. +# If not set, it will be set to 1.0 / mis_upper_bound +mis_lower_bound: 0.5 + +# For truncate, mask, or clip mode, the upper bound of the IS weights +mis_upper_bound: 2.0 + +# Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient +# Note: float number must be written with dot e.g. 1.0e-4, not 1e-4 +mis_veto_threshold: 1.0e-4 diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh new file mode 100755 index 000000000..8e6646296 --- /dev/null +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "/root/slime/scripts/models/qwen3-4B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + #--hf-checkpoint /root/Qwen3-4B-FP8 + --ref-load /root/Qwen3-4B_torch_dist + # --load /root/Qwen3-4B_slime/ + --save /root/Qwen3-4B_slime/ + --save-interval 200 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + # --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 1 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 2 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-mis + --wandb-group qwen3-4B-mis + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +CUSTOM_ARGS=( + --custom-config-path examples/train_infer_mismatch_helper/mis.yaml + --custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_with_cp +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} diff --git a/scripts/models/qwen2.5-1.5B.sh b/scripts/models/qwen2.5-1.5B.sh index b046a95c6..17472864f 100644 --- a/scripts/models/qwen2.5-1.5B.sh +++ b/scripts/models/qwen2.5-1.5B.sh @@ -12,5 +12,5 @@ MODEL_ARGS=( --rotary-base 10000 --group-query-attention --num-query-groups 2 - --vocab-size 151936 -) \ No newline at end of file + --vocab-size 152064 +) diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 161088247..8ce8f3e56 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -25,6 +25,7 @@ from .initialize import init, is_megatron_main_rank from .loss import compute_advantages_and_returns from .model import forward_only, initialize_model_and_optimizer, save, train +from .polaris_integration import apply_polaris_to_rollout_data, init_polaris_components, log_polaris_stats from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor, named_parameters @@ -108,6 +109,9 @@ def init(self, args, role, wandb_run_id, with_ref=False): ) self.prof.start() + # POLARIS components initialization + self.reward_tracker, self.dynamic_replacer = init_polaris_components(args) + Timer().start("train_wait") return start_rollout_id @@ -231,6 +235,25 @@ def train(self, rollout_id, rollout_data_ref): with timer("data_preprocess"): rollout_data = self._get_rollout_data(rollout_data_ref) + # POLARIS: Apply dynamic sampling and reward tracking + # This should be done before computing advantages_and_returns + polaris_stats = {} + if self.args.enable_polaris_dynamic_sampling or self.args.enable_polaris_reward_tracking: + with timer("polaris_processing"): + rollout_data, polaris_stats = apply_polaris_to_rollout_data( + self.args, + rollout_data, + rollout_id, + self.reward_tracker, + self.dynamic_replacer, + ) + + # If configured to skip batch when insufficient good samples (verl behavior) + if polaris_stats.get("polaris/skip_batch_due_to_insufficient_good", 0) == 1: + print("[POLARIS] Skip this batch due to insufficient medium-difficulty samples.") + Timer().start("train_wait") + return + # Create data iterator for log_probs and train. data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data) @@ -266,6 +289,10 @@ def train(self, rollout_id, rollout_data_ref): log_rollout_data(rollout_id, self.args, rollout_data) + # Log POLARIS statistics + if polaris_stats: + log_polaris_stats(rollout_id, self.args, polaris_stats) + if self.args.use_pytorch_profiler and torch.distributed.get_rank() == 0 and self.prof is not None: self.prof.step() diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index 6ee919396..6b805d61b 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -1,3 +1,5 @@ +from typing import Union + import torch import torch.distributed as dist import torch.nn.functional as F @@ -170,7 +172,22 @@ def slice_with_cp(tokens: torch.Tensor, pad_value): return torch.cat([tokens[start_1:end_1], tokens[start_2:end_2]]) -def slice_log_prob_with_cp(log_prob: list[float], total_length: int, response_length: int): +def slice_log_prob_with_cp( + log_prob: Union[list[float], torch.Tensor], + total_length: int, + response_length: int, +) -> Union[list[float], torch.Tensor]: + """ + Slice log probabilities for Context Parallel processing. + + Args: + log_prob: Log probabilities (list or tensor) of length response_length + total_length: Total sequence length (prompt + response) + response_length: Length of the response portion + + Returns: + Sliced log probabilities matching the input type + """ assert len(log_prob) == response_length cp_size = mpu.get_context_parallel_world_size() @@ -183,4 +200,9 @@ def slice_log_prob_with_cp(log_prob: list[float], total_length: int, response_le chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] chunk_2 = log_prob[logits_offset[1][0] - (prompt_length - 1) : logits_offset[1][1] - (prompt_length - 1)] - return chunk_1 + chunk_2 + + # Handle both list and tensor types + if isinstance(log_prob, list): + return chunk_1 + chunk_2 + else: + return torch.cat([chunk_1, chunk_2], dim=0) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 105d13fd2..a07ad0a79 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -30,25 +30,126 @@ def get_log_probs_and_entropy( assert logits.size(0) == 1, f"{logits.shape}" assert logits.dtype == torch.float32, f"{logits.dtype}" + def _dump_non_finite( + prefix: str, + tensor: torch.Tensor, + sample_idx: int, + total_length: int, + response_length: int, + extra: dict | None = None, + tokens: torch.Tensor | None = None, + ): + try: + mask = ~torch.isfinite(tensor) + num_bad = int(mask.sum().item()) + first_pos = mask.nonzero(as_tuple=False)[0].tolist() if num_bad > 0 else None + + tensor_cpu = tensor.detach().float().cpu() + finite_tensor_cpu = torch.where( + torch.isfinite(tensor_cpu), tensor_cpu, torch.zeros_like(tensor_cpu) + ) + max_abs = finite_tensor_cpu.abs().max().item() + + token_stats = None + if tokens is not None and tokens.numel() > 0: + tokens_cpu = tokens.detach().long().cpu() + token_stats = { + "min": int(tokens_cpu.min().item()), + "max": int(tokens_cpu.max().item()), + "mean": float(tokens_cpu.float().mean().item()), + } + + rank = -1 + dist_module = getattr(torch, "distributed", None) + if dist_module is not None and dist_module.is_available() and dist_module.is_initialized(): + rank = dist_module.get_rank() + + payload = { + "prefix": prefix, + "sample_idx": sample_idx, + "total_length": total_length, + "response_length": response_length, + "num_bad": num_bad, + "first_bad_position": first_pos, + "max_abs": max_abs, + "extra": extra, + "token_stats": token_stats, + } + + path = f"/tmp/stability_bad_logits_rank{rank}_sample{sample_idx}_{prefix.replace(' ', '_')}.pt" + payload["tensor"] = tensor_cpu + if tokens is not None: + payload["tokens"] = tokens.detach().cpu() + torch.save(payload, path) + print( + f"[Training Stability Debug] Saved non-finite tensor snapshot to {path} " + f"(prefix={prefix}, num_bad={num_bad}, first_bad_position={first_pos}, max_abs={max_abs})", + flush=True, + ) + except Exception as exc: # pragma: no cover - best-effort debug aid + print( + f"[Training Stability Debug] Failed to dump non-finite tensor (prefix={prefix}, sample_idx={sample_idx}): {exc}", + flush=True, + ) + logits = logits.squeeze(0) logits = logits.div(args.rollout_temperature) cp_size = mpu.get_context_parallel_world_size() - log_probs_list = [] if with_entropy: entropy_list = [] end = 0 - for tokens, total_length, response_length in zip(unconcat_tokens, total_lengths, response_lengths): + for sample_idx, (tokens, total_length, response_length) in enumerate( + zip(unconcat_tokens, total_lengths, response_lengths) + ): if cp_size == 1: end += total_length start = end - response_length logits_chunk = logits[start - 1 : end - 1] tokens_chunk = tokens[-response_length:] - log_prob, entropy = calculate_log_probs_and_entropy( - logits_chunk, tokens_chunk, mpu.get_tensor_model_parallel_group(), with_entropy=with_entropy - ) + sanitized = False + if not torch.isfinite(logits_chunk).all(): + _dump_non_finite( + "chunk_0_non_cp", + logits_chunk, + sample_idx, + total_length, + response_length, + extra={ + "start": start, + "end": end, + "path": "non_cp", + }, + tokens=tokens_chunk, + ) + logits_chunk = torch.nan_to_num( + logits_chunk, nan=0.0, posinf=1e4, neginf=-1e4 + ) + sanitized = True + + tp_world_size = max(mpu.get_tensor_model_parallel_world_size(), 1) + local_vocab_size = logits_chunk.size(-1) + global_vocab_upper_bound = local_vocab_size * tp_world_size + token_max_val = tokens_chunk.max().item() + token_min_val = tokens_chunk.min().item() + if token_max_val >= global_vocab_upper_bound or token_min_val < 0: + print( + "[Training Stability] Token index out of bounds detected " + f"(sample_idx={sample_idx}, global_vocab_upper_bound={global_vocab_upper_bound}, " + f"token_min={token_min_val}, token_max={token_max_val}, " + f"total_length={total_length}, response_length={response_length})" + ) + raise RuntimeError("Token index out of vocabulary range before log prob computation.") + + if sanitized: + log_prob = logits.new_zeros(response_length) + entropy = logits.new_zeros(response_length) if with_entropy else None + else: + log_prob, entropy = calculate_log_probs_and_entropy( + logits_chunk, tokens_chunk, mpu.get_tensor_model_parallel_group(), with_entropy=with_entropy + ) else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( @@ -66,21 +167,75 @@ def get_log_probs_and_entropy( assert logits_0.size(0) == tokens_0.size(0), f"{logits_0.size(0)} vs {tokens_0.size(0)}" assert logits_1.size(0) == tokens_1.size(0), f"{logits_1.size(0)} vs {tokens_1.size(0)}" - log_prob_0, entropy_0 = calculate_log_probs_and_entropy( - logits_0, - tokens_0, - mpu.get_tensor_model_parallel_group(), - with_entropy=with_entropy, - ) - log_prob_1, entropy_1 = calculate_log_probs_and_entropy( - logits_1, - tokens_1, - mpu.get_tensor_model_parallel_group(), - with_entropy=with_entropy, - ) - log_prob = torch.cat([log_prob_0, log_prob_1], dim=0) - if with_entropy: - entropy = torch.cat([entropy_0, entropy_1], dim=0) + sanitized = False + if not torch.isfinite(logits_0).all(): + _dump_non_finite( + "chunk_0_cp", + logits_0, + sample_idx, + total_length, + response_length, + extra={"part": 0, "path": "cp"}, + tokens=tokens_0, + ) + logits_0 = torch.nan_to_num(logits_0, nan=0.0, posinf=1e4, neginf=-1e4) + sanitized = True + if not torch.isfinite(logits_1).all(): + _dump_non_finite( + "chunk_1_cp", + logits_1, + sample_idx, + total_length, + response_length, + extra={"part": 1, "path": "cp"}, + tokens=tokens_1, + ) + logits_1 = torch.nan_to_num(logits_1, nan=0.0, posinf=1e4, neginf=-1e4) + sanitized = True + + tp_world_size = max(mpu.get_tensor_model_parallel_world_size(), 1) + local_vocab_size = logits_0.size(-1) + global_vocab_upper_bound = local_vocab_size * tp_world_size + token_max_candidates = [] + token_min_candidates = [] + if tokens_0.numel() > 0: + token_max_candidates.append(tokens_0.max()) + token_min_candidates.append(tokens_0.min()) + if tokens_1.numel() > 0: + token_max_candidates.append(tokens_1.max()) + token_min_candidates.append(tokens_1.min()) + + if token_max_candidates: + token_max = torch.stack(token_max_candidates).max().item() + token_min = torch.stack(token_min_candidates).min().item() + if token_max >= global_vocab_upper_bound or token_min < 0: + print( + "[Training Stability] Token index out of bounds detected (CP path) " + f"(sample_idx={sample_idx}, global_vocab_upper_bound={global_vocab_upper_bound}, " + f"token_min={token_min}, token_max={token_max}, " + f"total_length={total_length}, response_length={response_length})" + ) + raise RuntimeError("Token index out of vocabulary range before log prob computation.") + + if sanitized: + log_prob = logits.new_zeros(response_length) + entropy = logits.new_zeros(response_length) if with_entropy else None + else: + log_prob_0, entropy_0 = calculate_log_probs_and_entropy( + logits_0, + tokens_0, + mpu.get_tensor_model_parallel_group(), + with_entropy=with_entropy, + ) + log_prob_1, entropy_1 = calculate_log_probs_and_entropy( + logits_1, + tokens_1, + mpu.get_tensor_model_parallel_group(), + with_entropy=with_entropy, + ) + log_prob = torch.cat([log_prob_0, log_prob_1], dim=0) + if with_entropy: + entropy = torch.cat([entropy_0, entropy_1], dim=0) end += 2 * chunk_size @@ -211,7 +366,6 @@ def compute_advantages_and_returns(args, rollout_data): def policy_loss_function(args, batch, logits, sum_of_sample_mean): advantages = torch.cat(batch["advantages"], dim=0) - old_log_probs = batch["log_probs"] response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] @@ -226,7 +380,68 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): ) log_probs = log_probs_and_entropy["log_probs"] + entropy = log_probs_and_entropy["entropy"] + + # Apply high-entropy token filtering if enabled + if getattr(args, 'high_entropy_token_filter', False): + entropy_percentile = getattr(args, 'entropy_percentile', 0.2) + + # Concatenate all entropies and masks + all_entropy = torch.cat(entropy, dim=0) + loss_masks = batch["loss_masks"] + + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + all_masks = torch.cat(loss_masks) + else: + mask_chunks = [] + for i in range(len(entropy)): + total_len = total_lengths[i] + response_len = response_lengths[i] + prompt_len = total_len - response_len + _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp(total_len, response_len) + + s0, e0 = token_offsets[0] + s1, e1 = token_offsets[1] + res_s0, res_e0 = max(0, s0 - prompt_len), max(0, e0 - prompt_len) + res_s1, res_e1 = max(0, s1 - prompt_len), max(0, e1 - prompt_len) + + local_mask_parts = [] + full_mask = loss_masks[i] + if res_e0 > res_s0: + local_mask_parts.append(full_mask[res_s0:res_e0]) + if res_e1 > res_s1: + local_mask_parts.append(full_mask[res_s1:res_e1]) + + local_mask_chunk = ( + torch.cat(local_mask_parts) if local_mask_parts + else torch.tensor([], device=all_entropy.device, dtype=full_mask.dtype) + ) + mask_chunks.append(local_mask_chunk) + all_masks = torch.cat(mask_chunks) + + # Compute entropy threshold from valid tokens only + if all_masks.sum() > 0: + valid_entropy = all_entropy[all_masks.bool()] + entropy_threshold = torch.quantile(valid_entropy, 1.0 - entropy_percentile) + + # Create high-entropy mask + high_entropy_mask = (all_entropy >= entropy_threshold).float() + + # Update loss_masks + chunk_lengths = [ent.size(0) for ent in entropy] + high_entropy_chunks = list(torch.split(high_entropy_mask, chunk_lengths)) + batch["loss_masks"] = [ + loss_mask * high_entropy_chunk + for loss_mask, high_entropy_chunk in zip(loss_masks, high_entropy_chunks) + ] + + # Recompute sum_of_sample_mean with updated masks + sum_of_sample_mean = get_sum_of_sample_mean( + total_lengths, response_lengths, batch["loss_masks"], args.calculate_per_token_loss + ) + old_log_probs_flat: torch.Tensor | None = None if args.advantage_estimator == "gspo": full_log_probs = [ all_gather_with_cp(log_prob, total_length, response_length) @@ -245,30 +460,84 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): ppo_kl = [kl.expand_as(log_prob) for kl, log_prob in zip(ppo_kl, log_probs)] ppo_kl = torch.cat(ppo_kl, dim=0) log_probs = torch.cat(log_probs, dim=0) + old_log_probs_flat = torch.cat(full_old_log_probs, dim=0) else: - old_log_probs = torch.cat(batch["log_probs"], dim=0) + old_log_probs_flat = torch.cat(batch["log_probs"], dim=0) log_probs = torch.cat(log_probs, dim=0) - ppo_kl = old_log_probs - log_probs - - pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) + ppo_kl = old_log_probs_flat - log_probs + + cispo_stats: dict[str, torch.Tensor] | None = None + is_cispo = getattr(args, "policy_objective", "ppo") == "cispo" + + if is_cispo: + assert old_log_probs_flat is not None, "old log probs must be available for CISPO." + ratio = torch.exp(log_probs - old_log_probs_flat) + ratio_clipped = ratio + eps_low = getattr(args, "cispo_eps_low", 0.0) + eps_high = getattr(args, "cispo_eps_high", 2.0) + if eps_low is not None and eps_low > 0.0: + ratio_clipped = ratio_clipped.clamp_min(1.0 - eps_low) + if eps_high is not None and eps_high > 0.0: + ratio_clipped = ratio_clipped.clamp_max(1.0 + eps_high) + clip_mask = (ratio_clipped != ratio).float() + pg_loss = -(ratio_clipped.detach() * advantages) * log_probs + pg_clipfrac = clip_mask + cispo_stats = { + "ratio_mean": sum_of_sample_mean(ratio_clipped.detach()), + "ratio_max": ratio.detach().max(), + } + else: + pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) # Apply TIS off-policy correction using importance sampling if enabled + tis_metrics = {} if args.use_tis: + + def vanilla_tis_function( + args, + *, + train_log_probs, + rollout_log_probs, + **kwargs, + ): + rollout_log_probs = torch.cat(rollout_log_probs, dim=0) + old_log_probs = torch.cat(train_log_probs, dim=0) + tis = torch.exp(old_log_probs - rollout_log_probs) + tis_weights = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) + tis_clipfrac = (tis_weights != tis).float() + metrics = { + "tis": tis.clone().detach(), + "tis_clipfrac": tis_clipfrac.clone().detach(), + } + return tis_weights, metrics + assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) - old_log_probs = torch.cat(batch["log_probs"], dim=0) - tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() - tis_clip = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) - tis_clipfrac = tis_clip != tis + tis_kwargs = { + "args": args, + "train_log_probs": batch["log_probs"], + "rollout_log_probs": batch["rollout_log_probs"], + "loss_masks": batch["loss_masks"], + "total_lengths": total_lengths, + "response_lengths": response_lengths, + } + + if args.custom_tis_function_path is not None: + tis_func = load_function(args.custom_tis_function_path) + else: + tis_func = vanilla_tis_function + tis_weights, tis_metrics = tis_func(**tis_kwargs) - pg_loss = pg_loss * tis_clip + pg_loss = pg_loss * tis_weights pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) + if cispo_stats is not None: + cispo_stats["clipfrac"] = pg_clipfrac.clone().detach() + # entropy loss entropy = log_probs_and_entropy["entropy"] entropy = torch.cat(entropy, dim=0) @@ -304,9 +573,14 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): reported_loss["kl_loss"] = kl_loss.clone().detach() if args.use_tis: - reported_loss["tis"] = sum_of_sample_mean(tis).clone().detach() reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - reported_loss["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac).clone().detach() + for metric_key, metric_value in tis_metrics.items(): + reported_loss[metric_key] = sum_of_sample_mean(metric_value).clone().detach() + + if cispo_stats is not None: + reported_loss["cispo_ratio_mean"] = cispo_stats["ratio_mean"].clone().detach() + reported_loss["cispo_ratio_max"] = cispo_stats["ratio_max"].clone().detach() + reported_loss["cispo_clipfrac"] = cispo_stats["clipfrac"].clone().detach() return loss, reported_loss @@ -324,10 +598,57 @@ def sft_loss_function(args, batch, logits, sum_of_sample_mean): with_entropy=False, ) - log_probs = log_probs_and_entropy["log_probs"] - log_probs = torch.cat(log_probs, dim=0) + log_probs_list = log_probs_and_entropy["log_probs"] + + non_finite_info = [] + for sample_idx, (log_prob_tensor, resp_len, tot_len, loss_mask) in enumerate( + zip( + log_probs_list, + batch.get("response_lengths", []), + batch.get("total_lengths", []), + batch.get("loss_masks", []), + ) + ): + if not torch.isfinite(log_prob_tensor).all(): + non_finite_values = log_prob_tensor[~torch.isfinite(log_prob_tensor)] + non_finite_info.append( + { + "idx": sample_idx, + "count": non_finite_values.numel(), + "min": non_finite_values.min().item() + if non_finite_values.numel() > 0 + else float("nan"), + "max": non_finite_values.max().item() + if non_finite_values.numel() > 0 + else float("nan"), + "response_length": resp_len, + "total_length": tot_len, + "loss_mask_sum": loss_mask.sum().item() + if isinstance(loss_mask, torch.Tensor) + else None, + } + ) + + log_probs = torch.cat(log_probs_list, dim=0) loss = -sum_of_sample_mean(log_probs) + if not torch.isfinite(loss) or non_finite_info: + loss_mask_sums = ( + [mask.sum().item() for mask in batch.get("loss_masks", [])] + if isinstance(batch.get("loss_masks", None), list) + else None + ) + print( + "[Training Stability] Non-finite loss detected. " + f"loss={loss}, " + f"num_log_probs={log_probs.numel()}, " + f"loss_mask_sums={loss_mask_sums}, " + f"response_lengths={batch.get('response_lengths', None)}, " + f"total_lengths={batch.get('total_lengths', None)}, " + f"non_finite_info={non_finite_info}" + ) + raise RuntimeError("Encountered non-finite loss during SFT training.") + # make sure the gradient could backprop correctly. if log_probs.numel() == 0: loss += 0 * logits.sum() diff --git a/slime/backends/megatron_utils/polaris_integration.py b/slime/backends/megatron_utils/polaris_integration.py new file mode 100644 index 000000000..0cf8423b8 --- /dev/null +++ b/slime/backends/megatron_utils/polaris_integration.py @@ -0,0 +1,377 @@ +""" +Integration module for POLARIS features in Megatron actor. + +This module provides functions to integrate POLARIS dynamic sampling +and reward tracking into the training loop. +""" + +from pathlib import Path +import numbers +import math +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from megatron.core import mpu + +from slime.utils.polaris_utils import ( + DynamicSampleReplacer, + RewardTracker, + aggregate_rewards_per_prompt, + extract_sample_indices, + replace_samples_in_rollout, +) + + +def init_polaris_components(args): + """ + Initialize POLARIS components (reward tracker and dynamic replacer). + + Args: + args: Training arguments + + Returns: + Tuple of (reward_tracker, dynamic_replacer) + """ + # Initialize reward tracker + if args.enable_polaris_reward_tracking: + if args.polaris_reward_tracking_dir is None: + # Default to data directory + data_path = getattr(args, "rollout_data_path", None) or getattr(args, "prompt_data", None) + if data_path: + tracking_dir = str(Path(data_path).parent) + else: + tracking_dir = "polaris_tracking" + else: + tracking_dir = args.polaris_reward_tracking_dir + + experiment_name = ( + getattr(args, "wandb_name", None) + or getattr(args, "wandb_group", None) + or "experiment" + ) + + reward_tracker = RewardTracker( + save_dir=tracking_dir, + experiment_name=experiment_name, + enabled=True, + ) + else: + reward_tracker = RewardTracker( + save_dir="", + experiment_name="", + enabled=False, + ) + + # Initialize dynamic sample replacer + if args.enable_polaris_dynamic_sampling: + dynamic_replacer = DynamicSampleReplacer( + enabled=True, + good_reward_range=(args.polaris_good_reward_min, args.polaris_good_reward_max), + min_good_ratio=args.polaris_min_good_ratio, + verbose=args.polaris_verbose, + ) + else: + dynamic_replacer = DynamicSampleReplacer( + enabled=False, + ) + + return reward_tracker, dynamic_replacer + + +def extract_rewards_from_rollout_data( + rollout_data: Dict, + n_samples_per_prompt: int = 1, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Extract and aggregate rewards from rollout data. + + Args: + rollout_data: Dictionary containing rollout data + n_samples_per_prompt: Number of samples per prompt (rollout_n) + + Returns: + per_rollout_rewards: Rewards for each rollout, shape (batch_size * n_samples,) + per_prompt_rewards: Average reward per prompt, shape (batch_size,) + """ + # Prefer raw reward when shape matches local samples; otherwise fall back to rewards. + rewards = None + local_count = len(rollout_data.get("tokens", [])) + if "raw_reward" in rollout_data: + rr = rollout_data["raw_reward"] + if isinstance(rr, (list, tuple)) and len(rr) == local_count: + rewards = rr + if rewards is None and "rewards" in rollout_data: + rewards = rollout_data["rewards"] + + if rewards is not None: + if isinstance(rewards, list): + first = rewards[0] + if isinstance(first, torch.Tensor): + per_rollout_rewards = np.array([ + r.item() if r.numel() == 1 else r.sum().item() for r in rewards + ]) + else: + per_rollout_rewards = np.array(rewards, dtype=float) + elif isinstance(rewards, torch.Tensor): + per_rollout_rewards = rewards.detach().cpu().numpy() + else: + per_rollout_rewards = np.array(rewards, dtype=float) + else: + num_samples = len(rollout_data.get("tokens", [])) + per_rollout_rewards = np.zeros(num_samples) + + # Aggregate to per-prompt rewards + per_prompt_rewards = aggregate_rewards_per_prompt(per_rollout_rewards, n_samples_per_prompt) + + return per_rollout_rewards, per_prompt_rewards + + +def apply_polaris_to_rollout_data( + args, + rollout_data: Dict, + rollout_id: int, + reward_tracker: RewardTracker, + dynamic_replacer: DynamicSampleReplacer, +) -> Tuple[Dict, Dict]: + """Apply POLARIS features to rollout data before training.""" + is_controller = mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage() + + polaris_stats: Dict = {} + replacement_plan: Optional[Dict[str, List[int]]] = None + + if is_controller: + n_samples_per_prompt = getattr(args, "n_samples_per_prompt", 1) + per_rollout_rewards, per_prompt_rewards = extract_rewards_from_rollout_data( + rollout_data, + n_samples_per_prompt=n_samples_per_prompt, + ) + + polaris_stats["polaris/mean_reward"] = per_prompt_rewards.mean() + polaris_stats["polaris/std_reward"] = per_prompt_rewards.std() + polaris_stats["polaris/reward_0_count"] = (per_prompt_rewards == 0).sum() + polaris_stats["polaris/reward_1_count"] = (per_prompt_rewards == 1).sum() + polaris_stats["polaris/reward_mid_count"] = ( + (per_prompt_rewards > 0) & (per_prompt_rewards < 1) + ).sum() + + # Only one writer across DP and CP: DP(rank==0, with_context_parallel=True) and CP(rank==0 if available) + dp_rank_with_cp = mpu.get_data_parallel_rank(with_context_parallel=True) + is_dp0_with_cp = dp_rank_with_cp == 0 + cp_rank_fn = getattr(mpu, "get_context_parallel_rank", None) + current_cp_rank = cp_rank_fn() if cp_rank_fn is not None else None + is_cp0 = True if cp_rank_fn is None else (current_cp_rank == 0) + if reward_tracker.enabled: + # Build per-prompt indices aligned with per_prompt_rewards + num_samples_local = len(rollout_data.get("tokens", [])) + rollout_n = n_samples_per_prompt + if ( + "sample_indices" in rollout_data + and isinstance(rollout_data["sample_indices"], list) + and len(rollout_data["sample_indices"]) == num_samples_local + and rollout_n > 0 + and num_samples_local % rollout_n == 0 + ): + # Take the first sample index of each prompt group + per_prompt_indices = rollout_data["sample_indices"][::rollout_n] + else: + # Fallback to sequential indices + per_prompt_indices = list(range(len(per_prompt_rewards))) + + tracker_payload = { + "indices": [int(idx) for idx in per_prompt_indices], + "scores": [float(score) for score in per_prompt_rewards.tolist()], + } + dp_src_with_cp = mpu.get_data_parallel_src_rank(with_context_parallel=True) + dp_world_size_with_cp = mpu.get_data_parallel_world_size(with_context_parallel=True) + gathered_tracker_payloads = ( + [None] * dp_world_size_with_cp if dp_rank_with_cp == dp_src_with_cp else None + ) + + dist.gather_object( + tracker_payload, + gathered_tracker_payloads, + dst=dp_src_with_cp, + group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), + ) + + if is_dp0_with_cp and is_cp0 and dp_rank_with_cp == dp_src_with_cp: + combined_indices: list[int] = [] + combined_scores: list[float] = [] + seen_scores: dict[int, list[float]] = {} + for entry in gathered_tracker_payloads or []: + if not entry: + continue + entry_indices = entry.get("indices", []) + entry_scores = entry.get("scores", []) + for idx, score in zip(entry_indices, entry_scores): + score_history = seen_scores.setdefault(idx, []) + if any(math.isclose(score, logged, rel_tol=1e-6, abs_tol=1e-6) for logged in score_history): + continue + score_history.append(score) + combined_indices.append(idx) + combined_scores.append(score) + + reward_tracker.log_batch_rewards( + sample_indices=combined_indices, + rewards=np.array(combined_scores, dtype=float), + rollout_id=rollout_id, + ) + tracker_stats = reward_tracker.get_statistics() + polaris_stats["polaris/tracker_total_batches"] = tracker_stats["total_batches"] + polaris_stats["polaris/tracker_total_samples"] = tracker_stats["total_samples"] + + if dynamic_replacer.enabled: + ( + rollout_data, + modified_per_prompt_rewards, + replacement_stats, + replacement_plan, + ) = dynamic_replacer.replace_samples( + rollout_data=rollout_data, + rewards=per_prompt_rewards, + rollout_n=n_samples_per_prompt, + ) + + polaris_stats.update({ + f"polaris/replacer_{k}": v for k, v in replacement_stats.items() + }) + for bool_key in ("polaris/replacer_enabled", "polaris/replacer_replaced"): + if bool_key in polaris_stats: + polaris_stats[bool_key] = 1.0 if polaris_stats[bool_key] else 0.0 + + # Optionally skip this batch to align with verl when insufficient good samples + if ( + not replacement_stats.get("replaced", False) + and replacement_stats.get("reason") == "insufficient_good_samples" + and getattr(args, "polaris_skip_batch_when_insufficient", False) + ): + # Mark a flag so the caller can choose to skip training on this batch. + polaris_stats["polaris/skip_batch_due_to_insufficient_good"] = 1 + # No replacement plan should be applied across ranks in this case. + replacement_plan = None + + if replacement_stats.get("replaced", False): + polaris_stats["polaris/mean_reward_after"] = modified_per_prompt_rewards.mean() + polaris_stats["polaris/std_reward_after"] = modified_per_prompt_rewards.std() + + replacer_stats = dynamic_replacer.get_statistics() + polaris_stats["polaris/replacer_total_calls"] = replacer_stats["total_calls"] + polaris_stats["polaris/replacer_total_replacements"] = replacer_stats["total_replacements"] + polaris_stats["polaris/replacer_rate"] = replacer_stats["replacement_rate"] + + mp_group = mpu.get_model_parallel_group() + if mp_group is not None and dist.get_world_size(group=mp_group) > 1: + plan_buffer = [replacement_plan] + dist.broadcast_object_list(plan_buffer, src=mpu.get_model_parallel_src_rank(), group=mp_group) + replacement_plan = plan_buffer[0] + + if replacement_plan: + rollout_data = replace_samples_in_rollout( + rollout_data, + replacement_plan["bad_indices"], + replacement_plan["chosen_indices"], + ) + + return rollout_data, polaris_stats if is_controller else {} + + +def log_polaris_stats(rollout_id, args, polaris_stats): + """ + Log POLARIS statistics to console and wandb. + + Args: + rollout_id: Current rollout/step ID + args: Training arguments + polaris_stats: Dictionary of POLARIS statistics + """ + payload = polaris_stats if polaris_stats else {"_polaris_empty": True} + + # Only log from main rank + if mpu.get_data_parallel_rank(with_context_parallel=True) == 0: + # Gather statistics across data parallel ranks if needed + gathered_stats = [None] * mpu.get_data_parallel_world_size(with_context_parallel=True) + dist.gather_object( + payload, + gathered_stats, + dst=mpu.get_data_parallel_src_rank(with_context_parallel=True), + group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), + ) + + if mpu.get_data_parallel_rank(with_context_parallel=True) == 0: + # Average statistics while tolerating ranks without this key. + valid_stats = [ + s for s in gathered_stats if isinstance(s, dict) and not s.get("_polaris_empty", False) + ] + if not valid_stats: + return + + rank_count = len(valid_stats) + + dp_world_size_with_cp = mpu.get_data_parallel_world_size(with_context_parallel=True) + dp_world_size_without_cp = mpu.get_data_parallel_world_size(with_context_parallel=False) + + cp_world_size_fn = getattr(mpu, "get_context_parallel_world_size", None) + cp_world_size = cp_world_size_fn() if cp_world_size_fn is not None else 1 + + expected_controller_count = dp_world_size_with_cp + assert rank_count == expected_controller_count, ( + f"Missing POLARIS stats: expected {expected_controller_count} controller reports, " + f"got {rank_count}. This may indicate a crash, early exit, or communication issue in one or more ranks." + ) + + averaged_stats = {} + all_keys = set().union(*(s.keys() for s in valid_stats)) + for key in all_keys: + values = [s[key] for s in valid_stats if key in s] + if not values: + continue + + if all(isinstance(v, numbers.Number) for v in values): + averaged_stats[key] = sum(values) / len(values) + else: + averaged_stats[key] = values[0] + + averaged_stats["polaris/dp_world_size"] = dp_world_size_with_cp + averaged_stats["polaris/dp_world_size_without_cp"] = dp_world_size_without_cp + averaged_stats["polaris/cp_world_size"] = cp_world_size + reward_bucket_keys = [ + "polaris/reward_0_count", + "polaris/reward_mid_count", + "polaris/reward_1_count", + ] + if all(key in averaged_stats for key in reward_bucket_keys): + reward_0_avg = averaged_stats["polaris/reward_0_count"] + reward_mid_avg = averaged_stats["polaris/reward_mid_count"] + reward_1_avg = averaged_stats["polaris/reward_1_count"] + batch_total_prompts = (reward_0_avg + reward_mid_avg + reward_1_avg) * dp_world_size_without_cp + averaged_stats["polaris/batch_total_prompts"] = batch_total_prompts + averaged_stats["polaris/batch_solve_none_total"] = reward_0_avg * dp_world_size_without_cp + averaged_stats["polaris/batch_solve_partial_total"] = reward_mid_avg * dp_world_size_without_cp + averaged_stats["polaris/batch_solve_all_total"] = reward_1_avg * dp_world_size_without_cp + if "polaris/replacer_replaced" in averaged_stats: + avg_replaced = averaged_stats["polaris/replacer_replaced"] + successful_ranks = avg_replaced * rank_count + averaged_stats["polaris/replacer_successful_ranks"] = successful_ranks + averaged_stats["polaris/replacer_total_ranks"] = rank_count + averaged_stats["polaris/replacer_success_rate"] = ( + successful_ranks / rank_count if rank_count > 0 else 0.0 + ) + print(f"POLARIS stats {rollout_id}: {averaged_stats}") + + if args.use_wandb: + import wandb + averaged_stats["rollout/step"] = ( + rollout_id + if not args.wandb_always_use_train_step + else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size + ) + wandb.log(averaged_stats) + else: + dist.gather_object( + payload, + None, + dst=mpu.get_data_parallel_src_rank(with_context_parallel=True), + group=mpu.get_data_parallel_group_gloo(with_context_parallel=True), + ) \ No newline at end of file diff --git a/slime/ray/buffer.py b/slime/ray/buffer.py index ff5ae6206..91ac8a479 100644 --- a/slime/ray/buffer.py +++ b/slime/ray/buffer.py @@ -121,6 +121,72 @@ def _convert_samples_to_train_data(self, samples: Union[list[Sample], list[list[ assert len(raw_rewards) == len(samples) assert len(rewards) == len(samples) + filtered_samples: list[Sample] = [] + filtered_raw_rewards: list[float] = [] + filtered_rewards: list[float] = [] + filtered_dataset_indices: list[int | None] = [] + dropped = 0 + + for sample, raw_reward, reward in zip(samples, raw_rewards, rewards): + response_len = int(sample.response_length) + + if response_len <= 0: + dropped += 1 + continue + + if sample.loss_mask is None: + sample.loss_mask = [1] * response_len + + if len(sample.loss_mask) != response_len: + if len(sample.loss_mask) > response_len and response_len > 0: + sample.loss_mask = sample.loss_mask[:response_len] + else: + dropped += 1 + continue + + idx = None + if sample.metadata and isinstance(sample.metadata, dict): + value = sample.metadata.get("dataset_index") + if isinstance(value, int): + idx = value + + filtered_samples.append(sample) + filtered_raw_rewards.append(raw_reward) + filtered_rewards.append(reward) + filtered_dataset_indices.append(idx) + + if dropped > 0: + print( + f"[RolloutController] Dropped {dropped} invalid samples " + "(response_length<=0 or mismatched loss_mask)." + ) + + if not filtered_samples: + raise RuntimeError("No valid samples remained after filtering invalid rollout entries.") + + samples = filtered_samples + raw_rewards = filtered_raw_rewards + rewards = filtered_rewards + dataset_indices = filtered_dataset_indices + + global_bs = getattr(self.args, "global_batch_size", None) + if global_bs: + valid_len = len(samples) + trimmed_len = (valid_len // global_bs) * global_bs + if trimmed_len == 0: + raise RuntimeError( + f"Filtered rollout batch ({valid_len} samples) smaller than global batch size {global_bs}." + ) + if trimmed_len != valid_len: + drop_tail = valid_len - trimmed_len + print( + f"[RolloutController] Trimmed {drop_tail} samples to keep batch size divisible by global_batch_size ({global_bs})." + ) + samples = samples[:trimmed_len] + raw_rewards = raw_rewards[:trimmed_len] + rewards = rewards[:trimmed_len] + dataset_indices = dataset_indices[:trimmed_len] + train_data = { "tokens": [sample.tokens for sample in samples], "response_lengths": [sample.response_length for sample in samples], @@ -129,7 +195,9 @@ def _convert_samples_to_train_data(self, samples: Union[list[Sample], list[list[ "rewards": rewards, "raw_reward": raw_rewards, "truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples], - "sample_indices": [sample.index for sample in samples], + "sample_indices": [ + dataset_indices[i] if dataset_indices[i] is not None else samples[i].index for i in range(len(samples)) + ], } # loss mask diff --git a/slime/rollout/rm_hub/__init__.py b/slime/rollout/rm_hub/__init__.py index 00d653665..7e40e1e7b 100644 --- a/slime/rollout/rm_hub/__init__.py +++ b/slime/rollout/rm_hub/__init__.py @@ -30,7 +30,9 @@ async def async_rm(args, sample: Sample, **kwargs): if args.custom_rm_path is not None: rm_function = load_function(args.custom_rm_path) return await rm_function(args, sample, **kwargs) - + # Check if this is evaluation mode - disable length penalty for evaluation + is_evaluation = kwargs.get('evaluation', False) + rm_type = args.rm_type response = sample.response label = sample.label @@ -43,7 +45,7 @@ async def async_rm(args, sample: Sample, **kwargs): if rm_type == "remote_rm": return await remote_rm(args, sample) elif rm_type == "deepscaler": - return get_deepscaler_rule_based_reward(response, label) + return get_deepscaler_rule_based_reward(response, label, args=args, sample=sample, evaluation=is_evaluation) elif rm_type == "dapo": return compute_score_dapo(response, label) elif rm_type == "math": diff --git a/slime/rollout/rm_hub/deepscaler.py b/slime/rollout/rm_hub/deepscaler.py index 39d4de383..54b56bb1a 100644 --- a/slime/rollout/rm_hub/deepscaler.py +++ b/slime/rollout/rm_hub/deepscaler.py @@ -1,13 +1,13 @@ from .math_utils import extract_answer, grade_answer_mathd, grade_answer_sympy -def get_deepscaler_rule_based_reward(response, label): +def get_deepscaler_rule_based_reward(response, label, args=None, sample=None, evaluation=False): if "" in response: model_solution = response.split("")[-1] elif "###Response" in response: model_solution = response.split("###Response")[1] else: - return 0 + model_solution = response model_answer = extract_answer(model_solution) if model_answer is None: @@ -34,9 +34,61 @@ def get_deepscaler_rule_based_reward(response, label): return 0 # Check against all possible correct answers + base_reward = 0 for ground_truth in processed_ground_truths: is_correct = grade_answer_mathd(model_answer, ground_truth) or grade_answer_sympy(model_answer, ground_truth) if is_correct: - return 1 + base_reward = 1 + break + # Apply response length penalty if enabled (but not during evaluation) + final_reward = base_reward + if (args and sample and hasattr(args, 'enable_length_penalty') and args.enable_length_penalty + and not evaluation): # Skip penalty during evaluation + length_penalty = _compute_length_penalty_deepscaler(args, sample) + final_reward = base_reward + length_penalty + + # Debug output when penalty is applied + if length_penalty != 0: + print(f"[DEEPSCALER REWARD DEBUG] Base: {base_reward}, Length penalty: {length_penalty:.3f}, Final: {final_reward:.3f}") + elif evaluation and args and hasattr(args, 'enable_length_penalty') and args.enable_length_penalty: + print(f"[DEEPSCALER EVAL] Length penalty disabled for evaluation - Base reward only: {base_reward}") + + return final_reward + +def _compute_length_penalty_deepscaler(args, sample) -> float: + """Compute response length penalty for deepscaler (same as DAPO). + + Args: + args: Configuration arguments + sample: Sample object containing response information + + Returns: + Length penalty (non-positive value) + """ + if not sample or not hasattr(sample, 'response_length'): + return 0.0 + + if not args.max_response_length: + return 0.0 + + response_length = sample.response_length + max_length = args.max_response_length + buffer_length = getattr(args, 'length_penalty_buffer', 1024) + penalty_factor = getattr(args, 'length_penalty_factor', 1.0) + + # Calculate expected length (same logic as DAPO) + expected_length = max_length - buffer_length + + if response_length <= expected_length: + return 0.0 # No penalty for responses within expected length + + # Calculate penalty for responses exceeding expected length + exceed_length = response_length - expected_length + raw_penalty = -exceed_length / buffer_length * penalty_factor + + # Limit penalty to prevent extremely negative rewards + # Max penalty should not exceed the base reward magnitude + max_penalty = -1.0 + length_penalty = max(raw_penalty, max_penalty) - return 0 + return length_penalty diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 76b3a091b..e57e90540 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -167,7 +167,7 @@ async def generate_and_rm(args, sample: Sample, sampling_params: dict, evaluatio # for multi agent system, the reward of some sample is calculated during generation. samples_need_reward = [sample for sample in samples if sample.reward is None] - rewards = await batched_async_rm(args, samples_need_reward) + rewards = await batched_async_rm(args, samples_need_reward, evaluation=evaluation) for sample, reward in zip(samples_need_reward, rewards): sample.reward = reward return samples @@ -175,7 +175,7 @@ async def generate_and_rm(args, sample: Sample, sampling_params: dict, evaluatio if sample.status == Sample.Status.ABORTED: return sample - sample.reward = await async_rm(args, sample) + sample.reward = await async_rm(args, sample, evaluation=evaluation) return sample @@ -192,7 +192,7 @@ async def generate_and_rm_group(args, group: list[Sample], sampling_params: dict # for the rm that need the whole group, we will not do the rm here if not state.aborted and args.group_rm: - rewards = await batched_async_rm(args, group) + rewards = await batched_async_rm(args, group, evaluation=evaluation) for sample, reward in zip(group, rewards): sample.reward = reward diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index cfa6ff25a..2ec18b4ad 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1,6 +1,8 @@ import os from typing import Any, Dict +import yaml + from transformers import AutoConfig from slime.backends.sglang_utils.arguments import add_sglang_arguments @@ -345,6 +347,15 @@ def add_data_arguments(parser): "the input should be the same structure as an openai message, e.g. [\{'role': 'user', 'content': 'blabla'\}]. " ), ) + parser.add_argument( + "--rollout-data-path", + type=str, + default=None, + help=( + "Alias for --prompt-data kept for POLARIS examples. " + "When provided, sets the prompt dataset used for rollouts." + ), + ) parser.add_argument("--apply-chat-template", action="store_true", default=False) parser.add_argument("--input-key", type=str, default="input", help="JSON dataset key") parser.add_argument("--label-key", type=str, default=None, help="JSON dataset key") @@ -528,6 +539,13 @@ def add_algo_arguments(parser): "if custom_loss is set, we will use the function path from `--custom-loss-function-path`." ), ) + parser.add_argument( + "--policy-objective", + type=str, + choices=["ppo", "cispo"], + default="ppo", + help="Policy gradient objective. PPO applies token clipping; CISPO clips importance weights.", + ) parser.add_argument( "--custom-loss-function-path", type=str, @@ -572,6 +590,38 @@ def add_algo_arguments(parser): parser.add_argument("--entropy-coef", type=float, default=0.0, help="Entropy loss coef") parser.add_argument("--gamma", type=float, default=1.0, help="Discount factor for rewards in REINFORCE++.") parser.add_argument("--normalize-advantages", action="store_true", default=False) + parser.add_argument( + "--high-entropy-token-filter", + action="store_true", + default=False, + help=( + "Whether to apply policy gradient updates only to high-entropy tokens. " + "Inspired by 'Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective RL for LLM Reasoning'. " + "This focuses training on 'forking tokens' that steer reasoning directions." + ), + ) + parser.add_argument( + "--entropy-percentile", + type=float, + default=0.2, + help=( + "The percentile of highest-entropy tokens to retain for gradient updates when --high-entropy-token-filter is enabled. " + "Default 0.2 means only the top 20%% highest-entropy tokens will receive gradients. " + "According to the paper, 20%% achieves optimal balance between exploration and performance." + ), + ) + parser.add_argument( + "--cispo-eps-high", + type=float, + default=2.0, + help="Upper bound ε_high for CISPO importance-weight clipping (ratio limited to 1 + ε_high).", + ) + parser.add_argument( + "--cispo-eps-low", + type=float, + default=0.0, + help="Lower bound ε_low for CISPO importance-weight clipping (set ≤0 to disable the lower clamp).", + ) parser.add_argument( "--disable-grpo-std-normalization", action="store_false", @@ -612,6 +662,12 @@ def add_algo_arguments(parser): default=0, help="Lower bound clipping threshold C for importance sampling ratios to control variance.", ) + parser.add_argument( + "--custom-tis-function-path", + type=str, + default=None, + help="Path to the custom TIS function.", + ) return parser # wandb @@ -635,6 +691,12 @@ def add_wandb_arguments(parser): parser.add_argument("--wandb-host", type=str, default=None) parser.add_argument("--wandb-team", type=str, default=None) parser.add_argument("--wandb-group", type=str, default=None) + parser.add_argument( + "--wandb-name", + type=str, + default=None, + help="Alias for --wandb-group preserved for POLARIS examples.", + ) reset_arg(parser, "--wandb-project", type=str, default=None) parser.add_argument( "--disable-wandb-random-suffix", @@ -781,6 +843,32 @@ def add_reward_model_arguments(parser): "Path to the custom function that will post process reward, by default it will be the normalization for grpo. " ), ) + # Response length penalty arguments + parser.add_argument( + "--enable-length-penalty", + action="store_true", + default=False, + help="Enable response length truncation penalty (similar to verl DAPO)", + ) + parser.add_argument( + "--max-response-length", + type=int, + default=None, + help="Maximum allowed response length before applying penalty", + ) + parser.add_argument( + "--length-penalty-buffer", + type=int, + default=50, + help="Buffer length for gradual penalty application", + ) + parser.add_argument( + "--length-penalty-factor", + type=float, + default=1.0, + help="Penalty factor for response length truncation (higher = more penalty)", + + ) return parser def add_rollout_buffer_arguments(parser): @@ -812,7 +900,7 @@ def add_rollout_buffer_arguments(parser): "--loss-mask-type", type=str, default="qwen", - choices=["qwen", "distill_qwen"], + choices=["qwen", "qwen3", "distill_qwen"], help="Loss mask type", ) return parser @@ -847,6 +935,77 @@ def add_ci_arguments(parser): ) return parser + def add_polaris_arguments(parser): + """ + Add POLARIS-style training tricks arguments. + Implements dynamic sampling and reward tracking from POLARIS paper. + """ + parser.add_argument( + "--enable-polaris-dynamic-sampling", + action="store_true", + default=False, + help=( + "Enable POLARIS dynamic sample replacement. " + "This replaces trivial samples (reward=0 or 1) with medium-difficulty ones " + "during training to improve data quality and training efficiency." + ), + ) + parser.add_argument( + "--polaris-good-reward-min", + type=float, + default=0.0, + help="Minimum reward (exclusive) for 'good' samples in dynamic sampling.", + ) + parser.add_argument( + "--polaris-good-reward-max", + type=float, + default=1.0, + help="Maximum reward (exclusive) for 'good' samples in dynamic sampling.", + ) + parser.add_argument( + "--polaris-min-good-ratio", + type=float, + default=0.33, + help=( + "Minimum ratio of good samples required to perform replacement. " + "If less than this ratio, skip replacement and print warning." + ), + ) + parser.add_argument( + "--enable-polaris-reward-tracking", + action="store_true", + default=False, + help=( + "Enable reward tracking to JSONL files for post-training analysis. " + "This enables difficulty-based data filtering between training stages." + ), + ) + parser.add_argument( + "--polaris-reward-tracking-dir", + type=str, + default=None, + help=( + "Directory to save reward tracking files. " + "Defaults to the same directory as the training data." + ), + ) + parser.add_argument( + "--polaris-verbose", + action="store_true", + default=True, + help="Print verbose information about POLARIS operations.", + ) + parser.add_argument( + "--polaris-skip-batch-when-insufficient", + action="store_true", + default=False, + help=( + "Skip the current batch when insufficient medium-difficulty samples are available " + "(i.e., good ratio <= min_good_ratio). This matches verl's behavior." + ), + ) + return parser + # Add custom arguments in front to prevent overwritten some slime arguments. if add_custom_arguments is not None: parser = add_custom_arguments(parser) @@ -862,11 +1021,18 @@ def add_ci_arguments(parser): parser = add_network_arguments(parser) parser = add_reward_model_arguments(parser) parser = add_rollout_buffer_arguments(parser) + parser = add_polaris_arguments(parser) parser = add_ci_arguments(parser) # For megatron parser = add_custom_megatron_plugins_arguments(parser) try: + parser.add_argument( + "--custom-config-path", + type=str, + default=None, + help="Path to the YAML config for custom function arguments.", + ) parser.add_argument("--padded-vocab-size", type=int, default=None) except: pass @@ -941,6 +1107,18 @@ def parse_args(add_custom_arguments=None): def slime_validate_args(args): + # Harmonize backward compatibility aliases used in POLARIS examples. + if getattr(args, "rollout_data_path", None) and getattr(args, "prompt_data", None) is None: + args.prompt_data = args.rollout_data_path + elif getattr(args, "prompt_data", None) is not None and getattr(args, "rollout_data_path", None) is None: + args.rollout_data_path = args.prompt_data + + if getattr(args, "wandb_name", None): + if getattr(args, "wandb_group", None) is None: + args.wandb_group = args.wandb_name + else: + args.wandb_name = getattr(args, "wandb_group", None) + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") @@ -1072,6 +1250,15 @@ def slime_validate_args(args): "num_epoch is not set, but num_rollout is not set, " "please set --num-rollout or --num-epoch" ) + if getattr(args, "custom_config_path", None): + with open(args.custom_config_path, "r") as f: + data = yaml.safe_load(f) or {} + for k, v in data.items(): + if not hasattr(args, k): + setattr(args, k, v) + else: + print(f"Warning: Argument {k} is already set to {getattr(args, k)}, will not override with {v}.") + def hf_validate_args(args, hf_config): equal = lambda x, y: x == y diff --git a/slime/utils/data.py b/slime/utils/data.py index 00e7c6e06..91e5f5140 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -18,8 +18,43 @@ def read_file(path): df = pd.read_json(path, lines=True) elif path.endswith(".parquet"): df = pd.read_parquet(path, dtype_backend="pyarrow") + elif path.endswith(".json"): + # Support .json format (both list and columnar) + with open(path, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Handle different JSON formats + if isinstance(data, list): + # Standard list format: [{"key": "value"}, ...] + for item in data: + yield item + elif isinstance(data, dict): + # Columnar format: {"key": {"0": ..., "1": ...}, ...} + # or indexed format: {"0": {...}, "1": {...}} + + # Check if it's columnar format (has keys like "conversations", "problem", etc.) + first_key = list(data.keys())[0] + if isinstance(data[first_key], dict): + # Determine number of samples + num_samples = len(data[first_key]) + + # Convert columnar to row format + for idx in range(num_samples): + idx_str = str(idx) + item = {} + for key in data.keys(): + if isinstance(data[key], dict) and idx_str in data[key]: + item[key] = data[key][idx_str] + yield item + else: + # Single sample dict + yield data + else: + raise ValueError(f"Unsupported JSON structure: {type(data)}") + return else: - raise ValueError(f"Unsupported file format: {path}. Supported formats are .jsonl and .parquet.") + raise ValueError(f"Unsupported file format: {path}. Supported formats are .json, .jsonl and .parquet.") + for _, row in df.iterrows(): yield row.to_dict() @@ -58,11 +93,16 @@ def __init__( if len(tokenizer(prompt)["input_ids"]) > max_length: continue + metadata = data.get(metadata_key) or {} + if not isinstance(metadata, dict): + metadata = {} + metadata.setdefault("dataset_index", len(self.origin_samples)) + self.origin_samples.append( Sample( prompt=prompt, label=data[label_key] if label_key is not None else None, - metadata=data.get(metadata_key) or {}, + metadata=metadata, ) ) @@ -114,7 +154,7 @@ def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): data = data[0] # save the unprocessed reward for logging - rollout_data["raw_reward"] = data["raw_reward"] + # rollout_data["raw_reward"] = data["raw_reward"] total_lengths = [len(t) for t in data["tokens"]] data["total_lengths"] = total_lengths @@ -153,12 +193,13 @@ def get_partition(val): return [val[i] for i in parititions[dp_rank]] else: return val[dp_rank::dp_size] - + # add raw_reward in dp for key in [ "tokens", "total_lengths", "response_lengths", "rewards", + "raw_reward", "truncated", "loss_masks", "round_number", diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py index bddb43261..66b584c74 100644 --- a/slime/utils/mask_utils.py +++ b/slime/utils/mask_utils.py @@ -3,6 +3,10 @@ from transformers import AutoTokenizer +def get_response_lengths(loss_masks: List[List[int]]) -> List[int]: + return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks] + + class MultiTurnLossMaskGenerator: def __init__(self, tokenizer: AutoTokenizer, tokenizer_type: str = "qwen"): self.tokenizer = tokenizer @@ -10,7 +14,7 @@ def __init__(self, tokenizer: AutoTokenizer, tokenizer_type: str = "qwen"): self.tokenizer_type = tokenizer_type def get_response_lengths(self, loss_masks: List[List[int]]) -> List[int]: - return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks] + return get_response_lengths(loss_masks) def find_all_sublist_indices(self, main_list, sublist): sublist_len = len(sublist) @@ -57,6 +61,36 @@ def gen_multi_turn_loss_mask_qwen(self, messages: List[Dict]) -> Tuple[List[int] else: loss_mask = [0] * len(message_ids) + if message.get("step_loss_mask", 1) != 1: + loss_mask = [0] * len(message_ids) + + all_loss_masks.extend(loss_mask) + all_token_ids.extend(message_ids) + + return all_token_ids, all_loss_masks + + def gen_multi_turn_loss_mask_qwen3(self, messages: List[Dict]) -> Tuple[List[int], List[int]]: + all_loss_masks = [] + all_token_ids = [] + + prefix_message = {"role": "user", "content": "FOR CALCULATING LOSS MASK ONLY"} + prefix_token_ids = self.tokenizer.apply_chat_template([prefix_message], tokenize=True) + + for i, message in enumerate(messages): + prefixed_message_ids = self.tokenizer.apply_chat_template([prefix_message, message], tokenize=True) + message_ids = prefixed_message_ids[len(prefix_token_ids) :] + + if message["role"] != "system" and i > 0: + message_ids = message_ids[self.system_message_length :] + + if message["role"] == "assistant": + loss_mask = [0] * self.gen_token_length + [1] * (len(message_ids) - self.gen_token_length) + else: + loss_mask = [0] * len(message_ids) + + if message.get("step_loss_mask", 1) != 1: + loss_mask = [0] * len(message_ids) + all_loss_masks.extend(loss_mask) all_token_ids.extend(message_ids) @@ -71,6 +105,8 @@ def gen_multi_turn_loss_mask_distill_qwen(self, messages: List[Dict]) -> Tuple[L response_length = len(response_tokens) token_ids = prompt_tokens + response_tokens loss_mask = [0] * len(prompt_tokens) + [1] * response_length + if messages[-1].get("step_loss_mask", 1) != 1: + loss_mask = [0] * len(token_ids) return token_ids, loss_mask def get_loss_mask(self, messages: List[Dict]) -> List[int]: @@ -79,6 +115,8 @@ def get_loss_mask(self, messages: List[Dict]) -> List[int]: return self.gen_multi_turn_loss_mask_distill_qwen(messages) return self.gen_multi_turn_loss_mask_qwen(messages) + elif self.tokenizer_type == "qwen3": + return self.gen_multi_turn_loss_mask_qwen3(messages) elif self.tokenizer_type == "distill_qwen": return self.gen_multi_turn_loss_mask_distill_qwen(messages) else: diff --git a/slime/utils/polaris_filter_easy_data.py b/slime/utils/polaris_filter_easy_data.py new file mode 100644 index 000000000..da84b6a6e --- /dev/null +++ b/slime/utils/polaris_filter_easy_data.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +""" +Filter easy data based on reward tracking logs (POLARIS-style). + +This script reads the reward tracking JSONL files generated during training +and filters out samples with high average rewards (easy samples) from the dataset. + +Usage: + python polaris_filter_easy_data.py \ + --data-path data/train.json \ + --reward-log polaris_tracking/experiment.jsonl \ + --output data/train_filtered.json \ + --threshold 0.9 +""" + +import argparse +import json +from collections import defaultdict +from pathlib import Path + +import pandas as pd + + +def process_reward_log(jsonl_path: str, threshold: float = 0.9) -> set: + """ + Process reward tracking JSONL to identify easy samples. + + Args: + jsonl_path: Path to the reward tracking JSONL file + threshold: Average reward threshold above which samples are considered "easy" + + Returns: + Set of sample indices to remove + """ + index_to_scores = defaultdict(list) + + # Read all reward entries + with open(jsonl_path, 'r') as f: + for line in f: + entry = json.loads(line) + indices = entry['index'] + scores = entry['score'] + + for idx, score in zip(indices, scores): + index_to_scores[idx].append(score) + + # Compute average and filter + remove_indices = set() + for idx, scores in index_to_scores.items(): + avg_score = sum(scores) / len(scores) + if avg_score > threshold: + remove_indices.add(idx) + + print(f"Total unique samples: {len(index_to_scores)}") + print(f"Samples to remove (avg reward > {threshold}): {len(remove_indices)}") + print(f"Remaining samples: {len(index_to_scores) - len(remove_indices)}") + + return remove_indices + + +def filter_json_data(input_path: str, output_path: str, remove_indices: set): + """ + Filter JSON data by removing specified indices. + + Supports both columnar format (like your example) and list format. + """ + with open(input_path, 'r') as f: + data = json.load(f) + + if isinstance(data, list): + # List format: [{"problem": ..., "conversations": ...}, ...] + filtered_data = [item for i, item in enumerate(data) if i not in remove_indices] + elif isinstance(data, dict): + # Columnar format: {"problem": {"0": ..., "1": ...}, ...} + first_key = list(data.keys())[0] + if isinstance(data[first_key], dict): + # Get all indices + all_indices = set(data[first_key].keys()) + keep_indices = sorted([idx for idx in all_indices if int(idx) not in remove_indices]) + + # Rebuild columnar data with only kept indices + filtered_data = {} + for key in data.keys(): + filtered_data[key] = {} + for new_idx, old_idx in enumerate(keep_indices): + filtered_data[key][str(new_idx)] = data[key][old_idx] + else: + raise ValueError("Unexpected JSON structure") + else: + raise ValueError("Unsupported JSON format") + + # Save filtered data + with open(output_path, 'w') as f: + json.dump(filtered_data, f, indent=2, ensure_ascii=False) + + print(f"Filtered data saved to: {output_path}") + + +def filter_jsonl_data(input_path: str, output_path: str, remove_indices: set): + """Filter JSONL data by removing specified indices.""" + with open(output_path, 'w') as out_f: + with open(input_path, 'r') as in_f: + for i, line in enumerate(in_f): + if i not in remove_indices: + out_f.write(line) + + print(f"Filtered data saved to: {output_path}") + + +def filter_parquet_data(input_path: str, output_path: str, remove_indices: set): + """Filter Parquet data by removing specified indices.""" + df = pd.read_parquet(input_path) + print(f"Original dataframe size: {len(df)}") + + # Assume dataframe has an implicit index + mask = ~df.index.isin(remove_indices) + filtered_df = df[mask].reset_index(drop=True) + + print(f"Filtered dataframe size: {len(filtered_df)}") + filtered_df.to_parquet(output_path) + print(f"Filtered data saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Filter easy data based on reward tracking") + parser.add_argument( + "--data-path", + type=str, + required=True, + help="Path to input data file (.json, .jsonl, or .parquet)", + ) + parser.add_argument( + "--reward-log", + type=str, + required=True, + help="Path to reward tracking JSONL file", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Path to output filtered data file", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.9, + help="Reward threshold for filtering (default: 0.9)", + ) + + args = parser.parse_args() + + # Process reward log to get indices to remove + print(f"Processing reward log: {args.reward_log}") + remove_indices = process_reward_log(args.reward_log, args.threshold) + + # Filter data based on file format + input_path = Path(args.data_path) + print(f"\nFiltering data: {args.data_path}") + + if input_path.suffix == '.json': + filter_json_data(args.data_path, args.output, remove_indices) + elif input_path.suffix == '.jsonl': + filter_jsonl_data(args.data_path, args.output, remove_indices) + elif input_path.suffix == '.parquet': + filter_parquet_data(args.data_path, args.output, remove_indices) + else: + raise ValueError(f"Unsupported file format: {input_path.suffix}") + + print("\nFiltering complete!") + + +if __name__ == "__main__": + main() diff --git a/slime/utils/polaris_utils.py b/slime/utils/polaris_utils.py new file mode 100644 index 000000000..b40a3deac --- /dev/null +++ b/slime/utils/polaris_utils.py @@ -0,0 +1,424 @@ +""" +POLARIS utilities for dynamic sampling and reward tracking. + +This module implements the key tricks from POLARIS: +1. Reward tracking - Save reward history to JSONL for difficulty filtering +2. Dynamic sample replacement - Replace trivial samples (reward=0 or 1) with medium-difficulty ones +""" + +import json +import os +from collections import defaultdict +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + + +def _clone_sample_entry(entry): + """Clone a rollout entry so replacing one sample does not alias another.""" + if isinstance(entry, torch.Tensor): + return entry.clone() + try: + return deepcopy(entry) + except Exception: + return entry + + +def replace_samples_in_rollout( + rollout_data: Dict, + bad_indices: List[int], + chosen_indices: List[int], +) -> Dict: + """Apply a replacement plan to rollout data and return a modified copy.""" + if not bad_indices: + return rollout_data + + num_samples = len(rollout_data.get("tokens", [])) + if num_samples == 0: + return rollout_data + + modified = {} + for key, value in rollout_data.items(): + if isinstance(value, list) and len(value) == num_samples: + updated_list = list(value) + for bad_idx, chosen_idx in zip(bad_indices, chosen_indices): + updated_list[bad_idx] = _clone_sample_entry(updated_list[chosen_idx]) + modified[key] = updated_list + else: + modified[key] = value + return modified + + +class RewardTracker: + """ + Track and save reward scores for each sample during training. + + This enables post-training analysis and difficulty-based data filtering, + similar to POLARIS's drop_easy_data.py functionality. + """ + + def __init__( + self, + save_dir: str, + experiment_name: str, + enabled: bool = True, + ): + """ + Args: + save_dir: Directory to save reward tracking files + experiment_name: Name of the experiment (used as filename) + enabled: Whether to enable reward tracking + """ + self.enabled = enabled + if not self.enabled: + return + + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + self.save_path = self.save_dir / f"{experiment_name}.jsonl" + + # Track statistics + self.total_batches = 0 + self.total_samples = 0 + + def log_batch_rewards( + self, + sample_indices: List[int], + rewards: np.ndarray, + rollout_id: Optional[int] = None, + ): + """ + Log rewards for a batch of samples. + + Args: + sample_indices: List of sample indices in the dataset + rewards: Array of reward scores, shape (batch_size,) + rollout_id: Optional rollout/step ID + """ + if not self.enabled: + return + + # Convert to list if needed + if isinstance(rewards, torch.Tensor): + rewards = rewards.cpu().numpy() + if isinstance(rewards, np.ndarray): + rewards = rewards.tolist() + + # Create log entry + log_entry = { + "index": sample_indices, + "score": rewards, + } + if rollout_id is not None: + log_entry["rollout_id"] = rollout_id + + # Append to JSONL file + with open(self.save_path, "a") as f: + f.write(json.dumps(log_entry) + "\n") + + self.total_batches += 1 + self.total_samples += len(sample_indices) + + def get_statistics(self) -> Dict[str, int]: + """Get tracking statistics.""" + return { + "total_batches": self.total_batches, + "total_samples": self.total_samples, + "save_path": str(self.save_path), + } + + +class DynamicSampleReplacer: + """ + Dynamically replace trivial samples (reward=0 or 1) with medium-difficulty ones. + + This implements POLARIS's dynamic sampling trick to maintain training data quality + and avoid wasting compute on trivial samples. + """ + + def __init__( + self, + enabled: bool = True, + good_reward_range: Tuple[float, float] = (0.0, 1.0), + min_good_ratio: float = 0.33, + verbose: bool = True, + ): + """ + Args: + enabled: Whether to enable dynamic sample replacement + good_reward_range: (min, max) range for "good" samples (exclusive) + min_good_ratio: Minimum ratio of good samples required to perform replacement + verbose: Whether to print replacement information + """ + self.enabled = enabled + self.good_reward_range = good_reward_range + self.min_good_ratio = min_good_ratio + self.verbose = verbose + + # Statistics + self.total_calls = 0 + self.total_replacements = 0 + self.total_skipped = 0 + + def should_replace_batch( + self, + rewards: np.ndarray, + ) -> Tuple[bool, np.ndarray]: + """ + Determine if batch should undergo sample replacement. + + Args: + rewards: Array of reward scores, shape (batch_size,) + + Returns: + should_replace: Whether replacement should be performed + good_mask: Boolean mask indicating "good" samples + """ + if not self.enabled: + return False, np.ones(len(rewards), dtype=bool) + + # Identify "good" samples (not trivial) + good_mask = (rewards > self.good_reward_range[0]) & (rewards < self.good_reward_range[1]) + + good_count = good_mask.sum() + total_count = len(rewards) + good_ratio = good_count / total_count if total_count > 0 else 0 + + # Only replace if we have enough good samples + should_replace = good_ratio > self.min_good_ratio + + return should_replace, good_mask + + def get_replacement_indices( + self, + good_mask: np.ndarray, + rollout_n: int = 1, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Get indices for sample replacement. + + Args: + good_mask: Boolean mask indicating "good" samples (per prompt) + rollout_n: Number of rollouts per prompt + + Returns: + bad_indices: Indices of samples to be replaced (expanded for all rollouts) + chosen_indices: Indices of good samples to use as replacements + """ + # Get bad and good indices at the prompt level + bad_indices_prompt = np.where(~good_mask)[0] + good_indices_prompt = np.where(good_mask)[0] + + num_bad = len(bad_indices_prompt) + num_good = len(good_indices_prompt) + + if num_bad == 0 or num_good == 0: + return np.array([]), np.array([]) + + # Sample with replacement if necessary + if num_good >= num_bad: + chosen_prompt_indices = np.random.choice(good_indices_prompt, size=num_bad, replace=False) + else: + chosen_prompt_indices = np.random.choice(good_indices_prompt, size=num_bad, replace=True) + + # Expand to all rollouts + bad_indices = self._expand_to_rollouts(bad_indices_prompt, rollout_n) + chosen_indices = self._expand_to_rollouts(chosen_prompt_indices, rollout_n) + + return bad_indices, chosen_indices + + @staticmethod + def _expand_to_rollouts(indices: np.ndarray, rollout_n: int) -> np.ndarray: + """ + Expand prompt-level indices to include all rollouts. + + For example, if indices=[0, 2] and rollout_n=3: + - Prompt 0 has rollouts at positions [0, 1, 2] + - Prompt 2 has rollouts at positions [6, 7, 8] + - Returns: [0, 1, 2, 6, 7, 8] + """ + expanded = [] + for idx in indices: + start = idx * rollout_n + expanded.extend(range(start, start + rollout_n)) + return np.array(expanded) + + def replace_samples( + self, + rollout_data: Dict, + rewards: np.ndarray, + rollout_n: int = 1, + ) -> Tuple[Dict, np.ndarray, Dict[str, int], Optional[Dict[str, List[int]]]]: + """ + Replace trivial samples in rollout data. + + Args: + rollout_data: Dictionary containing rollout data with per-sample fields. + rewards: Array of average rewards per prompt, shape (batch_size,) + rollout_n: Number of rollouts per prompt + + Returns: + modified_rollout_data: Rollout data with bad samples replaced + modified_rewards: Updated reward array after replacement + stats: Dictionary of replacement statistics + replacement_plan: Replacement indices to replay on other ranks + """ + self.total_calls += 1 + + if not self.enabled: + return rollout_data, rewards, {"enabled": False}, None + + # Check if replacement should be performed + should_replace, good_mask = self.should_replace_batch(rewards) + + if not should_replace: + self.total_skipped += 1 + if self.verbose: + print("=" * 60) + print("[POLARIS Dynamic Sampling] Warning: Skipping replacement") + print(f" Reason: Insufficient good samples ({good_mask.sum()}/{len(good_mask)}, " + f"ratio={good_mask.sum()/len(good_mask):.2%} < {self.min_good_ratio:.2%})") + print(f" Most samples have trivial rewards (0 or 1)") + print(f" Check your data difficulty distribution!") + print("=" * 60) + return rollout_data, rewards, { + "enabled": True, + "replaced": False, + "reason": "insufficient_good_samples", + "good_count": int(good_mask.sum()), + "total_count": len(good_mask), + }, None + + # Get replacement indices + bad_indices, chosen_indices = self.get_replacement_indices(good_mask, rollout_n) + + if len(bad_indices) == 0: + return rollout_data, rewards, {"enabled": True, "replaced": False, "reason": "no_bad_samples"}, None + + num_samples_total = len(rollout_data.get("tokens", [])) + valid_mask = (bad_indices < num_samples_total) & (chosen_indices < num_samples_total) + if not valid_mask.all(): + if self.verbose: + invalid = int((~valid_mask).sum()) + print(f"[POLARIS Dynamic Sampling] Warning: discard {invalid} invalid replacement pairs") + bad_indices = bad_indices[valid_mask] + chosen_indices = chosen_indices[valid_mask] + + if len(bad_indices) == 0: + return rollout_data, rewards, {"enabled": True, "replaced": False, "reason": "invalid_indices"}, None + + # Apply replacements and recompute rewards + modified_data = replace_samples_in_rollout(rollout_data, bad_indices.tolist(), chosen_indices.tolist()) + + if isinstance(rewards, torch.Tensor): + modified_rewards = rewards.clone() + else: + modified_rewards = rewards.copy() + + reward_key = "raw_reward" if "raw_reward" in modified_data else "rewards" + reward_list = modified_data.get(reward_key) + if isinstance(reward_list, torch.Tensor): + reward_array = reward_list.detach().cpu().numpy() + elif isinstance(reward_list, (list, np.ndarray)): + reward_array = np.array([ + r if isinstance(r, (int, float)) else r.item() if hasattr(r, "item") else r + for r in reward_list + ], dtype=float) + else: + reward_array = None + + if reward_array is not None and reward_array.size >= rollout_n and reward_array.size % rollout_n == 0: + modified_rewards = reward_array.reshape(-1, rollout_n).mean(axis=1) + else: + modified_rewards = np.array(modified_rewards, dtype=float) + + self.total_replacements += 1 + + if self.verbose: + print("=" * 60) + print("[POLARIS Dynamic Sampling] Sample Replacement Performed") + print(f" Before: {rewards.tolist()}") + print(f" After: {modified_rewards.tolist()}") + print(f" Replaced {len(bad_indices)} samples ({len(bad_indices)//rollout_n} prompts × {rollout_n} rollouts)") + print("=" * 60) + + stats = { + "enabled": True, + "replaced": True, + "num_bad_samples": len(bad_indices), + "num_bad_prompts": len(bad_indices) // rollout_n, + "good_count": int(good_mask.sum()), + "total_count": len(good_mask), + } + + plan = { + "bad_indices": bad_indices.tolist(), + "chosen_indices": chosen_indices.tolist(), + "rollout_n": rollout_n, + } + + return modified_data, modified_rewards, stats, plan + + def get_statistics(self) -> Dict[str, int]: + """Get replacement statistics.""" + return { + "total_calls": self.total_calls, + "total_replacements": self.total_replacements, + "total_skipped": self.total_skipped, + "replacement_rate": self.total_replacements / self.total_calls if self.total_calls > 0 else 0, + } + + +def aggregate_rewards_per_prompt( + rewards: np.ndarray, + rollout_n: int, +) -> np.ndarray: + """ + Aggregate per-rollout rewards to per-prompt average rewards. + + Args: + rewards: Array of rewards, shape (batch_size * rollout_n,) + rollout_n: Number of rollouts per prompt + + Returns: + avg_rewards: Average reward per prompt, shape (batch_size,) + """ + return rewards.reshape(-1, rollout_n).mean(axis=1) + + +def extract_sample_indices( + rollout_data: Dict, + index_key: str = "index", +) -> List[int]: + """ + Extract sample indices from rollout data. + + Args: + rollout_data: Dictionary containing rollout data + index_key: Key for sample indices in metadata + + Returns: + sample_indices: List of sample indices + """ + + if "sample_indices" in rollout_data and isinstance(rollout_data["sample_indices"], list): + return rollout_data["sample_indices"] + + # Try to get from metadata + if "metadata" in rollout_data: + metadata = rollout_data["metadata"] + if isinstance(metadata, list): + indices = [] + for meta in metadata: + if isinstance(meta, dict) and index_key in meta: + indices.append(meta[index_key]) + else: + indices.append(-1) # Unknown index + return indices + + # If not available, use sequential indices + batch_size = len(rollout_data.get("tokens", [])) + return list(range(batch_size)) diff --git a/slime/utils/wandb_utils.py b/slime/utils/wandb_utils.py index 87ea79ded..c9c1d846e 100644 --- a/slime/utils/wandb_utils.py +++ b/slime/utils/wandb_utils.py @@ -86,6 +86,16 @@ def init_wandb_secondary(args, wandb_run_id): if (not offline) and args.wandb_key is not None: wandb.login(key=args.wandb_key, host=args.wandb_host) + # Configure settings based on offline/online mode + if offline: + settings_kwargs = dict(mode="offline") + else: + settings_kwargs = dict( + mode="shared", + x_primary=False, + x_update_finish_state=False, + ) + init_kwargs = { "id": wandb_run_id, "entity": args.wandb_team, @@ -93,17 +103,9 @@ def init_wandb_secondary(args, wandb_run_id): "config": args.__dict__, "resume": "allow", "reinit": True, + "settings": wandb.Settings(**settings_kwargs), } - # Configure settings based on offline/online mode - if offline: - init_kwargs["settings"] = wandb.Settings(mode="offline") - else: - init_kwargs["settings"] = wandb.Settings( - mode="shared", - x_primary=False, - x_update_finish_state=False, - ) # Add custom directory if specified if args.wandb_dir: diff --git a/tools/modify_choose_datav5_indexed.py b/tools/modify_choose_datav5_indexed.py new file mode 100644 index 000000000..9f16c8c02 --- /dev/null +++ b/tools/modify_choose_datav5_indexed.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +""" +基于索引的选择题提取工具(v5 Ultra 版本) +- 从带 original_index 的源文件中提取 +- 保留 original_index 以便后续追踪 +""" + +import json +import re +from pathlib import Path +from collections import Counter +from typing import Dict, Optional, Tuple, List + +# 配置 +SRC = Path('analysis/polaris-data-53K-indexed.jsonl') # 带索引的源文件 +DST = Path('analysis/polaris-data-53K__choose_ultra_indexed.jsonl') + +if not SRC.exists(): + raise SystemExit(f'Missing source file: {SRC}') + +skip_reason = Counter() + +# ==================== Pattern Matching (与 v5 相同) ==================== + +WRAPPED_PAIR = [ + re.compile(r"\\textbf\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\mathrm\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\mathbf\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\textit\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\text\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\bf\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), +] + +WRAPPED_SINGLE = [ + re.compile(r"\\textbf\s*\{\s*([A-E])\s*\}", re.IGNORECASE), + re.compile(r"\\mathrm\s*\{\s*([A-E])\s*\}", re.IGNORECASE), + re.compile(r"\\mathbf\s*\{\s*([A-E])\s*\}", re.IGNORECASE), + re.compile(r"\\textit\s*\{\s*([A-E])\s*\}", re.IGNORECASE), + re.compile(r"\\text\s*\{\s*([A-E])\s*\}", re.IGNORECASE), + re.compile(r"\\bf\s*\{\s*([A-E])\s*\}", re.IGNORECASE), +] + +SIMPLE_CMDS = ['mathrm', 'text', 'operatorname', 'mathbf', 'textit', 'textbf', 'bf'] + +OPTION_PATTERNS = [ + re.compile(r"^\s*(?:\(([A-E])\)|([A-E]))[)\.::-]?\s+", re.IGNORECASE), + re.compile(r"^\s*\$\\textbf\{?\(([A-E])\)\}?\$\s*", re.IGNORECASE), + re.compile(r"^\s*\(\s*([A-E])\s*\)\s*[:\-\.]?\s+", re.IGNORECASE), + re.compile(r"^\s*([A-E])\s*[\.::\-]\s+", re.IGNORECASE), + re.compile(r"^\s*([A-E])\s+(?=[A-Za-z\d\$\(])", re.IGNORECASE), + re.compile(r"^\s*\$\s*\\textbf\s*\{\s*\(([A-E])\)\s*\}\s*\$", re.IGNORECASE), + re.compile(r"^\s*\$\s*\\mathrm\s*\{\s*\(([A-E])\)\s*\}\s*\$", re.IGNORECASE), + re.compile(r"^\s*\$\s*\(([A-E])\)\s*\$", re.IGNORECASE), + re.compile(r"\s+\$\\textbf\{?\(([A-E])\)\}?\$\s+", re.IGNORECASE), + re.compile(r"^\s*\(([A-E])\)", re.IGNORECASE), + re.compile(r"^([A-E])\)", re.IGNORECASE), +] + +LABEL_PATTERNS = [ + re.compile(r"\\textbf\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\mathrm\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\\mathbf\s*\{\s*\(([A-E])\)\s*\}", re.IGNORECASE), + re.compile(r"\(([A-E])\)", re.IGNORECASE), + re.compile(r"^([A-E])$", re.IGNORECASE), + re.compile(r"^([A-E])[)\.::\-]", re.IGNORECASE), + re.compile(r"^([A-E])\s*$", re.IGNORECASE), + re.compile(r"^([A-E])[)\s]", re.IGNORECASE), + re.compile(r"([A-E])(?:\s*,\s*([A-E]))*", re.IGNORECASE), +] + + +def strip_simple_command(cmd: str, text: str) -> str: + """Remove LaTeX commands while preserving content.""" + pattern = re.compile(rf"\\{cmd}\s*\{{([^{{}}]*)\}}", re.IGNORECASE) + iterations = 0 + while iterations < 10: + new_text, count = pattern.subn(r"\1", text) + text = new_text + iterations += 1 + if count == 0: + break + return text + + +def normalize_prompt(raw_text: str) -> str: + """Normalize prompt by cleaning LaTeX and formatting.""" + text = raw_text.replace('\r', '') + text = text.replace('\\qquad', '\n').replace('\\quad', ' ') + text = text.replace('\\\\', '\n') + text = text.replace('\\newline', '\n') + + for _ in range(3): + for pat in WRAPPED_PAIR: + text = pat.sub(lambda m: f"({m.group(1)})", text) + for pat in WRAPPED_SINGLE: + text = pat.sub(lambda m: m.group(1), text) + + for cmd in SIMPLE_CMDS: + text = strip_simple_command(cmd, text) + + text = re.sub(r"\\[a-zA-Z]+\*?", " ", text) + text = text.replace('$', ' ').replace('~', ' ') + text = text.replace('{', ' ').replace('}', ' ') + + lines = [] + for line in text.split('\n'): + clean = re.sub(r"\s+", ' ', line).strip() + if clean: + lines.append(clean) + return '\n'.join(lines) + + +def try_match_option(line: str): + """Try to match option marker using multiple patterns.""" + for pattern in OPTION_PATTERNS: + m = pattern.search(line) + if m: + letter = None + for group in m.groups(): + if group: + letter = group.upper() + break + if letter and letter in 'ABCDE': + return letter, m.end() + return None, None + + +def extract_choices_inline(clean_prompt: str): + """Extract choices when they appear inline.""" + all_matches = [] + for pattern in OPTION_PATTERNS: + for m in pattern.finditer(clean_prompt): + letter = None + for group in m.groups(): + if group: + letter = group.upper() + break + if letter and letter in 'ABCDE': + all_matches.append((m.start(), m.end(), letter)) + + if len(all_matches) < 2: + return None + + all_matches.sort() + option_texts = {} + letters = [] + + for i, (start, end, letter) in enumerate(all_matches): + if letter in letters: + continue + letters.append(letter) + + content_start = end + content_end = all_matches[i + 1][0] if i + 1 < len(all_matches) else len(clean_prompt) + content = clean_prompt[content_start:content_end].strip() + + if not content: + content = "(blank)" + option_texts[letter] = content + + if len(option_texts) < 2: + return None + + stem = clean_prompt[:all_matches[0][0]].strip() + if not stem: + stem = "Select the correct option." + + return stem, option_texts, letters + + +def extract_choices(clean_prompt: str): + """Extract multiple choice options with ultra-enhanced pattern matching.""" + lines = clean_prompt.split('\n') + choices = [] + + for idx, line in enumerate(lines): + letter, pos = try_match_option(line) + if letter: + choices.append((letter, idx, pos)) + + if len(choices) >= 2: + seen = set() + filtered = [] + for letter, idx, pos in choices: + if letter not in seen: + seen.add(letter) + filtered.append((letter, idx, pos)) + + if len(filtered) >= 2: + letters = [letter for letter, _, _ in filtered] + if is_valid_sequence(letters): + option_texts = {} + for i, (letter, idx, pos) in enumerate(filtered): + start_line = idx + end_line = filtered[i + 1][1] if i + 1 < len(filtered) else len(lines) + desc_lines = [] + + remainder = lines[idx][pos:].strip() + if remainder: + desc_lines.append(remainder) + + for j in range(idx + 1, end_line): + desc_lines.append(lines[j]) + + desc = ' '.join(desc_lines).strip() + if not desc: + desc = '(blank)' + option_texts[letter] = desc + + stem_lines = lines[:filtered[0][1]] + stem = ' '.join(stem_lines).strip() + if not stem: + stem = "Select the correct option." + + return stem, option_texts, letters + + return extract_choices_inline(clean_prompt) + + +def is_valid_sequence(letters): + """Check if letters form a valid subsequence of ABCDE.""" + if not letters: + return False + + unique_letters = [] + for letter in letters: + if letter not in unique_letters: + unique_letters.append(letter) + + if len(unique_letters) < 2: + return False + + expected = 'ABCDE' + expected_idx = 0 + + for letter in unique_letters: + try: + idx = expected.index(letter, expected_idx) + expected_idx = idx + 1 + except ValueError: + return False + + return True + + +def parse_label(label_raw: str): + """Parse label with enhanced pattern matching.""" + if not label_raw: + return None + + for pattern in LABEL_PATTERNS: + matches = pattern.findall(label_raw.upper()) + if matches: + if matches and isinstance(matches[0], tuple): + letters = [m for group in matches for m in (group if isinstance(group, tuple) else [group]) if m] + else: + letters = matches + + for letter in letters: + if letter and letter in 'ABCDE': + return letter + + simple_match = re.search(r'\b([A-E])\b', label_raw.upper()) + if simple_match: + return simple_match.group(1) + + return None + + +def fuzzy_match_label(label_raw: str, available_options, option_texts): + """Try to match label even if format is unusual.""" + if not label_raw: + return None + + cleaned = label_raw.strip() + letter = parse_label(cleaned) + if letter and letter in available_options: + return letter + + cleaned_lower = cleaned.lower() + for opt_letter, opt_text in option_texts.items(): + opt_text_clean = opt_text.strip().lower() + + if cleaned_lower == opt_text_clean: + return opt_letter + + if cleaned_lower in opt_text_clean or opt_text_clean in cleaned_lower: + if len(cleaned) > 0 and len(opt_text_clean) > 0: + similarity = min(len(cleaned), len(opt_text_clean)) / max(len(cleaned), len(opt_text_clean)) + if similarity > 0.5: + return opt_letter + + for opt_letter in available_options: + if cleaned.startswith(opt_letter) or cleaned.startswith(f"({opt_letter})"): + return opt_letter + + return None + + +# ==================== 主处理逻辑 ==================== + +processed = [] +debug_samples = [] + +# 使用索引作为查找键 +indexed_data: Dict[int, dict] = {} + +# 第一遍:读取所有数据并建立索引 +print("步骤 1: 读取源文件并建立索引...") +with SRC.open('r', encoding='utf-8') as fin: + for line in fin: + line = line.strip() + if not line: + continue + + try: + record = json.loads(line) + original_index = record.get('original_index') + + if original_index is None: + skip_reason['missing_index'] += 1 + continue + + indexed_data[int(original_index)] = record + + except (json.JSONDecodeError, ValueError) as e: + skip_reason['json_error'] += 1 + +print(f"已加载 {len(indexed_data)} 条索引记录\n") + +# 第二遍:处理每条记录 +print("步骤 2: 提取选择题...") +for original_index, record in indexed_data.items(): + prompt_turns = record.get('prompt', []) + contents = [turn.get('content', '') for turn in prompt_turns if turn.get('content')] + + if not contents: + skip_reason['empty_prompt'] += 1 + continue + + # 提取和处理 + raw_prompt = '\n'.join(contents) + normalized = normalize_prompt(raw_prompt) + result = extract_choices(normalized) + + if not result: + skip_reason['no_choice_detected'] += 1 + if len(debug_samples) < 10: + debug_samples.append({ + 'original_index': original_index, + 'raw': raw_prompt[:300], + 'normalized': normalized[:300] + }) + continue + + stem, option_texts, letter_order = result + + # 解析标签 + label_raw = str(record.get('label', '')) + answer_letter = fuzzy_match_label(label_raw, option_texts.keys(), option_texts) + + if not answer_letter: + skip_reason['label_no_letter'] += 1 + if len([s for s in debug_samples if 'label_raw' in s]) < 10: + debug_samples.append({ + 'original_index': original_index, + 'label_raw': label_raw, + 'options': list(option_texts.keys()), + 'option_texts': {k: v[:50] for k, v in option_texts.items()} + }) + continue + + if answer_letter not in option_texts: + skip_reason['letter_not_in_options'] += 1 + continue + + # 构建最终问题 + option_lines = [f"{letter}) {option_texts[letter]}" for letter in letter_order] + question_text = ( + f"{stem}\n\n" + f"Options:\n" + '\n'.join(option_lines) + '\n\n' + f"Answer with the option letter only, in the form \\boxed{{{answer_letter}}}." + ) + + # **保留 original_index** + processed.append({ + 'original_index': original_index, # 保留索引 + 'prompt': [{'role': 'user', 'content': question_text}], + 'label': answer_letter + }) + +# 写入输出 +with DST.open('w', encoding='utf-8') as fout: + for item in processed: + json.dump(item, fout, ensure_ascii=False) + fout.write('\n') + +# 打印统计 +total_lines = len(indexed_data) + sum(skip_reason.values()) - skip_reason.get('missing_index', 0) +print(f'\n{"="*60}') +print(f'total records: {len(indexed_data)}') +print(f'kept lines: {len(processed)}') +print(f'success rate: {len(processed) / len(indexed_data) * 100:.1f}%') +print(f'\nskip reasons:') +for reason, count in skip_reason.most_common(): + print(f' {reason:>20}: {count}') + +# 打印调试样本 +if debug_samples: + print(f'\n{"="*60}') + print('Debug samples (first 5 failed cases):') + print('='*60) + for i, sample in enumerate(debug_samples[:5], 1): + print(f'\nSample {i}:') + for key, value in sample.items(): + if isinstance(value, str) and len(value) > 150: + print(f' {key}: {value[:150]}...') + else: + print(f' {key}: {value}') + +print(f'\n输出文件: {DST}') diff --git a/tools/polaris_extract_hard_samples.py b/tools/polaris_extract_hard_samples.py new file mode 100644 index 000000000..351811030 --- /dev/null +++ b/tools/polaris_extract_hard_samples.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +""" +Extract "hard" dataset samples based on POLARIS reward tracking logs. + +This script scans a reward tracking JSONL file and keeps the entries whose +average reward falls below a user-provided threshold. The matching samples +are written to a new dataset file, preserving the original format +(.json, .jsonl, or .parquet). + +Each retained sample in the output carries an `_original_index` field so that +you can trace it back to the exact line/row in the source dataset. + +Example: + python polaris_extract_hard_samples.py \ + --data-path data/train.jsonl \ + --reward-log polaris_tracking/run.jsonl \ + --output data/train_hard_cases.jsonl \ + --threshold 0.1 \ + --threshold-low 0.0 +""" + +from __future__ import annotations + +import argparse +import json +from collections import defaultdict +from pathlib import Path + +import pandas as pd + +ORIGINAL_INDEX_KEY = "_original_index" + + +def collect_low_reward_indices(jsonl_path: str, threshold_high: float, threshold_low: float) -> set[int]: + """ + Parse reward tracking logs and return indices whose average reward is within the thresholds. + """ + index_to_scores: dict[int, list[float]] = defaultdict(list) + + with open(jsonl_path, "r") as f: + for line in f: + if not line.strip(): + continue + entry = json.loads(line) + indices = entry.get("index") or entry.get("indices") + scores = entry.get("score") or entry.get("scores") + if indices is None or scores is None: + raise KeyError("Reward log entries must contain 'index'/'indices' and 'score'/'scores'.") + + for idx, score in zip(indices, scores): + index_to_scores[int(idx)].append(float(score)) + + keep_indices: set[int] = set() + for idx, scores in index_to_scores.items(): + avg_score = sum(scores) / len(scores) + if threshold_low <= avg_score < threshold_high: + keep_indices.add(idx) + + print(f"Total tracked samples: {len(index_to_scores)}") + print( + "Samples kept " + f"(avg reward in [{threshold_low}, {threshold_high}))" + f": {len(keep_indices)}" + ) + return keep_indices + + +def keep_from_json(input_path: str, output_path: str, keep_indices: set[int]) -> None: + """ + Save only the selected indices from a JSON dataset. + Supports list-of-objects and columnar dict formats. + """ + with open(input_path, "r") as f: + data = json.load(f) + + if isinstance(data, list): + filtered = [] + for idx, item in enumerate(data): + if idx not in keep_indices: + continue + if isinstance(item, dict): + sample = dict(item) + sample[ORIGINAL_INDEX_KEY] = idx + else: + sample = { "value": item, ORIGINAL_INDEX_KEY: idx } + filtered.append(sample) + elif isinstance(data, dict): + first_key = next(iter(data)) + if isinstance(data[first_key], dict): + available_indices = set(data[first_key].keys()) + selected_old_indices = [str(idx) for idx in sorted(keep_indices) if str(idx) in available_indices] + + filtered = {key: {} for key in data} + filtered[ORIGINAL_INDEX_KEY] = {} + for new_idx, old_idx in enumerate(selected_old_indices): + for key in data: + filtered[key][str(new_idx)] = data[key][old_idx] + filtered[ORIGINAL_INDEX_KEY][str(new_idx)] = int(old_idx) + else: + raise ValueError("Unsupported JSON column format.") + else: + raise ValueError("Unsupported JSON structure.") + + with open(output_path, "w") as f: + json.dump(filtered, f, indent=2, ensure_ascii=False) + saved_count = len(filtered) if isinstance(filtered, list) else len(selected_old_indices) + print(f"Saved {saved_count} samples to {output_path}") + + +def keep_from_jsonl(input_path: str, output_path: str, keep_indices: set[int]) -> None: + """Write only the selected lines from a JSONL dataset.""" + kept = 0 + with open(output_path, "w") as out_f: + with open(input_path, "r") as in_f: + for idx, line in enumerate(in_f): + if idx in keep_indices: + entry = json.loads(line) + if isinstance(entry, dict): + payload = dict(entry) + payload[ORIGINAL_INDEX_KEY] = idx + else: + payload = {"value": entry, ORIGINAL_INDEX_KEY: idx} + out_f.write(json.dumps(payload, ensure_ascii=False) + "\n") + kept += 1 + print(f"Saved {kept} samples to {output_path}") + + +def keep_from_parquet(input_path: str, output_path: str, keep_indices: set[int]) -> None: + """Write only the selected rows from a Parquet dataset.""" + df = pd.read_parquet(input_path) + filtered_df = df[df.index.isin(keep_indices)].copy() + filtered_df[ORIGINAL_INDEX_KEY] = filtered_df.index.astype(int) + filtered_df.to_parquet(output_path) + print(f"Original rows: {len(df)}, kept rows: {len(filtered_df)}") + print(f"Saved parquet to {output_path}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Extract low-reward samples for case study.") + parser.add_argument("--data-path", required=True, help="Input dataset path (.json, .jsonl, or .parquet).") + parser.add_argument("--reward-log", required=True, help="Reward tracking JSONL file.") + parser.add_argument("--output", required=True, help="Destination file for the filtered dataset.") + parser.add_argument( + "--threshold", + type=float, + default=0.1, + help="Upper bound: keep samples whose avg reward is below this value (default: 0.1).", + ) + parser.add_argument( + "--threshold-low", + type=float, + default=0.0, + help="Lower bound: keep samples whose avg reward is >= this value (default: 0.0).", + ) + args = parser.parse_args() + + if args.threshold_low >= args.threshold: + raise ValueError("`threshold-low` must be smaller than `threshold`.") + + print(f"Parsing rewards from {args.reward_log}") + keep_indices = collect_low_reward_indices(args.reward_log, args.threshold, args.threshold_low) + + input_path = Path(args.data_path) + suffix = input_path.suffix.lower() + print(f"Selecting hard samples from {args.data_path}") + + if suffix == ".json": + keep_from_json(args.data_path, args.output, keep_indices) + elif suffix == ".jsonl": + keep_from_jsonl(args.data_path, args.output, keep_indices) + elif suffix == ".parquet": + keep_from_parquet(args.data_path, args.output, keep_indices) + else: + raise ValueError(f"Unsupported data format: {suffix}") + + print("Extraction complete.") + + +if __name__ == "__main__": + main() diff --git a/tools/remove_hard_cases.py b/tools/remove_hard_cases.py new file mode 100644 index 000000000..917ac0d9b --- /dev/null +++ b/tools/remove_hard_cases.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +""" +从数据集中移除 hard cases(问题样本) +- 读取 hard cases 列表 +- 从源数据中移除对应的 original_index +- 使用模糊匹配验证内容一致性 +- 生成清理后的数据集 +""" + +import json +from pathlib import Path +from collections import Counter +from typing import Dict, Tuple, Optional +from difflib import SequenceMatcher + + +class HardCaseRemover: + """Hard Case 移除器""" + + def __init__(self): + # 配置路径 + self.hard_cases_file = Path( + "/Users/shuocai/Downloads/home/projects/polyullm/kejing/slime_workspace/" + "slime_polaris/slime/analysis/polaris_hard_cases_issues.jsonl" + ) + self.source_file = Path( + "/lustre/projects/polyullm/caishuo/cs_data/slime_rl/" + "polaris-data-53K-indexed.jsonl" + ) + self.output_file = Path( + "/lustre/projects/polyullm/caishuo/cs_data/slime_rl/" + "polaris-data-53K-indexed__clean.jsonl" + ) + self.log_file = Path("analysis/remove_hard_cases_log.txt") + + # 模糊匹配配置 + self.similarity_threshold = 0.7 # 内容相似度阈值 + self.prompt_preview_length = 200 # 用于比较的 prompt 预览长度 + + self.stats = Counter() + self.log_entries = [] + self.fuzzy_warnings = [] + + def load_hard_case_indices(self) -> Dict[int, dict]: + """ + 加载 hard cases 的 original_index 及其详细信息 + + Returns: + {original_index: hard_case_record} 字典 + """ + hard_cases = {} + + if not self.hard_cases_file.exists(): + self.log(f"错误: Hard cases 文件不存在: {self.hard_cases_file}") + return hard_cases + + self.log(f"读取 hard cases 文件: {self.hard_cases_file}") + + with self.hard_cases_file.open('r', encoding='utf-8') as f: + for line_no, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + try: + record = json.loads(line) + original_index = record.get('original_index') + + if original_index is None: + self.log(f"警告: Hard cases 第 {line_no} 行缺少 original_index") + self.stats['hard_missing_index'] += 1 + continue + + original_index = int(original_index) + hard_cases[original_index] = record + self.stats['hard_cases_loaded'] += 1 + + # 记录问题类型(用于日志) + issue = record.get('issue', 'Unknown issue') + if line_no <= 10: # 只记录前10个 + self.log(f" - index {original_index}: {issue[:80]}") + + except (json.JSONDecodeError, ValueError) as e: + self.log(f"错误: Hard cases 第 {line_no} 行解析失败: {e}") + self.stats['hard_parse_error'] += 1 + + return hard_cases + + def calculate_similarity(self, text1: str, text2: str) -> float: + """ + 计算两段文本的相似度 + + Args: + text1: 第一段文本 + text2: 第二段文本 + + Returns: + 相似度分数 (0-1) + """ + return SequenceMatcher(None, text1, text2).ratio() + + def extract_prompt_text(self, record: dict, is_hard_case: bool = False) -> str: + """ + 从记录中提取 prompt 文本用于比较 + + Args: + record: 数据记录 + is_hard_case: 是否为 hard case 记录(使用 prompt_full 字段) + + Returns: + 提取的文本内容 + """ + if is_hard_case: + # Hard case 使用 prompt_full 字段 + prompt_text = record.get('prompt_full', '') + return str(prompt_text)[:self.prompt_preview_length] + else: + # 源数据使用 prompt 列表结构 + prompt = record.get('prompt', []) + if isinstance(prompt, list): + contents = [] + for turn in prompt: + if isinstance(turn, dict): + content = turn.get('content', '') + if content: + contents.append(str(content)) + return ' '.join(contents)[:self.prompt_preview_length] + else: + return str(prompt)[:self.prompt_preview_length] + + def verify_hard_case_match( + self, + original_index: int, + source_record: dict, + hard_case_record: dict + ) -> Tuple[bool, Optional[str]]: + """ + 验证源记录是否与 hard case 记录匹配 + + Args: + original_index: 索引号 + source_record: 源数据记录 + hard_case_record: hard case 记录 + + Returns: + (是否匹配, 警告信息) + """ + # 提取 prompt 内容 + source_prompt = self.extract_prompt_text(source_record, is_hard_case=False) + hard_prompt = self.extract_prompt_text(hard_case_record, is_hard_case=True) + + # 如果都为空,认为匹配 + if not source_prompt and not hard_prompt: + return True, None + + # 计算相似度 + similarity = self.calculate_similarity(source_prompt, hard_prompt) + + if similarity >= self.similarity_threshold: + return True, None + else: + warning = ( + f"索引 {original_index} 内容相似度低 ({similarity:.2f}):\n" + f" 源数据: {source_prompt[:100]}...\n" + f" Hard case: {hard_prompt[:100]}..." + ) + return False, warning + + def remove_hard_cases(self, hard_cases: Dict[int, dict]) -> int: + """ + 从源文件中移除 hard cases(带模糊匹配验证) + + Args: + hard_cases: {original_index: hard_case_record} 字典 + + Returns: + 实际移除的数量 + """ + if not self.source_file.exists(): + self.log(f"错误: 源文件不存在: {self.source_file}") + return 0 + + self.log(f"\n读取源文件: {self.source_file}") + self.log(f"输出文件: {self.output_file}") + + removed_count = 0 + kept_count = 0 + verified_removal_count = 0 + unverified_removal_count = 0 + + with self.source_file.open('r', encoding='utf-8') as fin, \ + self.output_file.open('w', encoding='utf-8') as fout: + + for physical_line, line in enumerate(fin): + line = line.strip() + + # 空行直接写入 + if not line: + fout.write('\n') + continue + + try: + record = json.loads(line) + original_index = record.get('original_index') + + if original_index is None: + # 缺少索引,保留(并记录警告) + self.log(f"警告: 源文件物理行 {physical_line} 缺少 original_index,保留") + self.stats['source_missing_index'] += 1 + fout.write(line + '\n') + kept_count += 1 + continue + + original_index = int(original_index) + + # 检查是否在移除列表中 + if original_index in hard_cases: + # 模糊匹配验证 + hard_case_record = hard_cases[original_index] + is_match, warning = self.verify_hard_case_match( + original_index, record, hard_case_record + ) + + if is_match: + # 验证通过,移除 + removed_count += 1 + verified_removal_count += 1 + self.stats['removed'] += 1 + self.stats['verified_removal'] += 1 + + # 只记录前几个 + if removed_count <= 10: + self.log(f"✓ 移除 (已验证): index {original_index}") + elif removed_count == 11: + self.log(f"... (后续移除不再详细记录)") + + else: + # 验证未通过,记录警告但仍移除(可配置) + self.log(f"⚠️ 警告: {warning}") + self.fuzzy_warnings.append(warning) + removed_count += 1 + unverified_removal_count += 1 + self.stats['removed'] += 1 + self.stats['unverified_removal'] += 1 + + if unverified_removal_count <= 5: + self.log(f"⚠️ 移除 (未验证): index {original_index}") + + else: + # 保留 + fout.write(line + '\n') + kept_count += 1 + self.stats['kept'] += 1 + + except (json.JSONDecodeError, ValueError) as e: + # JSON 解析错误,保留原样(并记录) + self.log(f"错误: 源文件物理行 {physical_line} 解析失败: {e},保留原样") + self.stats['source_parse_error'] += 1 + fout.write(line + '\n') + kept_count += 1 + + self.log(f"\n处理完成:") + self.log(f" 保留: {kept_count} 条") + self.log(f" 移除总数: {removed_count} 条") + self.log(f" - 已验证移除: {verified_removal_count} 条") + if unverified_removal_count > 0: + self.log(f" - ⚠️ 未验证移除: {unverified_removal_count} 条") + + return removed_count + + def verify_removal(self, hard_cases: Dict[int, dict]): + """ + 验证移除操作是否成功 + + Args: + hard_cases: 应该被移除的 hard cases 字典 + """ + if not self.output_file.exists(): + self.log("错误: 输出文件不存在,无法验证") + return + + self.log(f"\n验证输出文件: {self.output_file}") + + found_hard_cases = [] + total_records = 0 + + with self.output_file.open('r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + try: + record = json.loads(line) + original_index = record.get('original_index') + + if original_index is not None: + total_records += 1 + original_index = int(original_index) + + # 检查是否还存在应该被移除的索引 + if original_index in hard_cases: + found_hard_cases.append(original_index) + + except (json.JSONDecodeError, ValueError): + pass + + if found_hard_cases: + self.log(f"⚠️ 警告: 发现 {len(found_hard_cases)} 个应该被移除但仍存在的索引:") + self.log(f" {found_hard_cases[:20]}") + self.stats['verification_failed'] += len(found_hard_cases) + else: + self.log(f"✅ 验证通过: 所有 hard cases 已被移除") + self.log(f" 输出文件总记录数: {total_records}") + self.stats['verification_passed'] = 1 + + def log(self, message: str): + """记录日志""" + print(message) + self.log_entries.append(message) + + def save_log(self): + """保存日志到文件""" + self.log_file.parent.mkdir(parents=True, exist_ok=True) + with self.log_file.open('w', encoding='utf-8') as f: + f.write('\n'.join(self.log_entries)) + print(f"\n日志已保存: {self.log_file}") + + def print_summary(self): + """打印统计摘要""" + print("\n" + "="*60) + print("Hard Cases 移除操作摘要(带模糊匹配验证)") + print("="*60) + + print(f"\nHard Cases 加载:") + print(f" 成功加载: {self.stats['hard_cases_loaded']}") + print(f" 索引缺失: {self.stats['hard_missing_index']}") + print(f" 解析错误: {self.stats['hard_parse_error']}") + + print(f"\n移除统计:") + print(f" 成功移除: {self.stats['removed']}") + if self.stats['verified_removal'] > 0 or self.stats['unverified_removal'] > 0: + print(f" - ✓ 已验证移除: {self.stats['verified_removal']}") + if self.stats['unverified_removal'] > 0: + print(f" - ⚠️ 未验证移除: {self.stats['unverified_removal']}") + print(f" 保留记录: {self.stats['kept']}") + + if self.stats['source_missing_index'] > 0: + print(f"\n源文件警告:") + print(f" 缺少索引: {self.stats['source_missing_index']}") + + if self.fuzzy_warnings: + print(f"\n⚠️ 模糊匹配警告 (共 {len(self.fuzzy_warnings)} 个):") + for i, warning in enumerate(self.fuzzy_warnings[:3], 1): + print(f"\n 警告 {i}:") + for line in warning.split('\n'): + print(f" {line}") + if len(self.fuzzy_warnings) > 3: + print(f"\n ... 还有 {len(self.fuzzy_warnings) - 3} 个警告") + + if self.stats['verification_failed'] > 0: + print(f"\n⚠️ 验证失败:") + print(f" 未成功移除: {self.stats['verification_failed']}") + elif self.stats['verification_passed'] > 0: + print(f"\n✅ 验证通过: 所有 hard cases 已被移除") + + print(f"\n最终结果:") + original_count = self.stats['removed'] + self.stats['kept'] + final_count = self.stats['kept'] + print(f" 原始记录数: {original_count}") + print(f" 输出记录数: {final_count}") + print(f" 净减少: {self.stats['removed']} 条 ({self.stats['removed'] / max(original_count, 1) * 100:.2f}%)") + + print(f"\n配置:") + print(f" 相似度阈值: {self.similarity_threshold}") + print(f" 比较长度: {self.prompt_preview_length} 字符") + + print(f"\n输出文件: {self.output_file}") + print("="*60) + + +def main(): + """主函数""" + remover = HardCaseRemover() + + print("开始移除 Hard Cases(带模糊匹配验证)...") + print() + + # 步骤 1: 加载 hard cases 索引和详细信息 + remover.log("步骤 1: 加载 hard cases 索引及详细信息") + hard_cases = remover.load_hard_case_indices() + + if not hard_cases: + remover.log("\n错误: 未找到任何 hard cases 索引") + return 1 + + remover.log(f"\n共需要移除 {len(hard_cases)} 个 hard cases") + remover.log(f"索引列表: {sorted(hard_cases.keys())}") + + # 步骤 2: 移除 hard cases(带模糊匹配验证) + remover.log(f"\n步骤 2: 从源文件中移除 hard cases(相似度阈值: {remover.similarity_threshold})") + removed_count = remover.remove_hard_cases(hard_cases) + + if removed_count == 0: + remover.log("\n警告: 未移除任何记录") + + # 步骤 3: 验证移除结果 + remover.log("\n步骤 3: 验证移除结果") + remover.verify_removal(hard_cases) + + # 打印摘要 + remover.print_summary() + + # 保存日志 + remover.save_log() + + if remover.stats['unverified_removal'] > 0: + print(f"\n⚠️ 注意: 有 {remover.stats['unverified_removal']} 个索引的内容相似度低于阈值,但仍被移除") + print(f" 请检查日志文件以获取详细信息: {remover.log_file}") + + print("\n✅ 完成!") + return 0 + + +if __name__ == '__main__': + exit(main()) diff --git a/tools/update_src_data_with_index_v2.py b/tools/update_src_data_with_index_v2.py new file mode 100644 index 000000000..ad166974a --- /dev/null +++ b/tools/update_src_data_with_index_v2.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +""" +优雅、安全的数据替换工具 +- 使用 0-based 行号索引 +- 严格的内容验证,避免替换错误 +- 详细的统计和日志 +- 生成新文件,不覆盖原数据 +""" + +import json +import hashlib +from pathlib import Path +from collections import Counter, defaultdict +from dataclasses import dataclass +from typing import Dict, List, Tuple, Optional + + +# ==================== 配置 ==================== + +@dataclass +class Config: + """配置类,集中管理所有路径""" + original_file: Path = Path("/lustre/projects/polyullm/caishuo/cs_data/slime_rl/polaris-data-53K.jsonl") + new_data_file: Path = Path("analysis/polaris-data-53K__choose_ultra.jsonl") # 使用 ultra 版本 + old_backup_file: Path = Path("analysis/polaris-data-53K__choose_old.jsonl") + output_file: Path = Path("/lustre/projects/polyullm/caishuo/cs_data/slime_rl/polaris-data-53K__patched_v2.jsonl") + log_file: Path = Path("analysis/update_log.txt") + + def validate(self) -> List[str]: + """验证必需文件是否存在""" + errors = [] + if not self.original_file.exists(): + errors.append(f"原始数据文件不存在: {self.original_file}") + if not self.new_data_file.exists(): + errors.append(f"新数据文件不存在: {self.new_data_file}") + return errors + + +# ==================== 数据验证 ==================== + +class DataValidator: + """数据验证器,确保替换的安全性""" + + @staticmethod + def compute_signature(data: dict, keys: List[str] = None) -> str: + """ + 计算数据签名,用于快速比对 + + Args: + data: 要计算签名的数据 + keys: 用于计算签名的键列表,None 表示使用所有键 + + Returns: + MD5 签名字符串 + """ + if keys: + filtered = {k: data.get(k) for k in keys if k in data} + else: + filtered = data + + content = json.dumps(filtered, sort_keys=True, ensure_ascii=False) + return hashlib.md5(content.encode()).hexdigest()[:16] + + @staticmethod + def compare_prompts(orig: dict, old_backup: dict) -> Tuple[bool, str]: + """ + 比较原始数据和备份数据的 prompt 是否匹配 + + Returns: + (是否匹配, 不匹配的原因) + """ + orig_prompt = orig.get('prompt', []) + old_prompt = old_backup.get('prompt', []) + + # 提取文本内容进行比较(忽略格式差异) + def extract_text(prompt_list): + return ' '.join([ + turn.get('content', '') + for turn in prompt_list + if isinstance(turn, dict) + ]).strip() + + orig_text = extract_text(orig_prompt) + old_text = extract_text(old_prompt) + + # 归一化比较(去除多余空格) + orig_normalized = ' '.join(orig_text.split()) + old_normalized = ' '.join(old_text.split()) + + # 计算相似度 + if orig_normalized == old_normalized: + return True, "完全匹配" + + # 允许一定程度的差异(如 LaTeX 格式化) + similarity = len(set(orig_normalized) & set(old_normalized)) / max(len(orig_normalized), len(old_normalized), 1) + + if similarity > 0.9: + return True, f"高度相似 ({similarity:.2%})" + + return False, f"内容不匹配 (相似度: {similarity:.2%})" + + @staticmethod + def validate_new_record(record: dict, expected_index: int) -> Tuple[bool, str]: + """ + 验证新记录的有效性 + + Args: + record: 新记录 + expected_index: 期望的索引 + + Returns: + (是否有效, 错误信息) + """ + # 检查必需字段 + if 'prompt' not in record: + return False, "缺少 prompt 字段" + + if 'label' not in record: + return False, "缺少 label 字段" + + # 检查 prompt 格式 + prompt = record.get('prompt', []) + if not isinstance(prompt, list) or not prompt: + return False, "prompt 格式错误或为空" + + # 检查 original_index(如果存在) + if 'original_index' in record: + actual_index = record['original_index'] + if actual_index != expected_index: + return False, f"索引不匹配: 期望 {expected_index}, 实际 {actual_index}" + + return True, "验证通过" + + +# ==================== 主处理逻辑 ==================== + +class DataReplacer: + """数据替换器""" + + def __init__(self, config: Config): + self.config = config + self.validator = DataValidator() + self.stats = Counter() + self.issues = defaultdict(list) + self.log_entries = [] + + def load_new_records(self) -> Dict[int, dict]: + """ + 加载新数据,建立索引映射 + + Returns: + {行号: 数据记录} 的字典 (0-based) + """ + new_records = {} + + with self.config.new_data_file.open('r', encoding='utf-8') as f: + for line_no, line in enumerate(f): + line = line.strip() + if not line: + continue + + try: + item = json.loads(line) + + # 获取原始索引 (0-based) + orig_index = item.get('original_index') + if orig_index is None: + self.log(f"警告: 第 {line_no} 行缺少 original_index,跳过") + self.stats['missing_index'] += 1 + continue + + # 转换为整数(确保是 0-based) + orig_index = int(orig_index) + + # 验证新记录 + valid, msg = self.validator.validate_new_record(item, orig_index) + if not valid: + self.log(f"警告: 索引 {orig_index} 的新记录验证失败: {msg}") + self.stats['invalid_new_record'] += 1 + continue + + # 移除 original_index 键(替换时不需要) + clean_item = {k: v for k, v in item.items() if k != 'original_index'} + + # 检查重复 + if orig_index in new_records: + self.log(f"警告: 索引 {orig_index} 重复出现") + self.stats['duplicate_index'] += 1 + + new_records[orig_index] = clean_item + self.stats['new_records_loaded'] += 1 + + except json.JSONDecodeError as e: + self.log(f"错误: 第 {line_no} 行 JSON 解析失败: {e}") + self.stats['json_error'] += 1 + + return new_records + + def load_old_backup(self) -> Dict[int, dict]: + """ + 加载旧备份数据(用于验证) + + Returns: + {行号: 数据记录} 的字典 (0-based) + """ + if not self.config.old_backup_file.exists(): + self.log("提示: 未找到旧备份文件,跳过验证") + return {} + + old_records = {} + + with self.config.old_backup_file.open('r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + try: + item = json.loads(line) + orig_index = int(item.get('original_index', -1)) + if orig_index >= 0: + old_records[orig_index] = item + except (json.JSONDecodeError, ValueError): + pass + + self.log(f"已加载 {len(old_records)} 条旧备份记录") + return old_records + + def process_replacement(self, + new_records: Dict[int, dict], + old_records: Dict[int, dict]) -> Tuple[int, int]: + """ + 执行替换操作 + + 逻辑说明: + 1. 如果行号在 new_records 中 -> 替换为新数据 + 2. 如果行号在 old_records 中但不在 new_records 中 -> 删除(说明被筛选掉了) + 3. 其他情况 -> 保持原样 + + Returns: + (替换的记录数, 删除的记录数) + """ + replaced_count = 0 + removed_count = 0 + + # 计算应该被删除的索引集合 + indices_to_remove = set(old_records.keys()) - set(new_records.keys()) + + if indices_to_remove: + self.log(f"检测到 {len(indices_to_remove)} 条记录在备份中但未通过筛选,将被删除:") + for idx in sorted(list(indices_to_remove)[:10]): # 只显示前10个 + self.log(f" - 索引 {idx}") + if len(indices_to_remove) > 10: + self.log(f" ... 还有 {len(indices_to_remove) - 10} 个") + + with self.config.original_file.open('r', encoding='utf-8') as fin, \ + self.config.output_file.open('w', encoding='utf-8') as fout: + + for line_idx, line in enumerate(fin): + line = line.strip() + + # 空行直接写入 + if not line: + fout.write('\n') + continue + + # 检查是否应该删除 (在备份中但不在新数据中) + if line_idx in indices_to_remove: + self.log(f"删除: 第 {line_idx} 行 (未通过筛选)") + removed_count += 1 + self.stats['removed'] += 1 + self.issues['removed_indices'].append(line_idx) + continue # 跳过这一行,不写入输出文件 + + # 检查是否需要替换 (0-based index) + if line_idx in new_records: + # 解析原始记录 + try: + orig_record = json.loads(line) + except json.JSONDecodeError as e: + self.log(f"错误: 原始文件第 {line_idx} 行 JSON 解析失败: {e}") + self.stats['orig_json_error'] += 1 + fout.write(line + '\n') + continue + + # 如果有旧备份,进行验证 + if line_idx in old_records: + match, reason = self.validator.compare_prompts( + orig_record, + old_records[line_idx] + ) + + if not match: + self.log(f"警告: 第 {line_idx} 行内容验证失败: {reason}") + self.issues['validation_failed'].append({ + 'index': line_idx, + 'reason': reason + }) + self.stats['validation_warning'] += 1 + + # 执行替换 + new_record = new_records[line_idx] + json.dump(new_record, fout, ensure_ascii=False) + fout.write('\n') + + replaced_count += 1 + self.stats['replaced'] += 1 + + # 记录替换信息 + if replaced_count <= 5: # 只记录前几条 + self.log(f"替换: 第 {line_idx} 行") + + else: + # 保持原样 + fout.write(line + '\n') + self.stats['unchanged'] += 1 + + return replaced_count, removed_count + + def log(self, message: str): + """记录日志""" + print(message) + self.log_entries.append(message) + + def save_log(self): + """保存日志文件""" + with self.config.log_file.open('w', encoding='utf-8') as f: + f.write('\n'.join(self.log_entries)) + print(f"\n日志已保存到: {self.config.log_file}") + + def print_summary(self): + """打印统计摘要""" + print("\n" + "="*60) + print("替换操作摘要") + print("="*60) + + print(f"\n加载统计:") + print(f" 新记录加载成功: {self.stats['new_records_loaded']}") + print(f" 新记录验证失败: {self.stats['invalid_new_record']}") + print(f" 索引缺失: {self.stats['missing_index']}") + print(f" 索引重复: {self.stats['duplicate_index']}") + + print(f"\n替换统计:") + print(f" 成功替换: {self.stats['replaced']}") + print(f" 删除记录: {self.stats['removed']} (备份中有但未通过筛选)") + print(f" 保持原样: {self.stats['unchanged']}") + print(f" 验证警告: {self.stats['validation_warning']}") + + if self.stats['removed'] > 0: + print(f"\n删除详情:") + removed_indices = self.issues['removed_indices'] + print(f" 被删除的索引 (前 20 个): {sorted(removed_indices[:20])}") + if len(removed_indices) > 20: + print(f" ... 还有 {len(removed_indices) - 20} 个") + + if self.issues['validation_failed']: + print(f"\n验证失败的索引 ({len(self.issues['validation_failed'])} 个):") + for issue in self.issues['validation_failed'][:10]: # 只显示前10个 + print(f" 索引 {issue['index']}: {issue['reason']}") + if len(self.issues['validation_failed']) > 10: + print(f" ... 还有 {len(self.issues['validation_failed']) - 10} 个") + + print(f"\n最终结果:") + original_count = self.stats['replaced'] + self.stats['removed'] + self.stats['unchanged'] + final_count = self.stats['replaced'] + self.stats['unchanged'] + print(f" 原始文件记录数: {original_count}") + print(f" 输出文件记录数: {final_count}") + print(f" 净减少记录数: {self.stats['removed']}") + + print(f"\n输出文件: {self.config.output_file}") + print("="*60) + + +# ==================== 主函数 ==================== + +def main(): + """主函数""" + # 初始化配置 + config = Config() + + # 验证配置 + errors = config.validate() + if errors: + print("配置验证失败:") + for error in errors: + print(f" - {error}") + return 1 + + # 创建替换器 + replacer = DataReplacer(config) + + print("开始数据替换流程...") + print(f"原始文件: {config.original_file}") + print(f"新数据文件: {config.new_data_file}") + print(f"输出文件: {config.output_file}") + print() + + # 加载新数据 + replacer.log("步骤 1: 加载新数据...") + new_records = replacer.load_new_records() + replacer.log(f"已加载 {len(new_records)} 条新记录\n") + + # 加载旧备份(用于验证) + replacer.log("步骤 2: 加载旧备份数据...") + old_records = replacer.load_old_backup() + replacer.log("") + + # 执行替换 + replacer.log("步骤 3: 执行替换操作...") + replaced_count, removed_count = replacer.process_replacement(new_records, old_records) + replacer.log(f"替换完成: 替换 {replaced_count} 条, 删除 {removed_count} 条\n") + + # 打印摘要 + replacer.print_summary() + + # 保存日志 + replacer.save_log() + + return 0 + + +if __name__ == '__main__': + exit(main())