Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 283 additions & 1 deletion tests/unit/distributed/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from nemo_rl.distributed.model_utils import (
ChunkedDistributedEntropy,
ChunkedDistributedGatherLogprob,
ChunkedDistributedLogprob,
DistributedLogprob,
Expand Down Expand Up @@ -609,7 +610,7 @@ def _torch_baseline_logprob(self, full_logits, target):
log_softmax = torch.nn.functional.log_softmax(full_logits, dim=-1)

# Gather log probabilities for target tokens
target_mask = target >= 0 # Valid targets (assuming -1 or similar for padding)
target_mask = target >= 0 # Valid targets (assuming -1 or similar for padding)
log_probs = torch.gather(log_softmax, -1, target.unsqueeze(-1)).squeeze(-1)
log_probs = log_probs * target_mask.float()

Expand Down Expand Up @@ -956,3 +957,284 @@ def test_distributed_logprob_all_tests(

finally:
cluster.shutdown()


@ray.remote(num_gpus=1)
class ChunkedDistributedEntropyTestActor:
def __init__(self, tp_size, chunk_size):
self.tp_size = tp_size
self.chunk_size = chunk_size
self.env_vars = dict(os.environ)
torch.distributed.init_process_group(backend="nccl")
self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size)))

def _torch_baseline_entropy(self, full_logits):
"""Single-GPU PyTorch baseline implementation for entropy computation."""
# Compute log softmax and softmax using standard PyTorch
log_probs = torch.nn.functional.log_softmax(full_logits, dim=-1)
probs = torch.exp(log_probs)

# Compute entropy: H = -sum(p * log(p)) = -sum(p * log_p)
entropy = (probs * log_probs).sum(dim=-1) # [B, S]

return entropy
Comment on lines +971 to +980
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Entropy docstring/sign confusion.

The implementation returns sum_v p_v log p_v (non-positive), not the conventional positive Shannon entropy -sum p log p. Adjust the docstring/comments to match to avoid confusion.

-    def _torch_baseline_entropy(self, full_logits):
-        """Single-GPU PyTorch baseline implementation for entropy computation."""
+    def _torch_baseline_entropy(self, full_logits):
+        """Single-GPU baseline for sum_v p_v log p_v (non-positive 'entropy' used by ChunkedDistributedEntropy)."""
@@
-        # Compute entropy: H = -sum(p * log(p)) = -sum(p * log_p)
-        entropy = (probs * log_probs).sum(dim=-1)  # [B, S]
+        # Compute H_all = sum_v p_v log p_v (<= 0)
+        entropy = (probs * log_probs).sum(dim=-1)  # [B, S]
🤖 Prompt for AI Agents
In tests/unit/distributed/test_model_utils.py around lines 971 to 980, the
docstring and inline comment claim the function computes "entropy" but the code
returns sum_v p_v * log p_v (a non-positive value), i.e. the negative of
conventional Shannon entropy; update the function docstring and the inline
comment to state explicitly that the function returns the negative Shannon
entropy (sum_v p_v * log p_v, ≤ 0) or "negative entropy" (or change the sign in
the computation if you prefer positive Shannon entropy), and make the comment
above the entropy computation consistent with that description.


def test_chunked_distributed_entropy_forward_and_backward(self):
"""Test ChunkedDistributedEntropy forward and backward passes against PyTorch baseline."""
rank = int(os.environ["RANK"])

# Test parameters
batch_size = 4
seq_len = 8
full_vocab_size = 1024
vocab_part_size = full_vocab_size // self.tp_size
chunk_size = self.chunk_size

# Calculate vocab partition for this rank
vocab_start_index = rank * vocab_part_size
vocab_end_index = (rank + 1) * vocab_part_size

# Create test data with fixed seed for reproducibility (same across all ranks)
torch.manual_seed(42)

# Create full logits (same on all ranks for fair comparison)
full_logits = torch.randn(
batch_size, seq_len, full_vocab_size, device="cuda", requires_grad=True
)

# Extract this rank's vocab partition
vocab_parallel_logits = (
full_logits[:, :, vocab_start_index:vocab_end_index]
.clone()
.detach()
.requires_grad_(True)
)

# === FORWARD PASS TEST ===
# Use the same full logits for baseline computation (without gradient tracking for forward test)
baseline_logits_forward = full_logits.clone().detach()
baseline_entropy_forward = self._torch_baseline_entropy(baseline_logits_forward)

# Compute using ChunkedDistributedEntropy (forward only first)
distributed_entropy_inference = ChunkedDistributedEntropy.apply(
vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test
chunk_size,
self.tp_group,
True, # inference_only=True for forward test
)

# Compare forward results
torch.testing.assert_close(
distributed_entropy_inference,
baseline_entropy_forward,
rtol=1e-4,
atol=1e-4,
)

forward_max_diff = torch.max(
torch.abs(distributed_entropy_inference - baseline_entropy_forward)
).item()

# === BACKWARD PASS TEST ===
# Compute baseline gradients - use full_logits with gradient tracking
baseline_entropy = self._torch_baseline_entropy(full_logits)
baseline_loss = torch.sum(baseline_entropy)
baseline_loss.backward()
baseline_grad = full_logits.grad[
:, :, vocab_start_index:vocab_end_index
].clone()

# Reset full_logits grad for clean comparison
full_logits.grad = None

# Compute distributed gradients
distributed_entropy = ChunkedDistributedEntropy.apply(
vocab_parallel_logits,
chunk_size,
self.tp_group,
False, # inference_only=False to enable backward
)

distributed_loss = torch.sum(distributed_entropy)
distributed_loss.backward()
distributed_grad = vocab_parallel_logits.grad

# Compare gradients
torch.testing.assert_close(
distributed_grad, baseline_grad, rtol=1e-4, atol=1e-4
)

# Compare entropy values again (should be same as forward test)
torch.testing.assert_close(
distributed_entropy, baseline_entropy, rtol=1e-4, atol=1e-4
)

grad_max_diff = torch.max(torch.abs(distributed_grad - baseline_grad)).item()
entropy_max_diff = torch.max(
torch.abs(distributed_entropy - baseline_entropy)
).item()

return {
"forward_max_diff": forward_max_diff,
"grad_max_diff": grad_max_diff,
"entropy_max_diff": entropy_max_diff,
}

def test_edge_cases(self):
"""Test edge cases like extreme values for numerical stability."""
rank = int(os.environ["RANK"])

# Test parameters
batch_size = 2
seq_len = 4
full_vocab_size = 128
vocab_part_size = full_vocab_size // self.tp_size

vocab_start_index = rank * vocab_part_size
vocab_end_index = (rank + 1) * vocab_part_size

# Test 1: Very large logits (test numerical stability)
torch.manual_seed(42)
large_logits = (
torch.randn(batch_size, seq_len, full_vocab_size, device="cuda") * 100
) # Large values
vocab_parallel_logits = large_logits[
:, :, vocab_start_index:vocab_end_index
].clone()

# Should not produce NaN or Inf
entropy = ChunkedDistributedEntropy.apply(
vocab_parallel_logits,
self.chunk_size,
self.tp_group,
True,
)

assert not torch.isnan(entropy).any(), "Entropy contains NaN"
assert not torch.isinf(entropy).any(), "Entropy contains Inf"

# Test 2: Compare with baseline for large values
baseline_entropy = self._torch_baseline_entropy(large_logits)
torch.testing.assert_close(entropy, baseline_entropy, rtol=1e-4, atol=1e-4)

# Test 3: Uniform distribution (maximum entropy case)
uniform_logits = torch.zeros(
batch_size, seq_len, full_vocab_size, device="cuda"
) # All equal -> uniform distribution
vocab_parallel_uniform = uniform_logits[
:, :, vocab_start_index:vocab_end_index
].clone()

entropy_uniform = ChunkedDistributedEntropy.apply(
vocab_parallel_uniform,
self.chunk_size,
self.tp_group,
True,
)

# For uniform distribution over V items: H = -sum(1/V * log(1/V)) = log(V)
expected_uniform_entropy = self._torch_baseline_entropy(uniform_logits)

# All positions should have the same entropy
torch.testing.assert_close(
entropy_uniform,
expected_uniform_entropy.expand_as(entropy_uniform),
rtol=1e-4,
atol=1e-4,
)

return {"success": True}


CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN = (
f"{ChunkedDistributedEntropyTestActor.__module__}.ChunkedDistributedEntropyTestActor"
)


@pytest.fixture
def register_chunked_distributed_entropy_test_actor():
"""Register the ChunkedDistributedEntropyTestActor for use in tests."""
original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get(
CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN
)
ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = (
PY_EXECUTABLES.SYSTEM
)

yield CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN

# Clean up registry
if CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY:
if original_registry_value is None:
del ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN]
else:
ACTOR_ENVIRONMENT_REGISTRY[CHUNKED_DISTRIBUTED_ENTROPY_TEST_ACTOR_FQN] = (
original_registry_value
)


@pytest.mark.parametrize(
"tp_size, chunk_size",
[
(1, 4),
(2, 4),
(1, 1),
(2, 1),
],
)
def test_chunked_distributed_entropy_all_tests(
register_chunked_distributed_entropy_test_actor, tp_size, chunk_size
):
"""Test all ChunkedDistributedEntropy functionality for a given TP size."""
# Skip if not enough GPUs
if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size:
pytest.skip(
f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}"
)

cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True)

try:
actor_fqn = register_chunked_distributed_entropy_test_actor

# Create sharding for TP
sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"])
builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size)

worker_group = RayWorkerGroup(
cluster=cluster,
remote_worker_builder=builder,
workers_per_node=None,
sharding_annotations=sharding,
)

# Test 1: Combined Forward and Backward pass
print(
f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: ChunkedDistributedEntropy Forward & Backward Pass ==="
)
futures = worker_group.run_all_workers_single_data(
"test_chunked_distributed_entropy_forward_and_backward"
)
results = ray.get(futures)

for i, result in enumerate(results):
if "forward_max_diff" in result:
print(f"Worker {i} forward max diff: {result['forward_max_diff']:.2e}")
if "grad_max_diff" in result and "entropy_max_diff" in result:
print(
f"Worker {i} gradient max diff: {result['grad_max_diff']:.2e}, "
f"entropy max diff: {result['entropy_max_diff']:.2e}"
)

# Test 2: Edge cases
print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Edge Cases ===")
futures = worker_group.run_all_workers_single_data("test_edge_cases")
results = ray.get(futures)
for i, result in enumerate(results):
if "success" in result:
print(f"Worker {i} edge cases test: {'PASSED' if result['success'] else 'FAILED'}")

worker_group.shutdown(force=True)

finally:
cluster.shutdown()
Loading