diff --git a/cvs/monitors/cluster-mon/.gitignore b/cvs/monitors/cluster-mon/.gitignore index b6fc8d77..58b31661 100644 --- a/cvs/monitors/cluster-mon/.gitignore +++ b/cvs/monitors/cluster-mon/.gitignore @@ -5,6 +5,7 @@ __pycache__/ *.so .Python venv/ +.venv/ env/ ENV/ *.egg-info/ diff --git a/cvs/monitors/cluster-mon/backend/app/api/__init__.py b/cvs/monitors/cluster-mon/backend/app/api/__init__.py index c2d1c6e8..e1629251 100644 --- a/cvs/monitors/cluster-mon/backend/app/api/__init__.py +++ b/cvs/monitors/cluster-mon/backend/app/api/__init__.py @@ -3,7 +3,7 @@ """ from fastapi import APIRouter -from app.api import cluster, nodes, metrics, config, software, restart, packages, logs, ssh_keys +from app.api import cluster, nodes, metrics, config, software, restart, packages, logs, ssh_keys, collectors, rccl_endpoints router = APIRouter() @@ -17,3 +17,5 @@ router.include_router(packages.router, prefix="/packages", tags=["packages"]) router.include_router(logs.router, prefix="/logs", tags=["logs"]) router.include_router(ssh_keys.router, prefix="/ssh-keys", tags=["ssh-keys"]) +router.include_router(collectors.router, prefix="/collectors", tags=["collectors"]) +router.include_router(rccl_endpoints.router, prefix="/rccl", tags=["rccl"]) diff --git a/cvs/monitors/cluster-mon/backend/app/api/collectors.py b/cvs/monitors/cluster-mon/backend/app/api/collectors.py new file mode 100644 index 00000000..f53fedf7 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/api/collectors.py @@ -0,0 +1,76 @@ +""" +Collectors status API endpoint. +Returns per-collector state and aggregate overall_status. +""" + +from fastapi import APIRouter +from typing import Any + +router = APIRouter() + + +def _compute_overall_status(collector_results: dict, collectors_meta: dict) -> str: + """ + Compute aggregate status from per-collector results. + + - "healthy" : all collectors in OK or NO_SERVICE state + - "degraded" : some collectors erroring, but no critical ones + - "critical" : any collector marked critical=True is in ERROR or UNREACHABLE + """ + if not collector_results: + return "healthy" + + for name, result in collector_results.items(): + state = result.state if hasattr(result, 'state') else result.get('state', 'ok') + state_str = state.value if hasattr(state, 'value') else str(state) + is_error = state_str in ("error", "unreachable") + is_critical = collectors_meta.get(name, {}).get('critical', False) + if is_error and is_critical: + return "critical" + + any_error = any( + (r.state.value if hasattr(r.state, 'value') else str(r.state)) in ("error", "unreachable") + for r in collector_results.values() + if hasattr(r, 'state') + ) + return "degraded" if any_error else "healthy" + + +@router.get("/status") +async def get_collectors_status() -> dict[str, Any]: + """ + Return per-collector state and aggregate overall_status. + + Response shape: + { + "gpu": {"state": "ok", "timestamp": "...", "error": null}, + "nic": {"state": "ok", "timestamp": "...", "error": null}, + "rccl": {"state": "no_service", "timestamp": "...", "error": "..."}, + "overall_status": "healthy" + } + """ + from app.main import app_state, REGISTERED_COLLECTORS + + # Build collectors metadata (critical flag) from REGISTERED_COLLECTORS + collectors_meta = { + cls.name: {"critical": getattr(cls, "critical", False)} + for cls in REGISTERED_COLLECTORS + } + + result: dict[str, Any] = {} + for name, collector_result in app_state.collector_results.items(): + state_val = ( + collector_result.state.value + if hasattr(collector_result.state, 'value') + else str(collector_result.state) + ) + result[name] = { + "state": state_val, + "timestamp": collector_result.timestamp, + "error": collector_result.error, + } + + result["overall_status"] = _compute_overall_status( + app_state.collector_results, collectors_meta + ) + return result diff --git a/cvs/monitors/cluster-mon/backend/app/api/config.py b/cvs/monitors/cluster-mon/backend/app/api/config.py index f1a93e4d..18874229 100644 --- a/cvs/monitors/cluster-mon/backend/app/api/config.py +++ b/cvs/monitors/cluster-mon/backend/app/api/config.py @@ -184,9 +184,8 @@ async def update_configuration(config: ClusterConfigUpdate) -> Dict[str, Any]: if config.jump_host.auth_method == "password" and config.jump_host.password: # Store in memory app_state.jump_host_password = config.jump_host.password - # Also save to YAML for development/testing (WARNING: Not secure for production) - cluster_config["cluster"]["ssh"]["jump_host"]["password"] = config.jump_host.password - logger.warning("⚠️ Jump host password saved to cluster.yaml - NOT RECOMMENDED FOR PRODUCTION") + # SECURITY: Password is stored in memory only (app_state), never persisted to disk + # cluster_config["cluster"]["ssh"]["jump_host"]["password"] is intentionally omitted else: app_state.jump_host_password = None # Remove password from YAML if using key-based auth @@ -251,7 +250,7 @@ async def get_current_configuration() -> Dict[str, Any]: """ Get current configuration including all SSH and jump host settings. """ - from app.core.simple_config import config as settings + from app.core.config import settings nodes = settings.load_nodes_from_file() diff --git a/cvs/monitors/cluster-mon/backend/app/api/rccl_endpoints.py b/cvs/monitors/cluster-mon/backend/app/api/rccl_endpoints.py new file mode 100644 index 00000000..57def3d3 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/api/rccl_endpoints.py @@ -0,0 +1,141 @@ +""" +RCCL monitoring REST API endpoints. +Phase 1: status, communicators, events, markers. +""" + +import logging +from typing import Any, Optional +from fastapi import APIRouter, HTTPException, Query +from app.models.rccl_models import RCCLMarker + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/status") +async def get_rccl_status() -> dict[str, Any]: + """ + Current RCCL job state and communicator health summary. + Returns the latest snapshot from app_state.latest_rccl_snapshot. + Falls back to {'state': 'no_job'} if no snapshot yet collected. + """ + from app.main import app_state + + snapshot = getattr(app_state, 'latest_rccl_snapshot', None) + if snapshot is None: + return {"state": "no_job", "message": "No RCCL snapshot collected yet"} + return snapshot + + +@router.get("/communicators") +async def get_rccl_communicators() -> list[dict]: + """All communicators with per-rank detail from the latest snapshot.""" + from app.main import app_state + + snapshot = getattr(app_state, 'latest_rccl_snapshot', None) + if snapshot is None: + return [] + return snapshot.get("communicators", []) + + +@router.get("/communicators/{comm_hash}") +async def get_rccl_communicator(comm_hash: str) -> dict[str, Any]: + """Single communicator deep-dive by hash.""" + from app.main import app_state + + snapshot = getattr(app_state, 'latest_rccl_snapshot', None) + if snapshot is None: + raise HTTPException(status_code=404, detail="No snapshot available") + for comm in snapshot.get("communicators", []): + if comm.get("comm_hash") == comm_hash: + return comm + raise HTTPException(status_code=404, detail=f"Communicator {comm_hash!r} not found") + + +@router.get("/events") +async def get_rccl_events( + since: Optional[float] = Query(None, description="Start timestamp (unix)"), + until: Optional[float] = Query(None, description="End timestamp (unix)"), + event_type: Optional[str] = Query(None, alias="type"), +) -> dict: + """Filtered event log from Redis event stream (or in-memory fallback).""" + from app.main import app_state + import time + + data_store = getattr(app_state, 'rccl_data_store', None) + if data_store is None: + return {"events": [], "truncated": False} + + start = since or (time.time() - 3600) # default: last hour + end = until or time.time() + events = await data_store.get_events_in_range(start, end) + + if event_type: + events = [e for e in events if e.get("event_type") == event_type] + + return { + "events": events, + "truncated": data_store.is_memory_capped, + } + + +@router.get("/performance") +async def get_rccl_performance() -> dict: + """ + Latest Inspector performance snapshot. + Returns bandwidth stats across all ranks from the most recent poll cycle. + Returns 503 when Inspector collector is disabled or has not run yet. + """ + from app.main import app_state + from fastapi import Response + import time + + data_store = getattr(app_state, 'rccl_data_store', None) + if data_store is None: + raise HTTPException(status_code=503, detail="Data store not initialized") + + snapshot = await data_store.get_inspector_current() + if snapshot is None: + raise HTTPException( + status_code=503, + detail="No Inspector snapshot available. Check that rccl.inspector.enabled=true and a job is running.", + ) + return snapshot + + +@router.get("/performance/history") +async def get_rccl_performance_history( + count: int = Query(50, ge=1, le=500, description="Number of snapshots to return"), +) -> dict: + """ + Recent Inspector performance snapshots for time-series charting. + Returns up to `count` snapshots, newest first. + """ + from app.main import app_state + + data_store = getattr(app_state, 'rccl_data_store', None) + if data_store is None: + return {"snapshots": []} + + snapshots = await data_store.get_inspector_snapshots(count=count) + return {"snapshots": snapshots, "count": len(snapshots)} + + +@router.post("/markers", status_code=201) +async def post_rccl_marker(marker: RCCLMarker) -> dict[str, str]: + """ + PyTorch callback endpoint for training step/loss markers. + Stores marker as an event in the RCCL event stream. + """ + from app.main import app_state + import time + + event = marker.model_dump() + event.setdefault("event_type", "training_marker") + event.setdefault("timestamp", time.time()) + + data_store = getattr(app_state, 'rccl_data_store', None) + if data_store: + await data_store.push_event(event) + + return {"status": "accepted"} diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/base.py b/cvs/monitors/cluster-mon/backend/app/collectors/base.py new file mode 100644 index 00000000..113e4ade --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/base.py @@ -0,0 +1,148 @@ +""" +BaseCollector ABC and supporting types for CVS cluster-mon collectors. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +try: + from enum import StrEnum # Python 3.11+ +except ImportError: + class StrEnum(str, Enum): # type: ignore[no-redef] + """Backport for Python 3.10: str+Enum with correct f-string behavior.""" +from typing import Any, Optional +import asyncio +import logging + +logger = logging.getLogger(__name__) + + +class CollectorState(StrEnum): + OK = "ok" + NO_SERVICE = "no_service" # Service not running (e.g. no RCCL job) + UNREACHABLE = "unreachable" # SSH/TCP timeout — node down + ERROR = "error" # Parse or protocol failure + + +@dataclass +class CollectorResult: + collector_name: str + timestamp: str # ISO-8601 UTC string + state: CollectorState + data: dict[str, Any] + error: Optional[str] = None + node_errors: dict[str, bool] = field(default_factory=dict) + + @staticmethod + def now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +class BaseCollector(ABC): + name: str # class-level attribute — set on each subclass + poll_interval: int # seconds between collection cycles + collect_timeout: float # max seconds per collect() call + critical: bool = False # if True, failures affect overall_status as "critical" + + @abstractmethod + async def collect(self, ssh_manager) -> CollectorResult: + """ + One collection cycle. Must NOT raise — all errors go into CollectorResult. + ssh_manager is Union[Pssh, JumpHostPssh]. + Must call ssh_manager.exec_async() (not exec()) to avoid blocking the event loop. + """ + ... + + async def on_collect_timeout(self, app_state: Any) -> None: + """ + Called when collect() is cancelled by the outer collect_timeout. + Override in subclasses to update internal state machines on timeout. + The default implementation is a no-op. + """ + + async def run(self, ssh_manager, app_state: Any) -> None: + """ + Default task body. Loops until app_state.is_collecting is False. + Wraps collect() in asyncio.wait_for to enforce collect_timeout. + Subclasses with non-poll lifecycles (e.g. RCCL monitor mode) override this. + """ + while app_state.is_collecting: + try: + result = await asyncio.wait_for( + self.collect(ssh_manager), + timeout=self.collect_timeout, + ) + except asyncio.TimeoutError: + logger.warning( + f"{self.name} collect() timed out after {self.collect_timeout}s" + ) + await self.on_collect_timeout(app_state) + result = CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=f"collect() timed out after {self.collect_timeout}s", + ) + except asyncio.CancelledError: + raise + except ConnectionError as e: + logger.error(f"{self.name} collector ConnectionError: {e}") + if hasattr(app_state, 'probe_requested') and app_state.probe_requested: + app_state.probe_requested.set() + result = CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.UNREACHABLE, + data={}, + error=str(e), + ) + except Exception as e: + logger.error( + f"{self.name} collector unexpected error: {e}", exc_info=True + ) + result = CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=str(e), + ) + + # Store per-collector result + if hasattr(app_state, 'collector_results'): + app_state.collector_results[self.name] = result + + # Node health update — only GPU collector populates node_errors + if hasattr(app_state, 'node_health_status'): + for node, has_error in result.node_errors.items(): + _update_node_status_via_app_state(app_state, node, has_error) + + # Update latest_metrics for WebSocket broadcast + if hasattr(app_state, 'latest_metrics'): + app_state.latest_metrics[self.name] = result.data + # Shared timestamp key: last-writer-wins across collectors. + # This preserves the existing WebSocket contract + # {"gpu": ..., "nic": ..., "timestamp": "..."}. + # Clients needing per-collector timestamps should use + # GET /api/collectors/status instead. + app_state.latest_metrics["timestamp"] = result.timestamp + + # Broadcast (imported lazily to avoid circular imports) + try: + from app.main import broadcast_metrics + await broadcast_metrics(app_state.latest_metrics) + except Exception as e: + logger.debug(f"broadcast_metrics not available: {e}") + + await asyncio.sleep(self.poll_interval) + + +def _update_node_status_via_app_state(app_state: Any, node: str, has_error: bool) -> None: + """Update node health status via app_state. Avoids circular import from main.""" + try: + from app.main import update_node_status + update_node_status(node, has_error, "unreachable") + except Exception as e: + logger.debug(f"update_node_status not available: {e}") diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/cvs_rccl_monitoring_technical_report.md b/cvs/monitors/cluster-mon/backend/app/collectors/cvs_rccl_monitoring_technical_report.md new file mode 100644 index 00000000..dafea394 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/cvs_rccl_monitoring_technical_report.md @@ -0,0 +1,316 @@ +# CVS Cluster-Mon: RCCL Monitoring Extension + +> **Date:** 2026-04-05 +> **Branch:** `users/nileshnegi/add-rcclras-inspector-support` +> **Status:** rcclras health monitoring tested on Ruby MI350 cluster (1-node, 2-node) +> **Status:** RCCL Inspector performance monitoring tested on CV350 MI350 cluster + +--- + +## Problem Statement + +Large-scale distributed training and inference jobs using RCCL often suffer from opaque failure modes: hangs from communicator deadlocks, silent performance degradation from degraded network links, segfaults from GPU memory errors, and cascading failures when a single node becomes unresponsive. Today, users have no unified way to observe RCCL's internal state during a live job — they resort to ad-hoc NCCL_DEBUG log analysis after the fact, losing critical temporal context. + +NCCL ships a built-in RAS (Reliability, Availability, Serviceability) subsystem that exposes communicator health, peer mesh connectivity, and lifecycle events through a dedicated TCP service (`ncclras`). RCCL inherits this subsystem as `rcclras`, but no tooling exists to leverage it for continuous monitoring, correlation with system-level metrics (GPU health, RDMA errors, kernel logs), or integration with application-level signals (training step progression, loss curves). + +A second observability gap exists on the performance side. Even when a job is healthy — all ranks alive, no dead peers — it may still run slower than expected due to a single straggler rank bottlenecking the collective. NCCL ships a profiler plugin interface (`nccl_profiler.h`) and a reference implementation called the Inspector plugin (`ext-profiler/inspector/`) that records per-collective bandwidth and latency to disk as JSONL files. RCCL v2.28.3 ships the Inspector plugin as well. However, no tooling exists to continuously read these files, surface the per-rank bandwidth breakdown, and correlate it with health events from rcclras — users must manually inspect raw log files after the fact. + +--- + +## What does this "CVS RCCL Monitoring Extension" do? + +CVS `cluster-mon` can now monitor live RCCL jobs in real time across two complementary channels: + +**Health monitoring via `rcclras`:** Connects directly to the `rcclras` TCP service embedded in every RCCL process. When a rank dies, hangs, or loses connectivity, the dashboard reflects it within one poll cycle — no log parsing, no application-level timeout required. + +**Performance monitoring via the `RCCL Inspector plugin`:** Reads JSONL files produced by the RCCL Inspector profiler plugin (`ext-profiler/inspector/`) to surface per-collective bus bandwidth, algorithm bandwidth, and latency — broken down by rank. Identifies stragglers (the slowest rank relative to peers) automatically. + +These two channels are complementary, not competing: + +| Capability | rcclras | Inspector | +|---|---|---| +| Rank alive / dead / missing | Yes | No | +| Dead peer detection | Yes | No | +| Per-collective bus bandwidth (GB/s) | No | Yes | +| Per-collective latency (µs) | No | Yes | +| Straggler rank identification | No | Yes | +| Message size per collective | No | Yes | + +Based on a cursory search, this is the only open-source tool that uses the RAS interface for RCCL/NCCL monitoring. Industry practice relies on post-mortem log analysis and training-side watchdog timeouts. Neither gives users visibility into communicator state while the job is still running. + +--- + +## Background: What is `rcclras` + +`rcclras` is a TCP server embedded in every RCCL process (port 28028, IPv6 loopback `[::1]` only). It exposes the internal communicator state machine via a line-oriented ASCII protocol: + +``` +CLIENT PROTOCOL 2 → handshake +SERVER PROTOCOL 2 ← +TIMEOUT 30 → set collective timeout +OK ← +VERBOSE STATUS → request full dump + ← streams header, then waits for all ranks to + check in before sending communicator table +``` + +The server streams the response in two bursts. It sends the header and job summary immediately, then blocks until all ranks report before appending the communicator table. The client must read until EOF to get the full response. + +`rcclras` is not reachable directly from the CVS backend — it binds only to the IPv6 loopback interface on each compute node. Access goes through an SSH port-forward tunnel. + +--- + +## Background: What is the `RCCL Inspector Plugin` + +The Inspector plugin is a profiler plugin shipped with RCCL v2.28.3 at `ext-profiler/inspector/`. It hooks into RCCL's profiler interface (`nccl_profiler.h`, `ncclProfiler_v5`) and records timing data for every completed collective operation. + +**Activation** is entirely via environment variables — no changes to the training script or RCCL source are required: + +```bash +NCCL_PROFILER_PLUGIN=/path/to/librccl-profiler-inspector.so +(or) +NCCL_PROFILER_PLUGIN=inspector +LD_LIBRARY_PATH=/path/to/dir/containing/librccl-profiler-inspector.so:$LD_LIBRARY_PATH + +NCCL_INSPECTOR_ENABLE=1 +NCCL_INSPECTOR_DUMP_THREAD_INTERVAL_MICROSECONDS=1000000 # 1 s recommended +NCCL_INSPECTOR_DUMP_DIR=/path/to/shared/inspector-logs +``` + +**Output format:** one JSONL file per process, named `-pid.log`. Each line is one JSON object representing the most recently completed collective for a communicator during a dump interval ("latest snapshot" model — not a complete event log): + +```json +{ + "header": { "id": "0x7f8c496ae9f661", "rank": 0, "n_ranks": 16, "nnodes": 2 }, + "metadata": { "inspector_output_format_version": "v4.0", "dump_timestamp_us": 1711800000000000, + "hostname": "gpu-node-01", "pid": 12345 }, + "coll_perf": { + "coll": "AllReduce", "coll_sn": 4217, "coll_msg_size_bytes": 2097152, + "coll_exec_time_us": 412, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 193.6, "coll_busbw_gbs": 387.2 + } +} +``` + +**Bandwidth formulas** (verified against NCCL-tests documentation): + +| Collective | algBw | busBw factor | +|---|---|---| +| AllReduce | `msgBytes / execTime` | `2*(n-1)/n` | +| AllGather, ReduceScatter | `msgBytes*nranks / execTime` | `(n-1)/n` | +| Reduce, Broadcast, Send, Recv | `msgBytes / execTime` | 1 | + +**Timing source hierarchy:** GPU clock (`kernel_gpu`) > CPU kernel timing (`kernel_cpu`) > CPU collective timing (`collective_cpu`). + +**GPU clock units:** AMD `wall_clock64()` / HIP `clock64()` runs at **100 MHz (10 ns per tick)**, not 1 GHz. The Inspector converts ticks to µs as `ticks / 100`. Using `/ 1000` (treating ticks as nanoseconds) would produce exec times 10× too small and bandwidth 10× too large. + +**Graceful degradation:** RCCL's `src/plugin/profiler.cc` catches any Inspector init failure and continues without profiling — the job always proceeds regardless of Inspector state. + +--- + +## Architecture + +``` +rcclras :28028 (IPv6 loopback, each compute node) + │ + │ SSH port-forward tunnel (Pssh / JumpHostPssh) + ▼ +RCCLRasClient ── VERBOSE STATUS ──► RCCLTextParser + │ + RCCLSnapshot ──────────────────────────────┐ + │ │ active PIDs + ┌───────────────────┼───────────────────┐ ▼ + ▼ ▼ ▼ InspectorCollector + Redis Streams app_state /ws/rccl (SSH exec_async) + (ring buffer) latest_rccl_snapshot WebSocket │ + │ InspectorParser + REST API ──► Frontend (4 pages) │ + InspectorSnapshot + │ + ┌───────────────────┤ + ▼ ▼ + Redis Streams /api/rccl/ + rccl:inspector:* performance +``` + +**Collector cadence:** 30-second poll interval for rcclras; 10-second poll interval for Inspector. The rcclras collector tries each healthy node in turn until it finds one with an active listener on port 28028. One successful response per cycle is sufficient — all ranks within a job report to the same rcclras instance. + +**State machine:** The collector tracks job state across polls. + +``` +NO_JOB ──► HEALTHY ──► DEGRADED ──► NO_JOB + │ ▲ + └──────── NO_JOB ─────────┘ + │ + UNREACHABLE + │ + ERROR +``` + +Every state transition emits a typed event (`job_start`, `job_end`, `job_degraded`, `job_recovered`, `node_unreachable`, etc.) stored in the event stream and visible on the Timeline page. + +--- + +## Components + +### RCCLRasClient +Async TCP client for the rcclras wire protocol. Takes a pre-connected `asyncio.StreamReader/Writer` from the SSH port-forward context manager. Handles the handshake, timeout setting, and VERBOSE STATUS dump. Reads until EOF in a loop — a single `read(n)` returns only the first burst and misses the communicator table. + +Includes protocol version guards for `SET FORMAT json` (protocol v3, rcclras v2.28.9) and `MONITOR` mode (protocol v4, rcclras v2.29.2) — these are not yet enabled but the client won't send unknown commands to older servers. + +### RCCLTextParser +Regex parser for the rcclras v2.28.3 VERBOSE STATUS text format. Built and tested against real output captured from a live MI300X cluster. Extracts: + +- **Job summary** — node count, process count, GPU count, RCCL version, HIP/driver versions +- **Communicator table** — group number, comm count, rank counts, status, error column +- **Dead peers** — IP:port of unreachable peers +- **Errors section** — raw error lines reported by rcclras + +The parser determines job state from the parsed data: `NO_JOB` if no valid output, `DEGRADED` if any communicator has missing ranks, dead peers, or errors, `HEALTHY` otherwise. + +### RCCLCollector +`BaseCollector` subclass running on a 30-second cycle. Key behaviours: + +- **Leader selection:** tries all healthy nodes (from `node_health_status`) in order until one has an active rcclras listener on port 28028. +- **Bootstrap:** on first poll after startup, seeds `job_state` from the last stored snapshot to avoid emitting a spurious `job_start` event on backend restart. +- **State transfer on config reload:** when configuration is reloaded and the collector is restarted, the previous `job_state` is copied to the new instance — same reason. +- **Timeout handling:** if the outer `asyncio.wait_for` fires, `on_collect_timeout()` updates the state machine to UNREACHABLE so the next cycle doesn't start from a stale state. + +### InspectorCollector + +`BaseCollector` subclass polling on a 10-second cycle (`critical = False` — Inspector failure never affects the overall cluster health status). Two collection modes: + +**File mode** (`mode: file`): reads `*.log` files directly from a locally-mounted NFS path. Zero SSH overhead. Requires the Inspector `dump_dir` to be visible from the CVS backend host. + +**SSH mode** (`mode: ssh`): runs `tail -n ` on each compute node via `exec_async`. Used when NFS is not available. When the rcclras snapshot is present, only log files matching active PIDs are read — stale files from previous runs in the same `dump_dir` are ignored automatically. Falls back to reading all `*-pid*.log` files when no snapshot is available (e.g., before rcclras has connected). + +The Inspector plugin names log files using the result of `gethostname()`, which is typically the FQDN (e.g., `cv350-zts-gtu-g31a-18.prov.gtu.zts.cpe.ice.amd.com-pid3404720.log`). The glob pattern uses `*-pid*.log` to avoid short-hostname vs. FQDN mismatches. + +### InspectorParser + +JSONL parser for the RCCL Inspector v4.0 output format. Reads the tail of each log file (bounded by `max_records_per_file`, default 500) and parses each line independently — malformed lines are skipped silently at DEBUG level. Produces `InspectorCollPerf` records. + +`aggregate_snapshot()` computes bandwidth statistics (avg/min/max busBw, straggler rank) from the full tail window, then deduplicates the `records` field to the **latest entry per (rank, comm_hash)** by `sequence_num` before storing. This keeps the WebSocket payload and frontend table bounded to one row per rank regardless of tail window size. + +#### Bugs fixed in Inspector v2.28.3 (branch `users/nileshnegi/rccl/inspector-fixes`) + +Five bugs in the Inspector plugin prevented it from producing valid output on RCCL v2.28.3. All five were diagnosed and fixed: + +| # | Bug | Root Cause | Fix | +|---|---|---|---| +| 1 | **Log files always empty** (`inspector_plugin.cc`) | `ncclTaskColl::nChannels` is uninitialized at allocation (`ncclMemoryPoolAlloc` does not zero memory). Garbage value (224) made the dump condition `nKernelChCompleted == nChannels` unsatisfiable — no collective ever completed. | Add `collStopFired` flag to `inspectorCollInfo`. Trigger dump when `collStopFired && nKernelChCompleted == nKernelChStarted == nChannels`. Initialize `nChannels = 0` at allocation in `enqueue.cc` so `scheduleCollTasksToPlan` overwrites it with the correct value before the profiler reads it. | +| 2 | **Teardown hang** (`src/transport/profiler.cc`) | `profilerProxyProgress` polled GPU counters indefinitely during teardown for channels that were never dispatched to GPU. The proxy thread never marked those channels done, blocking RCCL teardown. | Detect teardown (`proxyState->progressState.stop`) and drain: skip Start+Stop for channels whose GPU start counter was never written; skip Stop for channels whose GPU stop counter was never written. | +| 3 | **Zero bandwidth / garbage exec time** (`inspector.cc`) | `calculateMaxKernelExecTimeUsecs` iterated `for (i = 0; i < nChannels; i++)` using the same garbage `nChannels = 224`. Channels 57–223 had uninitialized `tsStartUsec` / `tsCompletedUsec` (random memory values), producing a spuriously large max exec time. | Normalize `collInfo->nChannels = collInfo->nKernelChStarted` and `collEvtTrk.nChannels = nKernelChStarted` before calling `inspectorUpdateCollPerf`, so the timing loop and JSON dump loop iterate only over channels that actually fired. | +| 4 | **GPU bandwidth 10× too large** (`inspector.cc`) | `calculateKernelGpuExecTimeUsecs` computed `execTimeNanosecs / 1000` treating AMD `wall_clock64()` ticks as 1 GHz nanoseconds. AMD GPU hardware timer runs at **100 MHz (10 ns/tick)**, so the correct divisor is 100, not 1000. | Change `execTimeNanosecs / 1000` to `ticks / 100`. Verified: reported busBw (387–388 GB/s) matches rccl-tests output (386 GB/s) within noise. | +| 5 | **Channels under-counted in dump** (`inspector_plugin.cc` + `enqueue.cc`) | The dump fired when `nKernelChCompleted == nKernelChStarted`, but some channels' GPU start counters are written to host memory slightly later than others. The dump fired on the first batch (e.g. 43 of 48), freeing `collInfo`. Remaining channels found a freed (zeroed) `parentObj` and became no-ops. | Fix 1's `enqueue.cc` change restores the correct `nChannels = 48` so the dump condition correctly waits for all channels. `collStopFired` ensures the dump never fires before CollStop is recorded. | + +#### Remaining known bugs (not fixed — low impact for CVS use) + +| Bug | Impact | +|---|---| +| ReduceScatter `trafficSize` may be inflated by `nranks` | busBw for ReduceScatter may be wrong; verify on cluster before alerting | +| Unsigned underflow in collective CPU fallback: no guard on `tsCompletedUsec < tsStartUsec` | Very unlikely (requires clock going backwards); produces garbage large value if it occurs | + +### RCCLDataStore +Dual-mode storage backend: + +| Mode | When | Storage | Capacity | +|------|------|---------|----------| +| **Redis Streams** | Redis available | `rccl:snapshots`, `rccl:events`, `rccl:inspector:snapshots` | 1 000 snapshots, 10 000 events each | +| **In-memory deque** | No Redis / Redis error | `collections.deque` | 500 events, 100 snapshots, 100 Inspector snapshots | + +Redis mode uses `XADD ... MAXLEN` — atomic append and cap in a single command. Time-range event queries use Redis Stream entry IDs (millisecond timestamps embedded). The in-memory fallback activates automatically if Redis is unavailable or throws an exception mid-operation. + +### REST API + +| Endpoint | Description | +|----------|-------------| +| `GET /api/rccl/status` | Latest snapshot: state, job summary, communicators, errors | +| `GET /api/rccl/communicators` | Communicator list from latest snapshot | +| `GET /api/rccl/communicators/{hash}` | Single communicator detail | +| `GET /api/rccl/events?since=&until=&type=` | Time-filtered event log. Returns `{events, truncated}` — `truncated: true` when the in-memory buffer is at capacity and older events may be missing | +| `POST /api/rccl/markers` | PyTorch training step/loss callback. Stores as `training_marker` event | +| `GET /api/rccl/performance` | Latest Inspector snapshot: avg/min/max busBw, straggler rank, per-rank table, collective breakdown. Returns 503 when Inspector is disabled or no data yet | +| `GET /api/rccl/performance/history?count=N` | Up to N recent Inspector snapshots for time-series charting (max 500) | +| `WebSocket /ws/rccl` | Real-time snapshot push on every collector cycle | + +### Frontend Pages + +**RCCL Health** — primary view. Shows job state banner (Healthy / Degraded / Unreachable / No Job), a staleness indicator when the snapshot is more than 75 seconds old (2.5× the poll interval), the raw rcclras Errors section when present, and a communicator card per group showing total/responding/missing rank counts. + +**RAS Topology** — peer mesh visualization. Disabled for rcclras v2.28.3, which does not include per-peer connectivity in its text output. A compatibility note is shown; peer data is expected in a future rcclras version. + +**Timeline** — chronological event log with type filter (job_start, job_end, degraded, recovered, etc.) and time-range selector. Shows `from_state → to_state` for state-change events and step/loss for training markers. + +**RCCL Performance** — Inspector data view. Shows avg/min/max bus bandwidth summary cards, collective breakdown table (call count per collective type), and a per-rank bandwidth table with a proportional bar chart. The slowest rank (straggler) is highlighted in red. Polls `/api/rccl/performance` every 15 seconds. Shows a descriptive 503 message when Inspector is not active. + +Each rank row is expandable to show the **per-channel kernel trace** (requires `NCCL_INSPECTOR_DUMP_VERBOSE=1`). Columns: `channel_idx`, `start_event` / `stop_event` / `record_event` (monotonic sequence numbers across all profiler events for the collective, showing relative ordering), `start_timestamp` / `end_timestamp` (CPU epoch µs from `gettimeofday`, recorded when the proxy thread observed the GPU counter), and `duration (µs)` (end − start, already in µs — no clock-tick conversion applied). + +Note on min busBw: the summary card min and slowest-rank are computed over **all records in the tail window**, not just the latest per rank. Warmup collectives (cold GPU on first iteration) typically appear in the tail and produce a low min — this is expected and does not indicate a genuine straggler. + +--- + +## RCCL Health — Live Screenshots + +### Healthy State +All 16 ranks across 2 nodes responding. `group_0` communicator: 16/16 responding, 0 missing. + +![RCCL Health - Healthy](images/rccl_health_good.png) + +### Degraded State +One rank dropped mid-job. rcclras identifies the exact rank (Rank 7), GPU (GPU 7), PID (3871587), and node (10.245.40.180) in its Errors section. The communicator card reflects 15/16 responding, 1 missing. + +![RCCL Health - Degraded](images/rccl_health_degraded.png) + +--- + +## RCCL Performance — Live Screenshots + +### Performance Overiew +RCCL Performance Summary (min/max/avg bus_bandwidth, straggler rank(s), and collective breakdown) + +![RCCL Performance - Overview](images/rccl_perf_overview.png) + +### Channel Breakdown +When run with `NCCL_INSPECTOR_DUMP_VERBOSE=1`, RCCL Inspector plugin can record channel count and GPU timing events per channel. + +![RCCL Performance - Channel Breakdown](images/rccl_perf_channel_breakdown.png) + +--- + +## Known Limitations — rcclras v2.28.3 + +**Single communicator group visible.** rcclras v2.28.3 exposes only the communicator group that rank 0 belongs to. A job using 8 independent communicator groups across 16 GPUs will show only one group. Full multi-group visibility is expected in a later rcclras version. + +**No per-peer connectivity data.** The v2.28.3 text format does not include peer-level mesh data. The RAS Topology page is present but shows a compatibility notice until the data is available. + +**In-memory events do not survive restarts.** Without Redis, events are held in a bounded in-memory buffer (500 events). Restarting the backend clears this history. Redis is not required for the core health dashboard — only for event history retention across restarts. + +--- + +## Known Limitations — RCCL Inspector v2.28.3 + +**Activation requires job-side configuration.** The Inspector plugin must be built from RCCL source (`make` in `ext-profiler/inspector/`) and enabled via env vars in the job submission script (`NCCL_PROFILER_PLUGIN=inspector`, `NCCL_INSPECTOR_ENABLE=1`, `NCCL_INSPECTOR_DUMP_DIR=`). CVS reads the output files; it cannot enable the plugin retroactively. + +**Latest-snapshot sampling model.** The Inspector writes only the most recently completed collective per communicator per dump interval — not a complete event log. Collectives that finish between two dump wakeups are overwritten and lost. The dump interval is a sampling rate, not a capture window. A 1 s interval (1 000 000 µs) gives one fresh record per second per rank, which is sufficient for bandwidth trending in a 30 s CVS poll cycle. + +**dump_dir accumulates stale files.** Files from previous runs persist in the same `dump_dir` until manually cleaned. CVS mitigates this by filtering to PIDs known from the current rcclras snapshot; when rcclras is unavailable, all files are read and deduplicated by `sequence_num`. + +--- + +## Testing + +Tested cluster-mon with long-running RCCL-Tests on 1-node and 2-nodes of Ruby MI350 cluster and introducing artificial chaos (e.g. killing a rccl-tests process). CVS backend/frontend app running on local laptop could directly connect to Ruby cluster compute nodes running RCCL. + +Unit-Tests: 103 tests across 10 files. Coverage includes: BaseCollector lifecycle (timeout, crash, ConnectionError, supervisor restart), Pydantic config defaults and environment variable overrides, SSH bridge (bidirectional data, EOF propagation), collectors/status API, RAS protocol client against a mock TCP server, text parser against 4 fixture files (healthy, single-node degraded, 2-node degraded with heterogeneous `7-8` ranks-per-node range, connection reset), RCCL collector state machine (all 20 transitions, bootstrap, no-duplicate-event on unchanged state), WebSocket ConnectionManager, config reload diff detection, and Inspector parser (JSONL parsing, malformed-line skipping, timestamp conversion, tail limiting, `aggregate_snapshot` avg/min/max/straggler/deduplication/collective breakdown). + +--- + +## Future Work + +| Phase | Scope | +|-------|-------| +| **2** | Switch rcclras to JSON output (rcclras v2.28.9)
Prometheus `/metrics` endpoint
InfluxDB long-term storage (structured data pipeline) | +| **3** | Persistent `MONITOR` mode (rcclras v2.29.2) for push-based event streaming (eliminates polling)
Per-rank structured error parsing | +| **4** | `/api/rccl/preflight` for Slurm prolog health gate
Slurm job ID correlation
Grafana dashboard templates
Snapshot replay for post-mortem analysis | +| **Inspector** | Bandwidth time-series charts on the Performance page (history endpoint already exists)
Per-communicator breakdown when multiple communicators are active
Alert threshold when straggler busBw falls >50% below the mean | diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/gpu_collector.py b/cvs/monitors/cluster-mon/backend/app/collectors/gpu_collector.py index 1a56c8f4..6eb89846 100644 --- a/cvs/monitors/cluster-mon/backend/app/collectors/gpu_collector.py +++ b/cvs/monitors/cluster-mon/backend/app/collectors/gpu_collector.py @@ -6,14 +6,22 @@ import json import logging from typing import Dict, Any -from datetime import datetime +from datetime import datetime, timezone + +from app.collectors.base import BaseCollector, CollectorResult, CollectorState +from app.core.config import settings as _settings logger = logging.getLogger(__name__) -class GPUMetricsCollector: +class GPUMetricsCollector(BaseCollector): """Collects GPU metrics via rocm-smi and amd-smi commands.""" + name = "gpu" + poll_interval: int = 60 # will be set after class definition + collect_timeout: float = 48.0 # 80% of poll_interval default + critical = True # GPU collection failure is critical + @staticmethod def parse_json_output(output_dict: Dict[str, str]) -> Dict[str, Any]: """ @@ -81,7 +89,7 @@ async def collect_gpu_utilization(self, ssh_manager) -> Dict[str, Any]: """ logger.info("Collecting GPU utilization") # Use amd-smi metric which provides comprehensive GPU metrics - output = ssh_manager.exec("amd-smi metric --json", timeout=120) + output = await ssh_manager.exec_async("amd-smi metric --json", timeout=120) return self.parse_json_output(output) async def collect_gpu_memory(self, ssh_manager) -> Dict[str, Any]: @@ -100,7 +108,7 @@ async def collect_gpu_memory(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting GPU memory usage") - output = ssh_manager.exec("amd-smi metric --json", timeout=120) + output = await ssh_manager.exec_async("amd-smi metric --json", timeout=120) return self.parse_json_output(output) async def collect_gpu_temperature(self, ssh_manager) -> Dict[str, Any]: @@ -120,7 +128,7 @@ async def collect_gpu_temperature(self, ssh_manager) -> Dict[str, Any]: """ logger.info("Collecting GPU temperature") # amd-smi metric provides temperature in the main metric output - output = ssh_manager.exec("amd-smi metric --json", timeout=120) + output = await ssh_manager.exec_async("amd-smi metric --json", timeout=120) return self.parse_json_output(output) async def collect_gpu_power(self, ssh_manager) -> Dict[str, Any]: @@ -139,7 +147,7 @@ async def collect_gpu_power(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting GPU power metrics") - output = ssh_manager.exec("amd-smi metric --power --json", timeout=120) + output = await ssh_manager.exec_async("amd-smi metric --power --json", timeout=120) return self.parse_json_output(output) async def collect_gpu_metrics(self, ssh_manager) -> Dict[str, Any]: @@ -158,7 +166,7 @@ async def collect_gpu_metrics(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting comprehensive GPU metrics") - output = ssh_manager.exec("amd-smi metric --json", timeout=120) + output = await ssh_manager.exec_async("amd-smi metric --json", timeout=120) return self.parse_json_output(output) async def collect_pcie_metrics(self, ssh_manager) -> Dict[str, Any]: @@ -177,7 +185,7 @@ async def collect_pcie_metrics(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting PCIe metrics") - output = ssh_manager.exec("amd-smi metric --pcie --json", timeout=120) + output = await ssh_manager.exec_async("amd-smi metric --pcie --json", timeout=120) return self.parse_json_output(output) async def collect_xgmi_metrics(self, ssh_manager) -> Dict[str, Any]: @@ -196,11 +204,7 @@ async def collect_xgmi_metrics(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting XGMI metrics") - output = ssh_manager.exec("amd-smi metric --xgmi-err --json", timeout=120) - logger.info('%%%%%%%%%%%') - logger.info('parsed value of xgmi') - logger.info(output) - logger.info(self.parse_json_output(output)) + output = await ssh_manager.exec_async("amd-smi metric --xgmi-err --json", timeout=120) return self.parse_json_output(output) async def collect_ras_errors(self, ssh_manager) -> Dict[str, Any]: @@ -219,10 +223,7 @@ async def collect_ras_errors(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting RAS error metrics") - output = ssh_manager.exec("amd-smi metric --ecc --json", timeout=120) - logger.info('%%%%%%%%%%') - logger.info('Output of ecc') - logger.info(output) + output = await ssh_manager.exec_async("amd-smi metric --ecc --json", timeout=120) return self.parse_json_output(output) async def collect_gpu_info(self, ssh_manager) -> Dict[str, Any]: @@ -239,7 +240,7 @@ async def collect_gpu_info(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting GPU info") - output = ssh_manager.exec("rocm-smi --loglevel error --showproductname --json", timeout=120) + output = await ssh_manager.exec_async("rocm-smi --loglevel error --showproductname --json", timeout=120) return self.parse_json_output(output) async def collect_pcie_info(self, ssh_manager) -> Dict[str, Any]: @@ -259,13 +260,13 @@ async def collect_pcie_info(self, ssh_manager) -> Dict[str, Any]: logger.info("Collecting PCIe link info via lspci") # First get BDF (Bus/Device/Function) addresses from amd-smi - static_output = ssh_manager.exec("amd-smi static --json", timeout=120) + static_output = await ssh_manager.exec_async("amd-smi static --json", timeout=120) static_data = self.parse_json_output(static_output) # OPTIMIZATION: Run lspci once per node instead of once per GPU # This reduces 288 commands (36 nodes * 8 GPUs) to just 36 commands! logger.info("Running lspci once per node (optimized)") - lspci_output = ssh_manager.exec("bash -c 'sudo lspci -vvv 2>/dev/null'", timeout=120) + lspci_output = await ssh_manager.exec_async("bash -c 'sudo lspci -vvv 2>/dev/null'", timeout=120) pcie_info = {} import re @@ -329,14 +330,12 @@ async def collect_all_metrics(self, ssh_manager) -> Dict[str, Any]: "info": {...} } """ - import asyncio - logger.info("Collecting all GPU metrics") # OPTIMIZATION: Call amd-smi metric --json ONCE to get ALL data # This single command includes: utilization, memory, temperature, PCIe, XGMI, and ECC metrics logger.info("Calling amd-smi metric --json for comprehensive GPU data") - amd_smi_output = await asyncio.to_thread(ssh_manager.exec, "amd-smi metric --json") + amd_smi_output = await ssh_manager.exec_async("amd-smi metric --json") amd_smi_data = self.parse_json_output(amd_smi_output) # Parse all metrics from single amd-smi output @@ -346,15 +345,15 @@ async def collect_all_metrics(self, ssh_manager) -> Dict[str, Any]: # Call dedicated commands for PCIe and ECC for cleaner data logger.info("Collecting PCIe metrics with dedicated command") - pcie_output = await asyncio.to_thread(ssh_manager.exec, "amd-smi metric --pcie --json") + pcie_output = await ssh_manager.exec_async("amd-smi metric --pcie --json") pcie_data = self.parse_json_output(pcie_output) logger.info("Collecting XGMI metrics with dedicated command") - xgmi_output = await asyncio.to_thread(ssh_manager.exec, "amd-smi metric --xgmi-err --json") + xgmi_output = await ssh_manager.exec_async("amd-smi metric --xgmi-err --json") xgmi_data = self.parse_json_output(xgmi_output) logger.info("Collecting ECC/RAS metrics with dedicated command") - ecc_output = await asyncio.to_thread(ssh_manager.exec, "amd-smi metric --ecc --json") + ecc_output = await ssh_manager.exec_async("amd-smi metric --ecc --json") ecc_data = self.parse_json_output(ecc_output) # Parse for frontend display @@ -365,7 +364,7 @@ async def collect_all_metrics(self, ssh_manager) -> Dict[str, Any]: # Package results metrics = { - "timestamp": datetime.utcnow().isoformat() + "Z", + "timestamp": datetime.now(timezone.utc).isoformat() + "Z", "utilization": utilization, "memory": memory, "temperature": temperature, @@ -379,6 +378,37 @@ async def collect_all_metrics(self, ssh_manager) -> Dict[str, Any]: return metrics + async def collect(self, ssh_manager) -> CollectorResult: + """ + BaseCollector interface. Calls collect_all_metrics() and wraps result. + node_errors is populated so BaseCollector.run() can call update_node_status(). + """ + try: + metrics = await self.collect_all_metrics(ssh_manager) + except Exception as e: + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=str(e), + ) + + # Determine per-node errors from metrics + node_errors: dict[str, bool] = {} + util_data = metrics.get("utilization", {}) if isinstance(metrics, dict) else {} + for node, node_data in util_data.items(): + has_error = isinstance(node_data, dict) and "error" in node_data + node_errors[node] = has_error + + return CollectorResult( + collector_name=self.name, + timestamp=metrics.get("timestamp", CollectorResult.now_iso()) if isinstance(metrics, dict) else CollectorResult.now_iso(), + state=CollectorState.OK, + data=metrics if isinstance(metrics, dict) else {}, + node_errors=node_errors, + ) + def _parse_utilization_from_amd_smi(self, amd_smi_data: Dict) -> Dict: """Parse utilization from amd-smi metric output.""" util_data = {} @@ -630,50 +660,6 @@ def _parse_pcie_metrics_from_amd_smi(self, amd_smi_data: Dict) -> Dict: return pcie_info - def _parse_pcie_metrics_from_amd_smi_OLD(self, amd_smi_data: Dict) -> Dict: - """OLD VERSION - keeping for reference""" - pcie_info = {} - for node, data in amd_smi_data.items(): - if isinstance(data, dict) and 'gpu_data' in data: - pcie_info[node] = {} - for gpu in data['gpu_data']: - gpu_id = f"card{gpu.get('gpu', 0)}" - pcie_data = gpu.get('pcie', {}) - - # Parse width - width = pcie_data.get('width', '-') - if width != '-': - width = f"x{width}" # Format as x16, x8, etc. - - # Parse speed - speed_data = pcie_data.get('speed', {}) - if isinstance(speed_data, dict): - speed_value = speed_data.get('value', '-') - speed_unit = speed_data.get('unit', 'GT/s') - speed = f"{speed_value}{speed_unit}" if speed_value != '-' else '-' - else: - speed = str(speed_data) if speed_data else '-' - - # Parse bandwidth - bandwidth_data = pcie_data.get('bandwidth', {}) - if isinstance(bandwidth_data, dict): - bw_value = bandwidth_data.get('value', '-') - bw_unit = bandwidth_data.get('unit', 'Mb/s') - bandwidth = f"{bw_value} {bw_unit}" if bw_value != '-' else '-' - else: - bandwidth = str(bandwidth_data) if bandwidth_data else '-' - - pcie_info[node][gpu_id] = { - 'width': width, - 'speed': speed, - 'bandwidth': bandwidth, - 'replay_count': pcie_data.get('replay_count', 0), - 'l0_to_recovery_count': pcie_data.get('l0_to_recovery_count', 0), - 'nak_sent_count': pcie_data.get('nak_sent_count', 0), - 'nak_received_count': pcie_data.get('nak_received_count', 0), - } - return pcie_info - def _parse_xgmi_metrics_from_amd_smi(self, amd_smi_data: Dict) -> Dict: """ Parse XGMI metrics from amd-smi output. @@ -823,3 +809,7 @@ def normalize_metrics(self, raw_metrics: Dict[str, Any]) -> Dict[str, Any]: normalized[node]["gpus"].append(gpu_metrics) return normalized + + +GPUMetricsCollector.poll_interval = _settings.polling.interval +GPUMetricsCollector.collect_timeout = _settings.polling.interval * 0.8 diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_health_degraded.png b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_health_degraded.png new file mode 100755 index 00000000..7cd30e4d Binary files /dev/null and b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_health_degraded.png differ diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_health_good.png b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_health_good.png new file mode 100755 index 00000000..1c7c9e49 Binary files /dev/null and b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_health_good.png differ diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_perf_channel_breakdown.png b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_perf_channel_breakdown.png new file mode 100755 index 00000000..05a5e899 Binary files /dev/null and b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_perf_channel_breakdown.png differ diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_perf_overview.png b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_perf_overview.png new file mode 100755 index 00000000..93428e3c Binary files /dev/null and b/cvs/monitors/cluster-mon/backend/app/collectors/images/rccl_perf_overview.png differ diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/inspector_collector.py b/cvs/monitors/cluster-mon/backend/app/collectors/inspector_collector.py new file mode 100644 index 00000000..d09cf57b --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/inspector_collector.py @@ -0,0 +1,201 @@ +""" +InspectorCollector — reads RCCL Inspector JSONL log files and pushes snapshots. + +Two collection modes: + - "file": reads directly from an NFS path visible to the CVS backend (primary) + - "ssh": reads via SSH exec_async, using the existing JumpHostPssh/Pssh manager + +The Inspector plugin writes one file per process: + /-pid.log + +This collector glob-matches all *.log files in dump_dir (file mode) or +runs `tail -n ` on each compute node (ssh mode). + +critical = False — Inspector failure never affects overall cluster health status. +""" + +import asyncio +import logging +from pathlib import Path +from typing import Any, Optional + +from app.collectors.base import BaseCollector, CollectorResult, CollectorState +from app.collectors.inspector_parser import InspectorParser, aggregate_snapshot +from app.models.rccl_models import InspectorCollPerf, InspectorSnapshot + +logger = logging.getLogger(__name__) + + +class InspectorCollector(BaseCollector): + """ + Polls RCCL Inspector log files every poll_interval seconds. + + File mode (default): + Reads all *.log files from rccl.inspector.dump_dir via the local + filesystem. Use this when dump_dir is on a shared NFS mount visible + from the CVS backend host. + + SSH mode: + For each healthy node, runs `tail -n /-pid*.log` + via exec_async. Use this when NFS is not available. + """ + + name = "inspector" + poll_interval: int = 30 + collect_timeout: float = 20.0 + critical = False + + def __init__(self): + self._parser = InspectorParser() + + async def collect(self, ssh_manager) -> CollectorResult: + from app.core.config import settings + + cfg = settings.rccl.inspector + if not cfg.enabled: + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.NO_SERVICE, + data={}, + error="Inspector collector is disabled (rccl.inspector.enabled=false)", + ) + + try: + if cfg.mode == "ssh": + records = await self._collect_ssh(ssh_manager, cfg) + else: + records = self._collect_file(cfg) + except Exception as e: + logger.error(f"InspectorCollector unexpected error: {e}", exc_info=True) + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=str(e), + ) + + snapshot = aggregate_snapshot(records) + snapshot_dict = snapshot.model_dump() + + # Push to data store if available + from app.main import app_state + data_store = getattr(app_state, 'rccl_data_store', None) + if data_store is not None: + await data_store.push_inspector_snapshot(snapshot_dict) + + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.OK, + data=snapshot_dict, + ) + + # ------------------------------------------------------------------ + # File mode + # ------------------------------------------------------------------ + + def _collect_file(self, cfg) -> list[InspectorCollPerf]: + """Read all *.log files in dump_dir from the local/NFS filesystem.""" + if not cfg.dump_dir: + logger.warning("Inspector file mode: rccl.inspector.dump_dir is not set") + return [] + + dump_path = Path(cfg.dump_dir) + if not dump_path.exists(): + logger.warning(f"Inspector dump_dir does not exist: {dump_path}") + return [] + + records: list[InspectorCollPerf] = [] + log_files = list(dump_path.glob("*.log")) + if not log_files: + logger.debug(f"Inspector: no *.log files found in {dump_path}") + return [] + + for log_file in log_files: + file_records = self._parser.parse_file(log_file, tail=cfg.max_records_per_file) + records.extend(file_records) + logger.debug(f"Inspector: parsed {len(file_records)} records from {log_file.name}") + + logger.info(f"Inspector file mode: {len(records)} records from {len(log_files)} files") + return records + + # ------------------------------------------------------------------ + # SSH mode + # ------------------------------------------------------------------ + + def _active_pids(self) -> set[int]: + """ + Extract PIDs of currently active RCCL ranks from the latest rcclras snapshot. + Returns an empty set when no job is running or rcclras hasn't connected yet. + """ + from app.main import app_state + snapshot = getattr(app_state, 'latest_rccl_snapshot', None) + if not snapshot: + return set() + pids: set[int] = set() + for comm in snapshot.get('communicators', []): + for rank in comm.get('ranks', []): + pid = rank.get('pid') + if pid: + pids.add(pid) + return pids + + async def _collect_ssh(self, ssh_manager, cfg) -> list[InspectorCollPerf]: + """Collect from each compute node via SSH tail. + + When active PIDs are known from the rcclras snapshot, only log files + belonging to those PIDs are read — stale files from previous runs in + the same dump_dir are ignored. + Falls back to reading all *-pid*.log files when no snapshot is available. + """ + if not cfg.dump_dir: + logger.warning("Inspector SSH mode: rccl.inspector.dump_dir is not set") + return [] + + active_pids = self._active_pids() + + if active_pids: + # Build a grep pattern so only files for the current job are read. + # Example: grep -E 'pid(3404720|3404721|...)\.log$' + pid_pattern = "|".join(f"pid{p}" for p in sorted(active_pids)) + cmd = ( + f"ls {cfg.dump_dir}/ 2>/dev/null " + f"| grep -E '({pid_pattern})\\.log$' " + f"| xargs -I{{}} tail -n {cfg.max_records_per_file} {cfg.dump_dir}/{{}} " + f"2>/dev/null || true" + ) + logger.debug(f"Inspector SSH: filtering to {len(active_pids)} active PIDs") + else: + # No rcclras snapshot yet — read all pid log files as fallback + cmd = ( + f"tail -n {cfg.max_records_per_file} " + f"{cfg.dump_dir}/*-pid*.log 2>/dev/null || true" + ) + logger.debug("Inspector SSH: no active PIDs known, reading all log files") + + try: + outputs = await ssh_manager.exec_async(cmd) + except Exception as e: + logger.warning(f"Inspector SSH exec failed: {e}") + return [] + + records: list[InspectorCollPerf] = [] + for node, output in outputs.items(): + if not output: + continue + node_records = self._parser.parse_lines(output) + records.extend(node_records) + logger.debug(f"Inspector SSH: {len(node_records)} records from {node}") + + logger.info( + f"Inspector SSH mode: {len(records)} records from {len(outputs)} nodes" + + (f" (PIDs: {sorted(active_pids)})" if active_pids else " (all files)") + ) + return records + + +# Set poll_interval from config at import time (consistent with other collectors) +from app.core.config import settings as _settings +InspectorCollector.poll_interval = _settings.rccl.inspector.poll_interval diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/inspector_parser.py b/cvs/monitors/cluster-mon/backend/app/collectors/inspector_parser.py new file mode 100644 index 00000000..3cb5c3d0 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/inspector_parser.py @@ -0,0 +1,202 @@ +""" +Parser for RCCL Inspector plugin JSONL output (format version v4.0). + +Each line in an Inspector log file is one JSON object representing the +most recently completed collective for a communicator during a dump interval. +This is a "latest snapshot" model — not a complete event log. + +Reference: ext-profiler/inspector/inspector.cc (RCCL v2.28.3) +""" + +import json +import logging +from pathlib import Path +from typing import Optional + +from app.models.rccl_models import ( + InspectorCollPerf, + InspectorEventTrace, + InspectorKernelChannel, + InspectorSnapshot, +) + +logger = logging.getLogger(__name__) + + +class InspectorParser: + """ + Parse Inspector JSONL log files into InspectorCollPerf records. + + Usage: + parser = InspectorParser() + records = parser.parse_file(Path("/nfs/inspector-logs/gpu-node-01-pid12345.log")) + records = parser.parse_lines("line1\\nline2\\n...") + """ + + def parse_file(self, path: Path, tail: int = 100) -> list[InspectorCollPerf]: + """ + Read the last `tail` lines from an Inspector log file and parse them. + Only the tail is read to bound memory usage on long-running jobs. + """ + try: + text = path.read_text(errors="replace") + except OSError as e: + logger.warning(f"Inspector: cannot read {path}: {e}") + return [] + lines = text.splitlines() + return self.parse_lines("\n".join(lines[-tail:])) + + def parse_lines(self, text: str) -> list[InspectorCollPerf]: + """ + Parse a block of text (one JSON object per line). Malformed lines + are silently skipped with a debug-level log. + """ + records: list[InspectorCollPerf] = [] + for lineno, line in enumerate(text.splitlines(), start=1): + line = line.strip() + if not line: + continue + record = self._parse_line(line, lineno) + if record is not None: + records.append(record) + return records + + def _parse_line(self, line: str, lineno: int) -> Optional[InspectorCollPerf]: + try: + obj = json.loads(line) + except json.JSONDecodeError: + logger.debug(f"Inspector: skipping malformed JSON at line {lineno}") + return None + + try: + header = obj["header"] + meta = obj["metadata"] + perf = obj["coll_perf"] + return InspectorCollPerf( + timestamp=meta["dump_timestamp_us"] / 1_000_000.0, + comm_hash=header["id"], + rank=header["rank"], + nranks=header["n_ranks"], + nnodes=header["nnodes"], + hostname=meta["hostname"], + pid=meta["pid"], + collective=perf["coll"], + sequence_num=perf["coll_sn"], + msg_size_bytes=perf["coll_msg_size_bytes"], + exec_time_us=perf["coll_exec_time_us"], + timing_source=perf["coll_timing_source"], + algo_bw_gbps=float(perf["coll_algobw_gbs"]), + bus_bw_gbps=float(perf["coll_busbw_gbs"]), + event_trace=self._parse_event_trace(perf), + ) + except (KeyError, TypeError, ValueError) as e: + logger.debug(f"Inspector: skipping line {lineno} — missing field: {e}") + return None + + def _parse_event_trace(self, perf: dict) -> Optional[InspectorEventTrace]: + """ + Parse verbose event trace from coll_perf when NCCL_INSPECTOR_DUMP_VERBOSE=1. + event_trace_sn and event_trace_ts are both nested inside the coll_perf object. + Returns None if neither key is present (non-verbose mode). + """ + sn = perf.get("event_trace_sn") + ts = perf.get("event_trace_ts") + if sn is None and ts is None: + return None + # Guard against non-dict values (malformed verbose data) + if not isinstance(sn, dict): + sn = None + if not isinstance(ts, dict): + ts = None + if sn is None and ts is None: + return None + + try: + # Merge per-channel sn and ts by channel_id + sn_channels = {c["channel_id"]: c for c in (sn or {}).get("kernel_events", [])} + ts_channels = {c["channel_id"]: c for c in (ts or {}).get("kernel_events", [])} + all_ids = sorted(set(sn_channels) | set(ts_channels)) + + channels = [ + InspectorKernelChannel( + channel_id=ch_id, + kernel_start_sn=sn_channels.get(ch_id, {}).get("kernel_start_sn"), + kernel_stop_sn=sn_channels.get(ch_id, {}).get("kernel_stop_sn"), + kernel_record_sn=sn_channels.get(ch_id, {}).get("kernel_record_sn"), + kernel_start_ts=ts_channels.get(ch_id, {}).get("kernel_start_ts"), + kernel_stop_ts=ts_channels.get(ch_id, {}).get("kernel_stop_ts"), + kernel_record_ts=ts_channels.get(ch_id, {}).get("kernel_record_ts"), + ) + for ch_id in all_ids + ] + + return InspectorEventTrace( + coll_start_sn=(sn or {}).get("coll_start_sn"), + coll_stop_sn=(sn or {}).get("coll_stop_sn"), + coll_start_ts=(ts or {}).get("coll_start_ts"), + coll_stop_ts=(ts or {}).get("coll_stop_ts"), + channels=channels, + ) + except (KeyError, TypeError, ValueError) as e: + logger.debug(f"Inspector: skipping malformed event_trace: {e}") + return None + + +def aggregate_snapshot(records: list[InspectorCollPerf]) -> InspectorSnapshot: + """ + Aggregate a list of InspectorCollPerf records into an InspectorSnapshot. + + Stats (avg/min/max busBw, collective breakdown) are computed from ALL + records in the tail window, giving a richer sample for accuracy. + + The `records` field stored in the snapshot is deduplicated to the LATEST + entry per (rank, comm_hash) — one row per rank per communicator. This + prevents the frontend table and WebSocket payload from growing with the + tail window size (e.g. 500 lines × 8 files = 4000 rows → 8 rows). + + Zero-bandwidth records (exec_time_us == 0) are excluded from bandwidth + statistics since they indicate a timing fallback or tiny collective. + """ + import time + + # Stats over all records in the tail window + bw_records = [r for r in records if r.bus_bw_gbps > 0.0] + + avg_bw: Optional[float] = None + min_bw: Optional[float] = None + max_bw: Optional[float] = None + slowest_rank: Optional[int] = None + + if bw_records: + bws = [r.bus_bw_gbps for r in bw_records] + avg_bw = sum(bws) / len(bws) + min_bw = min(bws) + max_bw = max(bws) + slowest_rank = bw_records[bws.index(min_bw)].rank + + collective_breakdown: dict[str, int] = {} + for r in records: + collective_breakdown[r.collective] = collective_breakdown.get(r.collective, 0) + 1 + + ts = records[0].timestamp if records else time.time() + + # Deduplicate to latest record per (rank, comm_hash) by sequence_num. + # sequence_num is a monotonically increasing counter per communicator, + # so the highest value is the most recent collective. + latest: dict[tuple, InspectorCollPerf] = {} + for r in records: + key = (r.rank, r.comm_hash) + existing = latest.get(key) + if existing is None or r.sequence_num > existing.sequence_num: + latest[key] = r + display_records = sorted(latest.values(), key=lambda r: (r.rank, r.comm_hash)) + + return InspectorSnapshot( + timestamp=ts, + records=display_records, + avg_bus_bw_gbps=avg_bw, + min_bus_bw_gbps=min_bw, + max_bus_bw_gbps=max_bw, + slowest_rank=slowest_rank, + collective_breakdown=collective_breakdown, + ) diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/nic_collector.py b/cvs/monitors/cluster-mon/backend/app/collectors/nic_collector.py index c0a5e079..e359fa38 100644 --- a/cvs/monitors/cluster-mon/backend/app/collectors/nic_collector.py +++ b/cvs/monitors/cluster-mon/backend/app/collectors/nic_collector.py @@ -7,14 +7,22 @@ import json import logging from typing import Dict, Any -from datetime import datetime +from datetime import datetime, timezone + +from app.collectors.base import BaseCollector, CollectorResult, CollectorState +from app.core.config import settings as _settings logger = logging.getLogger(__name__) -class NICMetricsCollector: +class NICMetricsCollector(BaseCollector): """Collects NIC metrics via rdma, ethtool, and ip commands.""" + name = "nic" + poll_interval: int = 60 + collect_timeout: float = 48.0 + critical = True + async def collect_rdma_links(self, ssh_manager) -> Dict[str, Any]: """ Collect RDMA link information. @@ -33,7 +41,7 @@ async def collect_rdma_links(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting RDMA link info") - output = ssh_manager.exec("rdma link", timeout=60) + output = await ssh_manager.exec_async("rdma link", timeout=60) rdma_dict = {} for node, out_str in output.items(): @@ -82,7 +90,7 @@ async def collect_rdma_stats(self, ssh_manager) -> Dict[str, Any]: """ logger.info("Collecting RDMA statistics (includes congestion control metrics)") # Use bash -c to properly handle shell redirection and || operator - output = ssh_manager.exec("bash -c 'rdma statistic show --json 2>/dev/null || echo \"{}\"'", timeout=60) + output = await ssh_manager.exec_async("bash -c 'rdma statistic show --json 2>/dev/null || echo \"{}\"'", timeout=60) logger.info(f"RDMA stats output received from {len(output)} nodes") @@ -165,7 +173,7 @@ async def collect_ethtool_stats(self, ssh_manager, interfaces: Dict[str, list] = # Run 'ip -s link' once per node to get all interface stats cmd = "ip -s link show" - output = ssh_manager.exec(cmd, timeout=60) + output = await ssh_manager.exec_async(cmd, timeout=60) eth_stats = {} @@ -263,7 +271,7 @@ async def collect_ip_addr(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting IP address info") - output = ssh_manager.exec("bash -c 'ip addr show | grep -A 5 mtu --color=never'", timeout=60) + output = await ssh_manager.exec_async("bash -c 'ip addr show | grep -A 5 mtu --color=never'", timeout=60) ip_dict = {} @@ -339,7 +347,7 @@ async def collect_lldp(self, ssh_manager) -> Dict[str, Any]: """ logger.info("Collecting LLDP info") # Use bash -c to properly handle shell redirection and || operator - output = ssh_manager.exec("bash -c 'sudo lldpctl -f json 2>/dev/null || echo \"{}\"'", timeout=60) + output = await ssh_manager.exec_async("bash -c 'sudo lldpctl -f json 2>/dev/null || echo \"{}\"'", timeout=60) lldp_dict = {} for node, out_str in output.items(): @@ -420,6 +428,26 @@ def _filter_lldp_by_rdma(self, lldp_data: Dict[str, Any], rdma_links_data: Dict[ return filtered_lldp + async def collect(self, ssh_manager) -> CollectorResult: + """BaseCollector interface. Calls collect_all_metrics() and wraps result.""" + try: + metrics = await self.collect_all_metrics(ssh_manager) + except Exception as e: + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=str(e), + ) + + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.OK, + data=metrics if isinstance(metrics, dict) else {}, + ) + async def collect_all_metrics(self, ssh_manager) -> Dict[str, Any]: """ Collect all NIC metrics in parallel. @@ -465,7 +493,7 @@ async def collect_all_metrics(self, ssh_manager) -> Dict[str, Any]: filtered_lldp = self._filter_lldp_by_rdma(lldp_data, rdma_links_data) metrics = { - "timestamp": datetime.utcnow().isoformat() + "Z", + "timestamp": datetime.now(timezone.utc).isoformat() + "Z", "rdma_links": rdma_links_data, "rdma_stats": results[1] if not isinstance(results[1], Exception) else {}, "rdma_resources": rdma_res, @@ -490,7 +518,7 @@ async def collect_rdma_resources(self, ssh_manager) -> Dict[str, Any]: } """ logger.info("Collecting RDMA resources") - output = ssh_manager.exec("rdma res", timeout=60) + output = await ssh_manager.exec_async("rdma res", timeout=60) rdma_res = {} for node, out_str in output.items(): @@ -520,3 +548,7 @@ async def collect_rdma_resources(self, ssh_manager) -> Dict[str, Any]: rdma_res[node][device] = resources return rdma_res + + +NICMetricsCollector.poll_interval = _settings.polling.interval +NICMetricsCollector.collect_timeout = _settings.polling.interval * 0.8 diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/rccl_collector.py b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_collector.py new file mode 100644 index 00000000..aa8959cb --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_collector.py @@ -0,0 +1,313 @@ +""" +RCCL Collector -- CVS cluster-mon Phase 1. + +Implements BaseCollector. Polls rcclras via SSH port-forward on each cycle. +Lifecycle managed by the unified REGISTERED_COLLECTORS loop in main.py. + +Critical=False: RCCLJobState.NO_JOB is expected when no RCCL job is running +and does NOT count as a collector failure for overall_status purposes. +""" + +import asyncio +import logging +import time +from typing import Optional, Any + +from paramiko.ssh_exception import ChannelException + +from app.collectors.base import BaseCollector, CollectorResult, CollectorState +from app.collectors.rccl_ras_client import RCCLRasClient, ProtocolError +from app.models.rccl_models import RCCLJobState, RCCLSnapshot + +logger = logging.getLogger(__name__) + + +class RCCLCollector(BaseCollector): + """ + Polls rcclras (the RCCL RAS TCP service on port 28028) via SSH port-forward. + + - name = "rccl" + - poll_interval: read from settings.rccl.poll_interval (default 30s) + - collect_timeout: collective_timeout_secs + 10s for SSH + protocol overhead + - critical = False (NO_JOB is expected; not a system failure) + """ + + name = "rccl" + poll_interval: int = 30 # overridden at module level from settings + collect_timeout: float = 20.0 # overridden at module level from settings + critical = False + + def __init__(self): + self.job_state: RCCLJobState = RCCLJobState.NO_JOB + self._app_state: Optional[Any] = None # set in run() before collect() + self._bootstrapped: bool = False # True after first poll seeds job_state from store + + def _healthy_nodes(self, app_state: Any) -> list[str]: + """ + Return all nodes with healthy status from app_state.node_health_status, + in the order they appear in the config. The rcclras listener (port 28028) + only runs on nodes that are part of an active RCCL job, which may be any + subset of the configured nodes — so we must try each one. + """ + return [ + node + for node, status in app_state.node_health_status.items() + if status == "healthy" + ] + + def _health_from_snapshot(self, snapshot: RCCLSnapshot) -> RCCLJobState: + """Return the job state already computed by the parser (covers missing ranks, + async errors, dead peers, and inconsistent topology).""" + return snapshot.state + + async def _bootstrap_job_state(self, app_state: Any) -> None: + """ + On the first poll after startup, seed job_state from the data store's last + known snapshot state. This prevents a spurious job_start event when the + backend restarts mid-job. + """ + if self._bootstrapped: + return + self._bootstrapped = True + data_store = getattr(app_state, 'rccl_data_store', None) + if not data_store: + return + try: + last = await data_store.get_current_snapshot() + if last and 'state' in last: + self.job_state = RCCLJobState(last['state']) + logger.info(f"RCCL collector bootstrapped from stored state: {self.job_state}") + except (ValueError, Exception): + pass # Unknown state value or store error — start from NO_JOB + + async def on_collect_timeout(self, app_state: Any) -> None: + """ + Called by BaseCollector.run() when collect() is cancelled by collect_timeout. + Updates the state machine so the timeout is visible as an UNREACHABLE transition. + """ + prev = self.job_state + self.job_state = RCCLJobState.UNREACHABLE + await self._push_state_event(prev, self.job_state, app_state) + if hasattr(app_state, 'latest_rccl_snapshot'): + app_state.latest_rccl_snapshot = {"state": "unreachable"} + + async def run(self, ssh_manager, app_state: Any) -> None: + """ + Override BaseCollector.run() to pass app_state to collect(). + Stores app_state reference so _healthy_nodes() and data_store are accessible. + """ + self._app_state = app_state + self._ssh_manager = ssh_manager + await super().run(ssh_manager, app_state) + + async def collect(self, ssh_manager) -> CollectorResult: + """ + One RCCL poll cycle: + 1. Bootstrap job_state from data store on first poll (prevents spurious job_start). + 2. Try each healthy node for an active rcclras listener on ras_port. + 3. On ConnectionRefused/ChannelException: continue to next node (not our job). + 4. On TimeoutError: continue to next node (may be transient SSH delay). + 5. On ProtocolError/Exception: abort cycle with ERROR. + 6. If no node has a listener: NO_JOB. + + Connection refused -> try next node -> NO_JOB if none respond. + Timeout -> try next node -> UNREACHABLE only if ALL candidates time out. + Protocol error -> ERROR (abort immediately; protocol errors are not transient). + """ + from app.core.config import settings + + app_state = self._app_state + if app_state is None: + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error="collect() called before run() -- app_state not set", + ) + + await self._bootstrap_job_state(app_state) + + prev_state = self.job_state + + candidates = self._healthy_nodes(app_state) + if not candidates: + self.job_state = RCCLJobState.UNREACHABLE + await self._push_state_event(prev_state, self.job_state, app_state) + if hasattr(app_state, 'latest_rccl_snapshot'): + app_state.latest_rccl_snapshot = {"state": "unreachable"} + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.UNREACHABLE, + data={}, + error="No healthy nodes available for RCCL polling", + ) + + ras_port = settings.rccl.ras_port + collective_timeout = settings.rccl.collective_timeout_secs + + timed_out_nodes: list[str] = [] + + # Try each healthy node — rcclras only listens on nodes that are part of + # an active RCCL job, which may be a subset of the configured nodes. + for leader in candidates: + try: + async with ssh_manager.open_port_forward(leader, ras_port) as (reader, writer): + client = RCCLRasClient(reader, writer) + await client.handshake() + await client.set_timeout(collective_timeout) + raw_text = await client.get_status(verbose=True) + logger.debug(f"rcclras raw output from {leader}:\n{raw_text}") + + snapshot = self._parse_text_response(raw_text, leader) + self.job_state = self._health_from_snapshot(snapshot) + await self._push_state_event(prev_state, self.job_state, app_state, leader) + snapshot_dict = snapshot.model_dump() + + if hasattr(app_state, 'rccl_data_store') and app_state.rccl_data_store: + await app_state.rccl_data_store.push_snapshot(snapshot_dict) + if hasattr(app_state, 'latest_rccl_snapshot'): + app_state.latest_rccl_snapshot = snapshot_dict + + try: + from app.main import broadcast_rccl + await broadcast_rccl(snapshot_dict) + except Exception as e: + logger.warning(f"broadcast_rccl failed (snapshot not sent to WebSocket clients): {e}") + + collector_state = ( + CollectorState.OK + if self.job_state in (RCCLJobState.HEALTHY, RCCLJobState.DEGRADED) + else CollectorState.NO_SERVICE + ) + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=collector_state, + data=snapshot_dict, + ) + + except (ConnectionRefusedError, ChannelException): + # Port 28028 closed on this node — no RCCL job here, try next. + logger.debug(f"No rcclras listener on {leader}:{ras_port}, trying next node") + continue + + except asyncio.TimeoutError: + # SSH/protocol timeout on this node — could be transient, try remaining nodes. + logger.debug(f"Timeout on {leader}:{ras_port}, trying next node") + timed_out_nodes.append(leader) + continue + + except ProtocolError as e: + # Protocol errors are not transient — abort immediately. + prev = self.job_state + self.job_state = RCCLJobState.ERROR + await self._push_state_event(prev, self.job_state, app_state, leader) + logger.error(f"RAS protocol error on {leader}: {e}") + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=str(e), + ) + + except Exception as e: + prev = self.job_state + self.job_state = RCCLJobState.ERROR + await self._push_state_event(prev, self.job_state, app_state, leader) + logger.error(f"RCCL collect() unexpected error on {leader}: {e}", exc_info=True) + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.ERROR, + data={}, + error=str(e), + ) + + # All candidates tried. If some timed out and none responded, mark UNREACHABLE. + if timed_out_nodes: + prev = self.job_state + self.job_state = RCCLJobState.UNREACHABLE + await self._push_state_event(prev, self.job_state, app_state) + if hasattr(app_state, 'latest_rccl_snapshot'): + app_state.latest_rccl_snapshot = {"state": "unreachable"} + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.UNREACHABLE, + data={}, + error=f"RAS collective timed out on all nodes: {timed_out_nodes}", + ) + + # All healthy nodes tried — no rcclras listener found anywhere. + self.job_state = RCCLJobState.NO_JOB + await self._push_state_event(prev_state, self.job_state, app_state) + if hasattr(app_state, 'latest_rccl_snapshot'): + app_state.latest_rccl_snapshot = {"state": "no_job"} + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.NO_SERVICE, + data={}, + error=f"Port {ras_port} not listening on any healthy node -- no RCCL job running", + ) + + def _parse_text_response(self, raw_text: str, leader: str) -> RCCLSnapshot: + """Parse rcclras VERBOSE STATUS text output using RCCLTextParser.""" + from app.collectors.rccl_text_parser import RCCLTextParser + return RCCLTextParser().parse(raw_text) + + async def _push_state_event( + self, + prev: RCCLJobState, + curr: RCCLJobState, + app_state: Any, + leader: Optional[str] = None, + ) -> None: + """Push a state_change event when job_state transitions between polls.""" + if prev == curr: + return + data_store = getattr(app_state, 'rccl_data_store', None) + if not data_store: + return + + _TYPE_MAP = { + (RCCLJobState.NO_JOB, RCCLJobState.HEALTHY): "job_start", + (RCCLJobState.NO_JOB, RCCLJobState.DEGRADED): "job_start_degraded", + (RCCLJobState.NO_JOB, RCCLJobState.UNREACHABLE): "nodes_unreachable", + (RCCLJobState.NO_JOB, RCCLJobState.ERROR): "collector_error", + (RCCLJobState.HEALTHY, RCCLJobState.DEGRADED): "job_degraded", + (RCCLJobState.HEALTHY, RCCLJobState.NO_JOB): "job_end", + (RCCLJobState.HEALTHY, RCCLJobState.UNREACHABLE): "node_unreachable", + (RCCLJobState.HEALTHY, RCCLJobState.ERROR): "collector_error", + (RCCLJobState.DEGRADED, RCCLJobState.HEALTHY): "job_recovered", + (RCCLJobState.DEGRADED, RCCLJobState.NO_JOB): "job_end", + (RCCLJobState.DEGRADED, RCCLJobState.UNREACHABLE): "node_unreachable", + (RCCLJobState.DEGRADED, RCCLJobState.ERROR): "collector_error", + (RCCLJobState.UNREACHABLE, RCCLJobState.HEALTHY): "node_recovered", + (RCCLJobState.UNREACHABLE, RCCLJobState.DEGRADED): "node_recovered_degraded", + (RCCLJobState.UNREACHABLE, RCCLJobState.NO_JOB): "job_end", + (RCCLJobState.UNREACHABLE, RCCLJobState.ERROR): "collector_error", + (RCCLJobState.ERROR, RCCLJobState.HEALTHY): "job_start", + (RCCLJobState.ERROR, RCCLJobState.DEGRADED): "job_start_degraded", + (RCCLJobState.ERROR, RCCLJobState.NO_JOB): "job_end", + (RCCLJobState.ERROR, RCCLJobState.UNREACHABLE): "node_unreachable", + } + event_type = _TYPE_MAP.get((prev, curr), "state_change") + + await data_store.push_event({ + "event_type": event_type, + "timestamp": time.time(), + "from_state": prev, + "to_state": curr, + "leader_node": leader, + }) + logger.info(f"RCCL state transition: {prev} → {curr} (event: {event_type})") + + +# Fail loudly if settings cannot be loaded — consistent with GPU/NIC collectors. +from app.core.config import settings as _settings +RCCLCollector.poll_interval = _settings.rccl.poll_interval +RCCLCollector.collect_timeout = _settings.rccl.collective_timeout_secs + 10 diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/rccl_data_store.py b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_data_store.py new file mode 100644 index 00000000..480f2df4 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_data_store.py @@ -0,0 +1,182 @@ +""" +Redis-backed RCCL data store using Redis Streams. + +XADD+MAXLEN is atomic (single command), fixing the LPUSH+LTRIM race condition. +Stream IDs embed millisecond timestamps, enabling time-range queries without +a separate sorted set. + +When redis_client is None, falls back to a bounded in-memory buffer so that +events and snapshots are available for the current session without Redis. +""" + +import collections +import json +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +# In-memory fallback caps (no Redis) +_MEMORY_EVENT_MAX = 500 +_MEMORY_SNAPSHOT_MAX = 100 +_MEMORY_INSPECTOR_MAX = 100 + + +class RCCLDataStore: + # Redis Streams — atomic append+cap in one command + SNAPSHOT_STREAM = "rccl:snapshots" # Stream, capped at 1000 entries + EVENT_STREAM = "rccl:events" # Stream, capped at 10000 entries + CURRENT_KEY = "rccl:current" # Hash, latest snapshot only + INSPECTOR_STREAM = "rccl:inspector:snapshots" # Stream, capped at 1000 entries + INSPECTOR_CURRENT_KEY = "rccl:inspector:current" # Hash, latest Inspector snapshot + + def __init__(self, redis_client, snapshot_max: int = 1000, event_max: int = 10000): + """ + Args: + redis_client: redis.asyncio.Redis instance from app_state.redis, or None. + When None, an in-memory deque is used as a fallback so that + events are available for the current session. + snapshot_max: Maximum number of entries in the snapshot stream. + event_max: Maximum number of entries in the event stream. + """ + self._r = redis_client + self._snapshot_max = snapshot_max + self._event_max = event_max + # In-memory fallback buffers (used only when Redis is unavailable) + self._mem_events: collections.deque[dict] = collections.deque(maxlen=_MEMORY_EVENT_MAX) + self._mem_snapshots: collections.deque[dict] = collections.deque(maxlen=_MEMORY_SNAPSHOT_MAX) + self._mem_current: Optional[dict] = None + self._mem_inspector_snapshots: collections.deque[dict] = collections.deque(maxlen=_MEMORY_INSPECTOR_MAX) + self._mem_inspector_current: Optional[dict] = None + + async def push_snapshot(self, snapshot: dict) -> None: + """Atomically append snapshot to ring buffer and update current.""" + if self._r is None: + self._mem_snapshots.append(snapshot) + self._mem_current = snapshot + return + try: + payload = json.dumps(snapshot) + await self._r.xadd( + self.SNAPSHOT_STREAM, + {"data": payload}, + maxlen=self._snapshot_max, + ) + await self._r.hset( + self.CURRENT_KEY, + mapping={"data": payload, "ts": str(snapshot.get("timestamp", ""))}, + ) + except Exception as e: + logger.warning(f"RCCLDataStore.push_snapshot failed (falling back to memory): {e}") + self._mem_snapshots.append(snapshot) + self._mem_current = snapshot + + async def push_event(self, event: dict) -> None: + """Atomically append event to event stream.""" + if self._r is None: + self._mem_events.append(event) + return + try: + # approximate=True trims in whole radix tree nodes — efficient for high-volume + await self._r.xadd( + self.EVENT_STREAM, + {"data": json.dumps(event)}, + maxlen=self._event_max, + approximate=True, + ) + except Exception as e: + logger.warning(f"RCCLDataStore.push_event failed (falling back to memory): {e}") + self._mem_events.append(event) + + async def get_recent_snapshots(self, count: int = 50) -> list[dict]: + """Return the most recent N snapshots, newest first.""" + if self._r is None: + return list(reversed(list(self._mem_snapshots)))[:count] + try: + entries = await self._r.xrevrange(self.SNAPSHOT_STREAM, count=count) + return [json.loads(e[1]["data"]) for e in entries] + except Exception as e: + logger.warning(f"RCCLDataStore.get_recent_snapshots failed: {e}") + return [] + + async def get_current_snapshot(self) -> Optional[dict]: + """Return the latest snapshot from the CURRENT_KEY hash.""" + if self._r is None: + return self._mem_current + try: + result = await self._r.hget(self.CURRENT_KEY, "data") + if result: + return json.loads(result) + return None + except Exception as e: + logger.warning(f"RCCLDataStore.get_current_snapshot failed: {e}") + return None + + @property + def is_memory_capped(self) -> bool: + """True when the in-memory event buffer has reached its maximum capacity.""" + return self._r is None and len(self._mem_events) >= _MEMORY_EVENT_MAX + + async def push_inspector_snapshot(self, snapshot: dict) -> None: + """Append an Inspector performance snapshot and update the current-key.""" + if self._r is None: + self._mem_inspector_snapshots.append(snapshot) + self._mem_inspector_current = snapshot + return + try: + payload = json.dumps(snapshot) + await self._r.xadd( + self.INSPECTOR_STREAM, + {"data": payload}, + maxlen=self._snapshot_max, + ) + await self._r.hset( + self.INSPECTOR_CURRENT_KEY, + mapping={"data": payload, "ts": str(snapshot.get("timestamp", ""))}, + ) + except Exception as e: + logger.warning(f"RCCLDataStore.push_inspector_snapshot failed (falling back to memory): {e}") + self._mem_inspector_snapshots.append(snapshot) + self._mem_inspector_current = snapshot + + async def get_inspector_current(self) -> Optional[dict]: + """Return the latest Inspector snapshot.""" + if self._r is None: + return self._mem_inspector_current + try: + result = await self._r.hget(self.INSPECTOR_CURRENT_KEY, "data") + if result: + return json.loads(result) + return None + except Exception as e: + logger.warning(f"RCCLDataStore.get_inspector_current failed: {e}") + return None + + async def get_inspector_snapshots(self, count: int = 50) -> list[dict]: + """Return the most recent N Inspector snapshots, newest first.""" + if self._r is None: + return list(reversed(list(self._mem_inspector_snapshots)))[:count] + try: + entries = await self._r.xrevrange(self.INSPECTOR_STREAM, count=count) + return [json.loads(e[1]["data"]) for e in entries] + except Exception as e: + logger.warning(f"RCCLDataStore.get_inspector_snapshots failed: {e}") + return [] + + async def get_events_in_range(self, start_ts: float, end_ts: float) -> list[dict]: + """Return events within a UTC timestamp range using stream entry IDs. + In-memory results are sorted by timestamp to handle NTP clock adjustments.""" + if self._r is None: + results = [ + e for e in self._mem_events + if start_ts <= e.get("timestamp", 0) <= end_ts + ] + return sorted(results, key=lambda e: e.get("timestamp", 0)) + try: + start_id = f"{int(start_ts * 1000)}-0" + end_id = f"{int(end_ts * 1000)}-0" + entries = await self._r.xrange(self.EVENT_STREAM, min=start_id, max=end_id) + return [json.loads(e[1]["data"]) for e in entries] + except Exception as e: + logger.warning(f"RCCLDataStore.get_events_in_range failed: {e}") + return [] diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/rccl_ras_client.py b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_ras_client.py new file mode 100644 index 00000000..f43ee3be --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_ras_client.py @@ -0,0 +1,148 @@ +""" +Async rcclras TCP client for the CVS RCCL monitoring extension. + +Speaks the rcclras wire protocol (newline-terminated ASCII over TCP). +Connection is owned by the caller's context manager (ssh_manager.open_port_forward). + +Warning: Protocol version caveat: Values 3 and 4 for JSON_FORMAT and MONITOR_MODE +are ASSUMPTIONS -- not verified against actual rcclras server responses for +v2.28.9 and v2.29.2. Verify by running a v2.28.9 rcclras server and checking +its handshake response before implementing version-gated features. If the +server always responds SERVER PROTOCOL 2, the version guards will never +activate and an alternative feature-detection mechanism will be required. +""" + +import asyncio +import logging +from typing import AsyncIterator + +logger = logging.getLogger(__name__) + + +class ProtocolError(Exception): + """Raised when the rcclras server responds unexpectedly.""" + + +class ProtocolVersionError(ProtocolError): + """Raised when a feature requires a higher protocol version than the server supports.""" + + +class ProtocolVersion: + TEXT_ONLY = 2 # v2.28.3: STATUS, VERBOSE STATUS, TIMEOUT only + JSON_FORMAT = 3 # v2.28.9+: adds SET FORMAT json (ASSUMPTION — verify) + MONITOR_MODE = 4 # v2.29.2+: adds MONITOR [groups] (ASSUMPTION — verify) + + +class RCCLRasClient: + """ + Async rcclras TCP client. + + Takes pre-connected reader/writer (from ssh_manager.open_port_forward()). + Connection lifetime is managed by the caller's context manager -- do not + close writer here. + + Version guards: methods requiring features not in the server's protocol + version raise ProtocolVersionError rather than sending unknown commands + (which would return ERROR: Unknown command and stall the reader). + """ + + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + self._reader = reader + self._writer = writer + self.server_protocol: int = 0 # Set during handshake + + async def handshake(self) -> int: + """ + Send CLIENT PROTOCOL 2, read SERVER PROTOCOL N. + Returns server protocol version. + Protocol mismatch is logged but not fatal (matches rcclras behavior). + """ + self._writer.write(b"CLIENT PROTOCOL 2\n") + await self._writer.drain() + line = await asyncio.wait_for(self._reader.readline(), timeout=5.0) + version_str = line.decode().strip().removeprefix("SERVER PROTOCOL ") + try: + self.server_protocol = int(version_str) + except ValueError: + raise ProtocolError(f"Unexpected handshake response: {line!r}") + return self.server_protocol + + async def set_timeout(self, secs: int) -> None: + """Set collective timeout. Available in all versions (v2.28.3+).""" + self._writer.write(f"TIMEOUT {secs}\n".encode()) + await self._writer.drain() + line = await asyncio.wait_for(self._reader.readline(), timeout=5.0) + if line.decode().strip() != "OK": + raise ProtocolError(f"Expected OK after TIMEOUT, got: {line!r}") + + async def set_format(self, fmt: str = "json") -> None: + """ + Set output format. Available only in v2.28.9+ (protocol 3+). + Raises ProtocolVersionError if server does not support it. + """ + if self.server_protocol < ProtocolVersion.JSON_FORMAT: + raise ProtocolVersionError( + f"SET FORMAT requires protocol {ProtocolVersion.JSON_FORMAT}+, " + f"server is {self.server_protocol}" + ) + self._writer.write(f"SET FORMAT {fmt}\n".encode()) + await self._writer.drain() + line = await asyncio.wait_for(self._reader.readline(), timeout=5.0) + if line.decode().strip() != "OK": + raise ProtocolError(f"Expected OK after SET FORMAT, got: {line!r}") + + async def get_status(self, verbose: bool = True) -> str: + """ + Send STATUS or VERBOSE STATUS. Reads until EOF (server closes after dump). + Returns raw text. The caller is responsible for parsing. + Available in all versions. + """ + cmd = b"VERBOSE STATUS\n" if verbose else b"STATUS\n" + self._writer.write(cmd) + await self._writer.drain() + # rcclras streams the response in chunks (e.g. sends header, then waits for + # all ranks to report before sending the communicator table). + # asyncio.StreamReader.read(n) returns as soon as ANY data is available — + # it does NOT wait for EOF. We must loop until read() returns b'' (EOF). + # The 30s timeout wraps the entire accumulation, and we cap at 1 MB total. + chunks: list[bytes] = [] + total = 0 + _MAX = 1024 * 1024 + + async def _read_all() -> bytes: + nonlocal total + while total < _MAX: + chunk = await self._reader.read(_MAX - total) + if not chunk: # EOF + break + chunks.append(chunk) + total += len(chunk) + return b"".join(chunks) + + data = await asyncio.wait_for(_read_all(), timeout=30.0) + if total >= _MAX: + logger.warning("rcclras VERBOSE STATUS response truncated at 1 MB — unexpected output size") + return data.decode() + + async def start_monitor(self, groups: str = "all") -> AsyncIterator[str]: + """ + Send MONITOR [groups] and yield lines until connection closes. + Available only in v2.29.2+ (protocol 4+). + Raises ProtocolVersionError if server does not support it. + """ + if self.server_protocol < ProtocolVersion.MONITOR_MODE: + raise ProtocolVersionError( + f"MONITOR requires protocol {ProtocolVersion.MONITOR_MODE}+, " + f"server is {self.server_protocol}" + ) + self._writer.write(f"MONITOR {groups}\n".encode()) + await self._writer.drain() + ok = await asyncio.wait_for(self._reader.readline(), timeout=5.0) + if ok.decode().strip() != "OK": + raise ProtocolError(f"Expected OK after MONITOR, got: {ok!r}") + async for line in self._reader: + yield line.decode() diff --git a/cvs/monitors/cluster-mon/backend/app/collectors/rccl_text_parser.py b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_text_parser.py new file mode 100644 index 00000000..c4177df7 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/collectors/rccl_text_parser.py @@ -0,0 +1,240 @@ +""" +Parser for rcclras VERBOSE STATUS text output (RCCL v2.28.3+). + +Driven by regex against real captured output. The C source (client_support.cc) +confirms field semantics but this parser is built from actual rcclras output. +""" + +import re +import time +import logging +from typing import Optional + +from app.models.rccl_models import ( + RCCLSnapshot, + RCCLJobState, + RCCLJobSummary, + RCCLCommunicator, +) + +logger = logging.getLogger(__name__) + + +class RCCLTextParser: + """ + Parses rcclras -v (VERBOSE STATUS) text output into RCCLSnapshot. + + Format (RCCL v2.28.3): + - Line 1: RCCL version X.Y.Z compiled with ROCm "..." + - Line 2: HIP runtime version N, amdgpu driver version N + - Job summary table: Nodes/Processes/GPUs counts + - Communicators table: Group-based with Status and Errors columns + - Errors section (if any) + - Warnings section (if any) + """ + + # Regex patterns + _VERSION_RE = re.compile( + r"RCCL version (\S+)\s+compiled with ROCm" + ) + _HIP_DRIVER_RE = re.compile( + r"HIP runtime version (\d+),\s*amdgpu driver version (\d+)" + ) + # Anchored to the "Job summary" section to avoid matching spurious 5-integer lines. + # Lookahead stops at the next section header (word chars followed by a === underline) + # or end-of-string. We deliberately do NOT use (?=^\S|\Z) because the column header + # line "(total) per node ..." starts at column 0 and would prematurely end the match. + _JOB_SUMMARY_SECTION_RE = re.compile( + r"^Job summary\s*\n=+\s*\n(.*?)(?=^\w[^\n]*\n=+|\Z)", + re.MULTILINE | re.DOTALL, + ) + _JOB_SUMMARY_ROW_RE = re.compile( + r"^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s*$", + re.MULTILINE, + ) + # Anchored to the "Communicators" section (same lookahead rationale as above) + _COMM_SECTION_RE = re.compile( + r"^Communicators.*?\n=+\s*\n(.*?)(?=^\w[^\n]*\n=+|\Z)", + re.MULTILINE | re.DOTALL, + ) + _COMM_ROW_RE = re.compile( + r"^\s*(\d+)\s+(\d+)\s+(\d+)\s+(\d+(?:-\d+)?)\s+(\d+)\s+(\d+)\s+(\S+)\s+(\S+)\s*$", + re.MULTILINE, + ) + _CONNECTION_REFUSED_RE = re.compile( + r"Connection refused|Failed to connect|Connection reset by peer", + re.IGNORECASE, + ) + + def parse(self, raw_text: str) -> RCCLSnapshot: + """Parse rcclras -v output into an RCCLSnapshot.""" + if not raw_text or not raw_text.strip(): + return RCCLSnapshot.empty(state=RCCLJobState.NO_JOB) + + # Check for connection refused (no job running) + if self._CONNECTION_REFUSED_RE.search(raw_text): + return RCCLSnapshot.empty(state=RCCLJobState.NO_JOB) + + try: + job_summary = self._parse_job_summary(raw_text) + communicators = self._parse_communicators(raw_text) + dead_peers = self._parse_dead_peers(raw_text) + errors = self._parse_errors_section(raw_text) + + # Determine overall state + state = self._determine_state(communicators, dead_peers, errors, job_summary) + + return RCCLSnapshot( + timestamp=time.time(), + state=state, + job_summary=job_summary, + communicators=communicators, + peers=[], # Not in v2.28.3 text output + dead_peers=dead_peers, + errors=errors, + ) + except Exception as e: + logger.error(f"Failed to parse rcclras output: {e}", exc_info=True) + return RCCLSnapshot.empty(state=RCCLJobState.ERROR) + + def _parse_job_summary(self, text: str) -> Optional[RCCLJobSummary]: + """Extract job summary from the table after 'Job summary'.""" + # RCCL version + version_match = self._VERSION_RE.search(text) + rccl_version = version_match.group(1) if version_match else "unknown" + + # HIP/driver versions + hip_match = self._HIP_DRIVER_RE.search(text) + hip_version = int(hip_match.group(1)) if hip_match else 0 + driver_version = int(hip_match.group(2)) if hip_match else 0 + + # Job summary numbers table — search within the "Job summary" section only + # to avoid false-matching any other line that happens to have 5 integers. + section_match = self._JOB_SUMMARY_SECTION_RE.search(text) + section = section_match.group(1) if section_match else "" + summary_match = self._JOB_SUMMARY_ROW_RE.search(section) + if not summary_match: + return None + + total_nodes = int(summary_match.group(1)) + procs_per_node = int(summary_match.group(2)) + gpus_per_proc = int(summary_match.group(3)) + total_procs = int(summary_match.group(4)) + total_gpus = int(summary_match.group(5)) + + # Check for inconsistent topology (when processes/node varies) + # In v2.28.3, the table always shows uniform values. Non-uniform + # topologies would show different format. For now, assume consistent + # if we get a single row. + inconsistent = False + + return RCCLJobSummary( + total_nodes=total_nodes, + total_processes=total_procs, + total_gpus=total_gpus, + rccl_version=rccl_version, + hip_runtime_version=hip_version, + amdgpu_driver_version=driver_version, + inconsistent_topology=inconsistent, + ) + + def _parse_communicators(self, text: str) -> list[RCCLCommunicator]: + """ + Extract communicator groups from the table after 'Communicators'. + + Format: + Group Comms Nodes Ranks Ranks Ranks Status Errors + # in grp per c per node per comm in group + 0 1 2 7-8 16 15 RUNNING INCOMPLETE + """ + # Search within the Communicators section to avoid matching other tables + section_match = self._COMM_SECTION_RE.search(text) + section = section_match.group(1) if section_match else text + + comms = [] + for match in self._COMM_ROW_RE.finditer(section): + group_num = int(match.group(1)) + comms_in_group = int(match.group(2)) + # ranks_per_node (group 4) may be a range like "7-8" on heterogeneous topologies + ranks_per_comm = int(match.group(5)) # expected ranks in ONE communicator + ranks_in_group = int(match.group(6)) # actual respondents across ALL comms in group + status = match.group(7) + errors = match.group(8) + + # Determine health from status and errors columns + if errors == "OK" and status == "RUNNING": + health = RCCLJobState.HEALTHY + elif errors != "OK": + health = RCCLJobState.DEGRADED + else: + health = RCCLJobState.HEALTHY + + # "Ranks in group" spans ALL communicators in the group, so total expected + # = ranks_per_comm * comms_in_group. Missing = expected - responding. + total_expected = ranks_per_comm * comms_in_group + missing = max(0, total_expected - ranks_in_group) + + # In v2.28.3, we don't get per-communicator hashes from the table, + # so use group number as placeholder + comms.append(RCCLCommunicator( + comm_hash=f"group_{group_num}", + total_ranks=total_expected, + responding_ranks=ranks_in_group, + missing_ranks=missing, + ranks=[], # Per-rank detail only in verbose with outliers + health=health, + )) + return comms + + def _parse_dead_peers(self, text: str) -> list[str]: + """Extract dead peer addresses if present.""" + # v2.28.3: dead peers appear between job summary and communicators + # Format: "Dead peers: IP:port, IP:port, ..." + dead_re = re.compile(r"Dead peers?:\s*(.+)", re.IGNORECASE) + match = dead_re.search(text) + if match: + peers_str = match.group(1).strip() + return [p.strip() for p in peers_str.split(",") if p.strip()] + return [] + + def _parse_errors_section(self, text: str) -> list[str]: + """Extract error lines from the Errors section.""" + errors = [] + # Find content between "Errors" header and "Warnings" header (or end) + # Stop at any section header (word followed by === underline) or end of string. + # Using \Z (end-of-string) instead of $ (end-of-line) so the lazy .*? + # doesn't stop at the first line ending when re.MULTILINE is active. + errors_section = re.search( + r"^Errors\s*\n=+\s*\n(.*?)(?=^\w[^\n]*\n=+|\Z)", + text, + re.MULTILINE | re.DOTALL, + ) + if errors_section: + content = errors_section.group(1).strip() + if content: + errors = [line.strip() for line in content.splitlines() if line.strip()] + return errors + + def _determine_state( + self, + communicators: list[RCCLCommunicator], + dead_peers: list[str], + errors: list[str], + job_summary: Optional[RCCLJobSummary] = None, + ) -> RCCLJobState: + """Determine overall job state from parsed data.""" + # If we couldn't parse a job summary, the text isn't valid rcclras output + if job_summary is None and not communicators: + return RCCLJobState.NO_JOB + if dead_peers: + return RCCLJobState.DEGRADED + if errors: + return RCCLJobState.DEGRADED + for comm in communicators: + if comm.health == RCCLJobState.DEGRADED: + return RCCLJobState.DEGRADED + if comm.missing_ranks > 0: + return RCCLJobState.DEGRADED + if not communicators: + return RCCLJobState.HEALTHY # No comms but no errors = job starting up + return RCCLJobState.HEALTHY diff --git a/cvs/monitors/cluster-mon/backend/app/core/config.py b/cvs/monitors/cluster-mon/backend/app/core/config.py index 3a9df112..9a3c8b44 100644 --- a/cvs/monitors/cluster-mon/backend/app/core/config.py +++ b/cvs/monitors/cluster-mon/backend/app/core/config.py @@ -1,175 +1,194 @@ """ -Configuration management for GPU cluster monitor. +Configuration management for CVS Cluster Monitor. +Uses pydantic-settings with a YAML source and env var overrides. """ -from pydantic_settings import BaseSettings -from pydantic import Field -from typing import Optional, List +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict +from typing import Any, Optional, Tuple, Type, List import yaml from pathlib import Path +import os -class JumpHostConfig(BaseSettings): - """Jump host (bastion) configuration.""" +class JumpHostConfig(BaseModel): + enabled: bool = False + host: Optional[str] = None + username: str = "root" + password: Optional[str] = None + key_file: str = "~/.ssh/id_rsa" + node_username: str = "root" # replaces: node_username_via_jumphost + node_key_file: str = "~/.ssh/id_rsa" # replaces: node_key_file_on_jumphost - enabled: bool = Field(default=False, description="Enable jump host") - host: Optional[str] = Field(default=None, description="Jump host IP/hostname") - username: str = Field(default="root", description="Jump host username") - password: Optional[str] = Field(default=None, description="Jump host password") - key_file: str = Field(default="~/.ssh/id_rsa", description="Jump host SSH key") +class SSHConfig(BaseModel): + username: str = "root" + key_file: str = "~/.ssh/id_rsa" + password: Optional[str] = None + timeout: int = 30 + jump_host: JumpHostConfig = Field(default_factory=JumpHostConfig) -class ClusterSSHConfig(BaseSettings): - """SSH configuration for cluster nodes.""" - username: str = Field(default="root", description="SSH username") - password: Optional[str] = Field(default=None, description="SSH password") - key_file: str = Field(default="~/.ssh/id_rsa", description="SSH private key file") - timeout: int = Field(default=30, description="SSH timeout in seconds") - jump_host: JumpHostConfig = Field(default_factory=JumpHostConfig, description="Jump host config") +class PollingConfig(BaseModel): + interval: int = 60 + batch_size: int = 10 + stagger_delay: int = 2 + failure_threshold: int = 5 -class PollingConfig(BaseSettings): - """Polling configuration.""" +class AlertsConfig(BaseModel): + gpu_temp_threshold: float = 85.0 + gpu_util_threshold: float = 95.0 - interval: int = Field(default=15, description="Polling interval in seconds") - batch_size: int = Field(default=10, description="Number of nodes per batch") - stagger_delay: int = Field(default=2, description="Delay between batches in seconds") +class RedisConfig(BaseModel): + url: str = "redis://localhost:6379" + db: int = 0 + password: Optional[str] = None + snapshot_max_entries: int = 1000 + event_max_entries: int = 10000 -class RedisConfig(BaseSettings): - """Redis configuration.""" - host: str = Field(default="localhost", description="Redis host") - port: int = Field(default=6379, description="Redis port") - password: Optional[str] = Field(default=None, description="Redis password") - db: int = Field(default=0, description="Redis database number") - ttl: int = Field(default=900, description="TTL for cached data in seconds (15 min)") +class StorageConfig(BaseModel): + redis: RedisConfig = Field(default_factory=RedisConfig) -class InfluxDBConfig(BaseSettings): - """InfluxDB configuration.""" +class InspectorConfig(BaseModel): + """Configuration for the RCCL Inspector plugin collector.""" + enabled: bool = False + mode: str = "file" # "file" (NFS) or "ssh" + dump_dir: Optional[str] = None # NFS path where Inspector writes .log files + poll_interval: int = 30 # seconds between collection cycles + max_records_per_file: int = 100 # tail last N lines per log file - url: str = Field(default="http://localhost:8086", description="InfluxDB URL") - token: Optional[str] = Field(default=None, description="InfluxDB auth token") - org: str = Field(default="gpu-monitor", description="InfluxDB organization") - bucket: str = Field(default="gpu_cluster", description="InfluxDB bucket") +class RCCLConfig(BaseModel): + """Forward-declaration for RCCL extension config. No runtime behaviour in base robustness spec.""" + ras_port: int = 28028 + poll_interval: int = 30 + collective_timeout_secs: int = 10 + debug_log_path: Optional[str] = None + inspector: InspectorConfig = Field(default_factory=InspectorConfig) -class AlertConfig(BaseSettings): - """Alert threshold configuration.""" - gpu_temp_threshold: float = Field(default=85.0, description="GPU temperature threshold (C)") - gpu_util_threshold: float = Field(default=95.0, description="GPU utilization threshold (%)") - gpu_mem_threshold: float = Field(default=95.0, description="GPU memory threshold (%)") - error_count_threshold: int = Field(default=10, description="Error count threshold") - nic_error_threshold: int = Field(default=100, description="NIC error count threshold") +class _YamlSource(PydanticBaseSettingsSource): + """ + Loads cluster.yaml as a pydantic-settings source. + Compatible with pydantic-settings 2.1.0. + """ + def __init__(self, settings_cls: Type[BaseSettings], yaml_path: Path): + super().__init__(settings_cls) + self._path = yaml_path -class Settings(BaseSettings): - """Main application settings.""" + def __call__(self) -> dict[str, Any]: + if self._path.exists(): + raw = yaml.safe_load(self._path.read_text()) or {} + return raw.get("cluster", {}) + return {} - # Application - app_name: str = "GPU Cluster Monitor" - debug: bool = False + def get_field_value(self, field: Any, field_name: str) -> tuple[Any, str, bool]: + # Required by PydanticBaseSettingsSource ABC in pydantic-settings 2.1.0. + # Not called when __call__() returns the full dict; stub satisfies the ABC. + raise NotImplementedError - # Cluster nodes - nodes_file: str = Field(default="../config/nodes.txt", description="File with node IPs") - nodes: List[str] = Field(default_factory=list, description="List of node IPs") + def field_is_complex(self, field: Any) -> bool: + return True - # Sub-configurations - ssh: ClusterSSHConfig = Field(default_factory=ClusterSSHConfig) - polling: PollingConfig = Field(default_factory=PollingConfig) - redis: RedisConfig = Field(default_factory=RedisConfig) - influxdb: InfluxDBConfig = Field(default_factory=InfluxDBConfig) - alerts: AlertConfig = Field(default_factory=AlertConfig) - # API +class Settings(BaseSettings): + app_name: str = "CVS Cluster Monitor" api_prefix: str = "/api" - cors_origins: List[str] = Field( - default_factory=lambda: [ - "http://localhost:3000", - "http://localhost:5173", - # Configured via environment or simple_config.py - ] - ) + cors_origins: List[str] = Field(default_factory=lambda: [ + "http://localhost:3000", + "http://localhost:5173", + ]) + nodes_file: str = "config/nodes.txt" + ssh: SSHConfig = Field(default_factory=SSHConfig) + polling: PollingConfig = Field(default_factory=PollingConfig) + alerts: AlertsConfig = Field(default_factory=AlertsConfig) + storage: StorageConfig = Field(default_factory=StorageConfig) + rccl: RCCLConfig = Field(default_factory=RCCLConfig) + + model_config = SettingsConfigDict(env_nested_delimiter="__") - class Config: - env_file = ".env" - env_nested_delimiter = "__" - extra = "allow" # Allow extra fields for YAML loading compatibility + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + # Try Docker path first, then dev path + yaml_path = Path("/app/config/cluster.yaml") + if not yaml_path.exists(): + yaml_path = Path("../config/cluster.yaml") + return ( + init_settings, + env_settings, # env vars override YAML + _YamlSource(settings_cls, yaml_path), # YAML is primary source + file_secret_settings, + ) def load_nodes_from_file(self) -> List[str]: - """Load node IPs from file.""" - nodes_path = Path(self.nodes_file) - if nodes_path.exists(): - with open(nodes_path) as f: - nodes = [line.strip() for line in f if line.strip() and not line.startswith("#")] - self.nodes = nodes - return nodes + """Load node IPs from nodes file, trying multiple paths.""" + possible_paths = [ + Path("/app/config/nodes.txt"), + Path("../config/nodes.txt"), + Path(self.nodes_file), + ] + for p in possible_paths: + p = p.resolve() + if p.exists(): + nodes = [ + line.strip() + for line in p.read_text().splitlines() + if line.strip() and not line.startswith("#") + ] + if nodes: + return nodes return [] - @classmethod - def load_from_yaml(cls, yaml_file: str) -> "Settings": - """Load settings from YAML file.""" - with open(yaml_file) as f: - config_data = yaml.safe_load(f) + # Backward-compat properties used in existing main.py and api/config.py + # These will be removed after main.py is fully migrated. + @property + def node_username_via_jumphost(self) -> str: + return self.ssh.jump_host.node_username - # Flatten nested structure for Pydantic - settings_dict = {} + @property + def node_key_file_on_jumphost(self) -> str: + return self.ssh.jump_host.node_key_file - if "cluster" in config_data: - cluster = config_data["cluster"] + @property + def ssh_username(self) -> str: + return self.ssh.username - if "nodes_file" in cluster: - settings_dict["nodes_file"] = cluster["nodes_file"] + @property + def ssh_password(self) -> Optional[str]: + return self.ssh.password - if "ssh" in cluster: - for key, value in cluster["ssh"].items(): - if key == "jump_host" and isinstance(value, dict): - # Handle nested jump_host configuration - for jh_key, jh_value in value.items(): - settings_dict[f"ssh__jump_host__{jh_key}"] = jh_value - else: - settings_dict[f"ssh__{key}"] = value + @property + def ssh_key_file(self) -> str: + return self.ssh.key_file - if "polling" in cluster: - for key, value in cluster["polling"].items(): - settings_dict[f"polling__{key}"] = value + @property + def jump_host_enabled(self) -> bool: + return self.ssh.jump_host.enabled - if "storage" in cluster: - storage = cluster["storage"] - if "redis" in storage: - for key, value in storage["redis"].items(): - settings_dict[f"redis__{key}"] = value - if "influxdb" in storage: - for key, value in storage["influxdb"].items(): - settings_dict[f"influxdb__{key}"] = value + @property + def jump_host(self) -> Optional[str]: + return self.ssh.jump_host.host - if "alerts" in cluster: - for key, value in cluster["alerts"].items(): - settings_dict[f"alerts__{key}"] = value + @property + def jump_host_username(self) -> str: + return self.ssh.jump_host.username - return cls(**settings_dict) + @property + def jump_host_key_file(self) -> str: + return self.ssh.jump_host.key_file # Global settings instance -# Try to load from YAML first, fall back to defaults -try: - # Get absolute path to config file (relative to backend directory) - yaml_path = Path("../config/cluster.yaml").resolve() - - if yaml_path.exists(): - print(f"Loading configuration from: {yaml_path}") - settings = Settings.load_from_yaml(str(yaml_path)) - print(f"Jump host enabled: {settings.ssh.jump_host.enabled}") - print(f"Jump host: {settings.ssh.jump_host.host}") - else: - print(f"YAML file not found at: {yaml_path}, using defaults") - settings = Settings() -except Exception as e: - print(f"Error loading YAML config: {e}") - import traceback - - traceback.print_exc() - settings = Settings() +settings = Settings() diff --git a/cvs/monitors/cluster-mon/backend/app/core/cvs_parallel_ssh_reliable.py b/cvs/monitors/cluster-mon/backend/app/core/cvs_parallel_ssh_reliable.py index 28cdbf7e..6b154a28 100644 --- a/cvs/monitors/cluster-mon/backend/app/core/cvs_parallel_ssh_reliable.py +++ b/cvs/monitors/cluster-mon/backend/app/core/cvs_parallel_ssh_reliable.py @@ -9,6 +9,10 @@ from pssh.clients import ParallelSSHClient from pssh.exceptions import Timeout, ConnectionError +import asyncio +import socket +from contextlib import asynccontextmanager +from typing import AsyncIterator import time import logging import threading @@ -20,6 +24,7 @@ # TCP probe for fast reachability detection from app.core.host_probe import discover_reachable_hosts +from app.core.ssh_port_forward import _run_bridge # Module-level logger logger = logging.getLogger(__name__) @@ -78,9 +83,9 @@ def __init__( # Add authentication if self.password is None: - print(self.reachable_hosts) - print(self.user) - print(self.pkey) + logger.debug(f"Reachable hosts: {self.reachable_hosts}") + logger.debug(f"SSH user: {self.user}") + logger.debug(f"SSH key: {self.pkey}") client_params['pkey'] = self.pkey else: client_params['password'] = self.password @@ -113,6 +118,9 @@ def __init__( f"{len(self.reachable_hosts)} reachable, {len(self.unreachable_hosts)} unreachable" ) + self._pf_clients: dict[str, paramiko.SSHClient] = {} + self._pf_lock = threading.Lock() # protects _pf_clients dict + # Only create ParallelSSHClient with reachable hosts if not self.reachable_hosts: logger.warning("No reachable hosts found! SSH manager will be inactive") @@ -258,7 +266,7 @@ def prune_unreachable_hosts(self, output): ] unreachable = self.check_connectivity(failed_hosts) for host in unreachable: - print(f"Host {host} is unreachable, pruning from reachable hosts list.") + logger.info(f"Host {host} is unreachable, pruning from reachable hosts list.") self.unreachable_hosts.append(host) self.reachable_hosts.remove(host) if len(self.unreachable_hosts) > initial_unreachable_len: @@ -296,22 +304,22 @@ def _process_output(self, output, cmd=None, cmd_list=None, print_console=True): cmd_output = {} i = 0 for item in output: - print('#----------------------------------------------------------#') - print(f'Host == {item.host} ==') - print('#----------------------------------------------------------#') + logger.debug('#----------------------------------------------------------#') + logger.debug(f'Host == {item.host} ==') + logger.debug('#----------------------------------------------------------#') cmd_out_str = '' if cmd_list: - print(cmd_list[i]) + logger.debug(cmd_list[i]) else: - print(cmd) + logger.debug(cmd) try: for line in item.stdout or []: if print_console: - print(line) + logger.debug(line) cmd_out_str += line.replace('\t', ' ') + '\n' for line in item.stderr or []: if print_console: - print(line) + logger.debug(line) cmd_out_str += line.replace('\t', ' ') + '\n' except Timeout as e: if not self.stop_on_errors: @@ -323,7 +331,7 @@ def _process_output(self, output, cmd=None, cmd_list=None, print_console=True): exc_str = exc_str.replace('\t', ' ') if isinstance(item.exception, Timeout): exc_str += "\nABORT: Timeout Error in Host: " + item.host - print(exc_str) + logger.debug(exc_str) cmd_out_str += exc_str + '\n' if cmd_list: i += 1 @@ -375,12 +383,18 @@ def exec(self, cmd, timeout=None, print_console=True): # CRITICAL: Acquire lock to prevent concurrent SSH operations # parallel-ssh/paramiko/libssh2 are NOT thread-safe with _ssh_lock: + # Re-check after acquiring the lock: destroy_clients() may have run + # while we were waiting, setting self.client = None. + if not self.client: + logger.info("SSH client destroyed before command could run (shutdown race) — skipping") + return {host: "ABORT: Host Unreachable Error" for host in self.host_list} + logger.info(f"CVS Pssh executing: {cmd[:100]}...") logger.info(f"Calling ParallelSSHClient.run_command() on {len(self.reachable_hosts)} reachable nodes...") logger.info(f" Timeout: {timeout if timeout else 'default'}") logger.info(f" Stop on errors: {self.stop_on_errors}") - print(f'cmd = {cmd}') + logger.debug(f'cmd = {cmd}') try: if timeout is None: @@ -409,7 +423,7 @@ def exec_cmd_list(self, cmd_list, timeout=None, print_console=True): which runs the same command on all hosts. Returns a dictionary of host as key and command output as values """ - print(cmd_list) + logger.debug(cmd_list) if timeout is None: output = self.client.run_command('%s', host_args=cmd_list, stop_on_errors=self.stop_on_errors) else: @@ -420,7 +434,7 @@ def exec_cmd_list(self, cmd_list, timeout=None, print_console=True): return cmd_output def scp_file(self, local_file, remote_file, recurse=False): - print('About to copy local file {} to remote {} on all Hosts'.format(local_file, remote_file)) + logger.info('About to copy local file {} to remote {} on all Hosts'.format(local_file, remote_file)) cmds = self.client.copy_file(local_file, remote_file, recurse=recurse) self.client.pool.join() for cmd in cmds: @@ -439,13 +453,98 @@ def get_unreachable_hosts(self): return self.unreachable_hosts.copy() def reboot_connections(self): - print('Rebooting Connections') + logger.info('Rebooting Connections') self.client.run_command('reboot -f', stop_on_errors=self.stop_on_errors) + def _get_pf_transport(self, node: str) -> paramiko.Transport: + """ + Get or create a dedicated paramiko SSH client for port-forwarding to node. + Thread-safe. Separate from the pssh connection pool (which does not expose transports). + + Security note: Uses AutoAddPolicy() (TOFU). See plan for hardening options. + """ + with self._pf_lock: + client = self._pf_clients.get(node) + transport = client.get_transport() if client else None + if transport is None or not transport.is_active(): + new_client = paramiko.SSHClient() + new_client.set_missing_host_key_policy( + paramiko.AutoAddPolicy() + # Security note: AutoAddPolicy() accepts any host key without + # verification (TOFU). Production hardening: pre-distribute known + # host keys via Ansible/Puppet and use RejectPolicy() + a + # pre-populated known_hosts file, or use OpenSSH certificate auth. + ) + new_client.connect( + node, + username=self.user, + key_filename=self.pkey, + password=self.password, + timeout=self.timeout, + ) + if client: + try: + client.close() # close the stale connection before replacing + except Exception: + pass + self._pf_clients[node] = new_client + return self._pf_clients[node].get_transport() + + @asynccontextmanager + async def open_port_forward( + self, node: str, remote_port: int + ) -> AsyncIterator[tuple]: + """ + Open a single-hop SSH tunnel: monitoring_host -> node:remote_port. + + Yields (asyncio.StreamReader, asyncio.StreamWriter) ready for asyncio use. + Uses a Unix socketpair() -- no ephemeral TCP port allocation, no TOCTOU race. + """ + asyncio_end, thread_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + try: + transport = await asyncio.to_thread(self._get_pf_transport, node) + channel = await asyncio.to_thread( + transport.open_channel, + "direct-tcpip", + ("::1", remote_port), + ("127.0.0.1", 0), + ) + except Exception: + asyncio_end.close() + thread_end.close() + raise + + _run_bridge(channel, thread_end) + + try: + reader, writer = await asyncio.open_connection(sock=asyncio_end) + except Exception: + asyncio_end.close() + channel.close() + raise + + try: + yield reader, writer + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + channel.close() + thread_end.close() + def destroy_clients(self): - print('Destroying Current phdl connections ..') - if self.client: - del self.client + logger.info('Destroying Current phdl connections ..') + with _ssh_lock: + self.client = None # set to None (not del) so exec() guard stays valid + with self._pf_lock: + for c in self._pf_clients.values(): + try: + c.close() + except Exception: + pass + self._pf_clients.clear() async def exec_async(self, cmd, timeout=None, print_console=True): """ diff --git a/cvs/monitors/cluster-mon/backend/app/core/jump_host_pssh.py b/cvs/monitors/cluster-mon/backend/app/core/jump_host_pssh.py index c5ea5c69..0e26dfd7 100644 --- a/cvs/monitors/cluster-mon/backend/app/core/jump_host_pssh.py +++ b/cvs/monitors/cluster-mon/backend/app/core/jump_host_pssh.py @@ -3,11 +3,18 @@ Based on working test_auth_script.py approach. """ -import paramiko -from typing import List, Optional, Dict +import asyncio +from contextlib import asynccontextmanager +from typing import List, Optional, Dict, AsyncIterator import logging +import socket +import threading import time +import paramiko + +from app.core.ssh_port_forward import _run_bridge + # TCP probe for fast reachability detection from app.core.host_probe import probe_from_bastion @@ -80,6 +87,9 @@ def __init__( self._create_parallel_client() + self._exec_lock = threading.Lock() # serializes concurrent exec() calls + self._hosts_lock = threading.Lock() # protects unreachable_hosts/reachable_hosts mutations + def _is_jump_host_alive(self): """Check if jump host connection is still active.""" if not self.jump_client: @@ -87,7 +97,7 @@ def _is_jump_host_alive(self): try: transport = self.jump_client.get_transport() return transport is not None and transport.is_active() - except: + except Exception: return False def _ensure_jump_host_connection(self): @@ -101,7 +111,7 @@ def _ensure_jump_host_connection(self): if self.jump_client: try: self.jump_client.close() - except: + except Exception: pass # Reconnect @@ -116,7 +126,7 @@ def _connect_to_jump_host(self): logger.info(f"Connecting to jump host: {self.jump_host}") logger.info(f" Jump user: {self.jump_user}") logger.info( - f" Jump password: {'***SET*** (length={len(self.jump_password)})' if self.jump_password else 'NOT SET'}" + f" Jump password: {'***SET***' if self.jump_password else 'NOT SET'}" ) logger.info(f" Jump pkey: {self.jump_pkey if self.jump_pkey else 'NOT SET'}") @@ -126,9 +136,6 @@ def _connect_to_jump_host(self): try: if self.jump_password: logger.info(f"Attempting password authentication to {self.jump_host}...") - logger.info( - f" Password value check: {self.jump_password[:3]}*** (showing first 3 chars for verification)" - ) logger.info("Using password authentication for jump host") self.jump_client.connect( hostname=self.jump_host, @@ -154,15 +161,6 @@ def _connect_to_jump_host(self): logger.error(f"❌ Failed to connect to jump host: {e}") raise - def _make_proxy(self, host, port): - """Create proxy socket through jump host.""" - logger.debug(f"Creating proxy socket for {host}:{port}") - return self.jump_transport.open_channel( - "direct-tcpip", - (host, port), - ("", 0), - ) - def _create_parallel_client(self): """Setup for parallel execution - key file is ON the jump host.""" logger.info(f"Ready for parallel SSH execution to {len(self.target_hosts)} nodes") @@ -201,11 +199,12 @@ def _execute_on_node(self, node: str, cmd: str, timeout: Optional[int] = None) - x in error.lower() for x in ['connection timed out', 'connection refused', 'no route to host', 'host is down'] ): - if node not in self.unreachable_hosts: - logger.warning(f"[{node}] Marking as unreachable: {error[:200]}") - self.unreachable_hosts.append(node) - if node in self.reachable_hosts: - self.reachable_hosts.remove(node) + with self._hosts_lock: + if node not in self.unreachable_hosts: + logger.warning(f"[{node}] Marking as unreachable: {error[:200]}") + self.unreachable_hosts.append(node) + if node in self.reachable_hosts: + self.reachable_hosts.remove(node) return f"ABORT: Host Unreachable Error - {error[:100]}" elif not output: logger.warning(f"[{node}] stderr: {error[:200]}") @@ -217,11 +216,12 @@ def _execute_on_node(self, node: str, cmd: str, timeout: Optional[int] = None) - # Check if it's a timeout exception error_str = str(e).lower() if 'timeout' in error_str or 'timed out' in error_str: - if node not in self.unreachable_hosts: - logger.warning(f"[{node}] Marking as unreachable due to timeout: {e}") - self.unreachable_hosts.append(node) - if node in self.reachable_hosts: - self.reachable_hosts.remove(node) + with self._hosts_lock: + if node not in self.unreachable_hosts: + logger.warning(f"[{node}] Marking as unreachable due to timeout: {e}") + self.unreachable_hosts.append(node) + if node in self.reachable_hosts: + self.reachable_hosts.remove(node) return "ABORT: Host Unreachable Error - Timeout" logger.error(f"[{node}] Exception: {e}") @@ -233,80 +233,81 @@ def exec(self, cmd: str, timeout: Optional[int] = None, print_console: bool = Tr Uses ThreadPoolExecutor for parallel execution. Skips unreachable nodes and reports them separately. """ - # Ensure jump host connection is active before executing - if not self._ensure_jump_host_connection(): - logger.error("Cannot execute command - jump host connection failed") - return {node: "ERROR: Jump host connection failed" for node in self.target_hosts} - - logger.info(f"Executing command: {cmd[:100]}...") - logger.info( - f"Total nodes: {len(self.target_hosts)}, Reachable: {len(self.reachable_hosts)}, Unreachable: {len(self.unreachable_hosts)}" - ) + with self._exec_lock: + # Ensure jump host connection is active before executing + if not self._ensure_jump_host_connection(): + logger.error("Cannot execute command - jump host connection failed") + return {node: "ERROR: Jump host connection failed" for node in self.target_hosts} + + logger.info(f"Executing command: {cmd[:100]}...") + logger.info( + f"Total nodes: {len(self.target_hosts)}, Reachable: {len(self.reachable_hosts)}, Unreachable: {len(self.unreachable_hosts)}" + ) - from concurrent.futures import ThreadPoolExecutor, as_completed + from concurrent.futures import ThreadPoolExecutor, as_completed - results = {} - success_count = 0 - fail_count = 0 + results = {} + success_count = 0 + fail_count = 0 - # First, add unreachable hosts to results - for node in self.unreachable_hosts: - results[node] = "ABORT: Host Unreachable Error" - fail_count += 1 + # First, add unreachable hosts to results + for node in self.unreachable_hosts: + results[node] = "ABORT: Host Unreachable Error" + fail_count += 1 - try: - # Execute in parallel using ThreadPoolExecutor on reachable hosts only - with ThreadPoolExecutor(max_workers=self.max_parallel) as executor: - # Submit tasks only for reachable nodes - future_to_node = { - executor.submit(self._execute_on_node, node, cmd, timeout): node for node in self.reachable_hosts - } - - # Collect results as they complete - for future in as_completed(future_to_node): - node = future_to_node[future] - try: - output = future.result() - results[node] = output - - if output.startswith("ERROR") or output.startswith("ABORT"): - logger.error(f"❌ [{node}] FAILED: {output[:200]}") + try: + # Execute in parallel using ThreadPoolExecutor on reachable hosts only + with ThreadPoolExecutor(max_workers=self.max_parallel) as executor: + # Submit tasks only for reachable nodes + future_to_node = { + executor.submit(self._execute_on_node, node, cmd, timeout): node for node in self.reachable_hosts + } + + # Collect results as they complete + for future in as_completed(future_to_node): + node = future_to_node[future] + try: + output = future.result() + results[node] = output + + if output.startswith("ERROR") or output.startswith("ABORT"): + logger.error(f"❌ [{node}] FAILED: {output[:200]}") + fail_count += 1 + else: + # Log first 3 lines + lines = output.split('\n')[:3] + logger.info(f"✅ [{node}] SUCCESS (first 3 lines):") + for line in lines: + if line.strip(): + logger.info(f" {line[:150]}") + success_count += 1 + + except Exception as e: + results[node] = f"ERROR: {str(e)}" + logger.error(f"❌ [{node}] Exception: {e}") fail_count += 1 - else: - # Log first 3 lines - lines = output.split('\n')[:3] - logger.info(f"✅ [{node}] SUCCESS (first 3 lines):") - for line in lines: - if line.strip(): - logger.info(f" {line[:150]}") - success_count += 1 - - except Exception as e: - results[node] = f"ERROR: {str(e)}" - logger.error(f"❌ [{node}] Exception: {e}") - fail_count += 1 - - logger.info(f"Results: {success_count} successful, {fail_count} failed") - - # If too many failures, trigger re-probe (connection issue detection) - failure_rate = fail_count / len(self.target_hosts) if self.target_hosts else 0 - if failure_rate > 0.5 and fail_count > 5: # More than 50% failed and at least 5 failures - logger.warning(f"High failure rate ({failure_rate:.1%}) - triggering re-probe") - self._handle_connection_failure() - - return results - except Exception as e: - logger.error(f"❌ Parallel execution failed: {e}", exc_info=True) - # Check if it's a connection error to jump host - if "connection" in str(e).lower() or "transport" in str(e).lower(): - logger.warning("Jump host connection issue detected - triggering re-probe") - self._handle_connection_failure() - raise + logger.info(f"Results: {success_count} successful, {fail_count} failed") + + # If too many failures, trigger re-probe (connection issue detection) + failure_rate = fail_count / len(self.target_hosts) if self.target_hosts else 0 + if failure_rate > 0.5 and fail_count > 5: # More than 50% failed and at least 5 failures + logger.warning(f"High failure rate ({failure_rate:.1%}) - triggering re-probe") + self._handle_connection_failure() + + return results - async def exec_async(self, cmd, timeout=None, print_console=True): - """Async wrapper - just calls exec() directly.""" - return self.exec(cmd, timeout, print_console) + except Exception as e: + logger.error(f"❌ Parallel execution failed: {e}", exc_info=True) + # Check if it's a connection error to jump host + if "connection" in str(e).lower() or "transport" in str(e).lower(): + logger.warning("Jump host connection issue detected - triggering re-probe") + self._handle_connection_failure() + raise + + async def exec_async(self, cmd: str, timeout: int = 30, print_console: bool = False) -> dict: + """Non-blocking wrapper around exec() using asyncio.to_thread.""" + return await asyncio.to_thread(self.exec, cmd, timeout, print_console) def get_reachable_hosts(self): """Return list of reachable hosts.""" @@ -350,8 +351,9 @@ def refresh_host_reachability(self): logger.info(f" Newly unreachable ({len(newly_unreachable)}): {list(newly_unreachable)[:10]}") # Update lists - self.reachable_hosts = new_reachable - self.unreachable_hosts = new_unreachable + with self._hosts_lock: + self.reachable_hosts = new_reachable + self.unreachable_hosts = new_unreachable return len(old_reachable) != len(new_reachable_set) or old_reachable != new_reachable_set @@ -388,17 +390,76 @@ def _handle_connection_failure(self): else: logger.info("No reachability changes detected") + @asynccontextmanager + async def open_port_forward( + self, node: str, remote_port: int + ) -> AsyncIterator[tuple]: + """ + Open a two-hop SSH tunnel: monitoring_host -> jump_host -> node:remote_port. + + Yields (asyncio.StreamReader, asyncio.StreamWriter) ready for asyncio use. + Uses a Unix socketpair() -- no ephemeral TCP port allocation, no TOCTOU race. + + Security note: Uses AutoAddPolicy() for host key verification (TOFU). + See plan for hardening options (pre-distributed known_hosts, SSH certificates). + + Args: + node: Target node hostname/IP + remote_port: Port on the target node to forward to + + Yields: + (reader, writer) connected to node:remote_port via the jump host + """ + # Ensure jump host connection is alive before opening port forward + if not await asyncio.to_thread(self._is_jump_host_alive): + await asyncio.to_thread(self._ensure_jump_host_connection) + + asyncio_end, thread_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + try: + channel = await asyncio.to_thread( + self.jump_transport.open_channel, + "direct-tcpip", + ("::1", remote_port), # rcclras binds to IPv6 loopback only + ("127.0.0.1", 0), + ) + except Exception: + asyncio_end.close() + thread_end.close() + raise + + _run_bridge(channel, thread_end) + + try: + reader, writer = await asyncio.open_connection(sock=asyncio_end) + except Exception: + asyncio_end.close() + channel.close() + raise + + try: + yield reader, writer + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + channel.close() + thread_end.close() + # asyncio_end is owned by the asyncio transport after open_connection(); + # closing writer causes it to be closed automatically. + def destroy_clients(self): """Clean up connections.""" logger.info("Closing connections...") if self.client: try: self.client.disconnect() - except: + except Exception: pass if self.jump_client: try: self.jump_client.close() logger.info("✅ Jump host connection closed") - except: + except Exception: pass diff --git a/cvs/monitors/cluster-mon/backend/app/core/simple_config.py b/cvs/monitors/cluster-mon/backend/app/core/simple_config.py deleted file mode 100644 index 1b14e894..00000000 --- a/cvs/monitors/cluster-mon/backend/app/core/simple_config.py +++ /dev/null @@ -1,251 +0,0 @@ -""" -Simplified configuration loader - reads directly from YAML. -Avoids Pydantic BaseSettings nested model issues. -""" - -import yaml -from pathlib import Path -from typing import List, Optional - - -class SimpleConfig: - """Simple configuration loader from YAML file.""" - - def __init__(self, yaml_path: str = None): - # Auto-detect config path for both dev and Docker - if yaml_path is None: - # Try Docker path first - docker_path = Path("/app/config/cluster.yaml") - if docker_path.exists(): - yaml_path = str(docker_path) - else: - # Fallback to development path - yaml_path = "../config/cluster.yaml" - - self.yaml_path = Path(yaml_path).resolve() - self.config_data = {} - self.load() - - def load(self): - """Load configuration from YAML file.""" - if self.yaml_path.exists(): - with open(self.yaml_path) as f: - data = yaml.safe_load(f) - self.config_data = data.get("cluster", {}) - else: - print(f"Warning: Config file not found at {self.yaml_path}") - self.config_data = {} - - def get_nodes_file(self) -> str: - """Get nodes file path.""" - # Try Docker path first - docker_nodes = Path("/app/config/nodes.txt") - if docker_nodes.exists(): - return str(docker_nodes) - return self.config_data.get("nodes_file", "../config/nodes.txt") - - def load_nodes_from_file(self) -> List[str]: - """Load node IPs from nodes file.""" - # Try multiple possible paths - import os - - possible_paths = [ - Path("/app/config/nodes.txt"), # Docker path (first priority) - Path("../config/nodes.txt"), # Development path - Path(os.path.join(os.getenv("CLUSTER_MONITOR_HOME", "."), "config/nodes.txt")), - Path(self.get_nodes_file()), - ] - - for nodes_file in possible_paths: - nodes_file = nodes_file.resolve() - if nodes_file.exists(): - with open(nodes_file) as f: - nodes = [line.strip() for line in f if line.strip() and not line.startswith("#")] - if nodes: - return nodes - - return [] - - # SSH Configuration - @property - def ssh_username(self) -> str: - import os - - default_user = os.getenv("USER", "root") - return self.config_data.get("ssh", {}).get("username", default_user) - - @property - def ssh_password(self) -> Optional[str]: - # SECURITY: Password is stored in memory only (app_state), never in YAML - try: - from app.main import app_state - - return app_state.ssh_password - except: - return None - - @property - def ssh_key_file(self) -> str: - return self.config_data.get("ssh", {}).get("key_file", "~/.ssh/id_rsa") - - @property - def ssh_timeout(self) -> int: - return self.config_data.get("ssh", {}).get("timeout", 30) - - # Jump Host Configuration - @property - def jump_host_enabled(self) -> bool: - return self.config_data.get("ssh", {}).get("jump_host", {}).get("enabled", False) - - @property - def jump_host(self) -> Optional[str]: - if self.jump_host_enabled: - return self.config_data.get("ssh", {}).get("jump_host", {}).get("host") - return None - - @property - def jump_host_username(self) -> str: - import os - - default_user = os.getenv("USER", "root") - return self.config_data.get("ssh", {}).get("jump_host", {}).get("username", default_user) - - @property - def jump_host_password(self) -> Optional[str]: - # SECURITY: Password is stored in memory only (app_state), never in YAML - # However, for testing/development, we also check YAML - try: - from app.main import app_state - - if app_state.jump_host_password: - return app_state.jump_host_password - except: - pass - - # Fallback: Read from YAML (for testing/development only) - # Production should never have passwords in YAML - return self.config_data.get("ssh", {}).get("jump_host", {}).get("password") - - @property - def jump_host_key_file(self) -> str: - """Local keyfile to SSH to jump host.""" - return self.config_data.get("ssh", {}).get("jump_host", {}).get("key_file", "~/.ssh/id_rsa") - - @property - def node_username_via_jumphost(self) -> str: - """Username for cluster nodes when using jump host.""" - import os - - default_user = os.getenv("USER", "root") - # First check for node_username_via_jumphost at ssh level, then check jump_host.node_username - return self.config_data.get("ssh", {}).get("node_username_via_jumphost") or self.config_data.get("ssh", {}).get( - "jump_host", {} - ).get("node_username", default_user) - - @property - def node_key_file_on_jumphost(self) -> str: - """Path to private key ON JUMP HOST for accessing cluster nodes.""" - # First check for node_key_file_on_jumphost at ssh level, then check jump_host.node_key_file - return self.config_data.get("ssh", {}).get("node_key_file_on_jumphost") or self.config_data.get("ssh", {}).get( - "jump_host", {} - ).get("node_key_file", "~/.ssh/id_rsa") - - # Polling Configuration - @property - def polling_interval(self) -> int: - return self.config_data.get("polling", {}).get("interval", 60) - - @property - def polling_batch_size(self) -> int: - return self.config_data.get("polling", {}).get("batch_size", 10) - - @property - def polling_stagger_delay(self) -> int: - return self.config_data.get("polling", {}).get("stagger_delay", 2) - - # Alert Thresholds - @property - def gpu_temp_threshold(self) -> float: - return self.config_data.get("alerts", {}).get("gpu_temp_threshold", 85.0) - - @property - def gpu_util_threshold(self) -> float: - return self.config_data.get("alerts", {}).get("gpu_util_threshold", 95.0) - - # CORS - @property - def cors_origins(self) -> List[str]: - import os - - # Allow all origins in Docker, or specific origins from environment variable - cors_env = os.getenv("CORS_ORIGINS", "*") - if cors_env == "*": - return ["*"] - else: - return cors_env.split(",") - # Default for development - # return ["http://localhost:3000", "http://localhost:5173"] - - # App settings - @property - def app_name(self) -> str: - return "CVS Cluster Monitor" - - @property - def debug(self) -> bool: - return False - - @property - def api_prefix(self) -> str: - return "/api" - - @property - def nodes(self) -> List[str]: - return self.load_nodes_from_file() - - # SSH sub-object for compatibility - @property - def ssh(self): - class SSHConfig: - def __init__(self, parent): - self.parent = parent - self.username = parent.ssh_username - self.password = parent.ssh_password - self.key_file = parent.ssh_key_file - self.timeout = parent.ssh_timeout - - # Jump host sub-object - class JumpHost: - def __init__(self, parent): - self.enabled = parent.jump_host_enabled - self.host = parent.jump_host - self.username = parent.jump_host_username - self.password = parent.jump_host_password - self.key_file = parent.jump_host_key_file - - self.jump_host = JumpHost(parent) - - return SSHConfig(self) - - # Polling sub-object - @property - def polling(self): - class PollingConfig: - def __init__(self, parent): - import os - - # Allow environment variable override - self.interval = int(os.getenv('POLLING__INTERVAL', parent.polling_interval)) - self.batch_size = parent.polling_batch_size - self.stagger_delay = parent.polling_stagger_delay - self.failure_threshold = int( - os.getenv( - 'POLLING__FAILURE_THRESHOLD', parent.config_data.get('polling', {}).get('failure_threshold', 5) - ) - ) - - return PollingConfig(self) - - -# Global config instance -config = SimpleConfig() diff --git a/cvs/monitors/cluster-mon/backend/app/core/ssh_port_forward.py b/cvs/monitors/cluster-mon/backend/app/core/ssh_port_forward.py new file mode 100644 index 00000000..9e2c165b --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/core/ssh_port_forward.py @@ -0,0 +1,69 @@ +""" +Shared SSH port-forwarding bridge for CVS cluster-mon. + +_run_bridge() creates a bidirectional byte-copy between a paramiko Channel +and a Unix socketpair. Used by both Pssh and JumpHostPssh to implement +open_port_forward() without ephemeral TCP port allocation. +""" + +import socket +import threading +import paramiko +import logging + +logger = logging.getLogger(__name__) + + +def _run_bridge(channel: paramiko.Channel, sock: socket.socket) -> None: + """ + Start two daemon threads that copy bytes bidirectionally between + a paramiko channel and a Unix socket. + + When either direction closes (clean EOF or exception), close_all() + is called, which closes both the channel and the socket. This causes + the other direction's recv() to return empty bytes or raise, causing + the other thread to also exit cleanly. No thread leaks. + + Args: + channel: Open paramiko.Channel (e.g., from transport.open_channel) + sock: One end of a socketpair (the thread end, not the asyncio end) + """ + + def copy(src_recv, dst_send, cleanup): + try: + while True: + data = src_recv(4096) + if not data: + break + dst_send(data) + except Exception: + pass + finally: + cleanup() + + def close_all(): + try: + channel.close() + except Exception: + pass + try: + sock.shutdown(socket.SHUT_RDWR) + except Exception: + pass + try: + sock.close() + except Exception: + pass + + threading.Thread( + target=copy, + args=(channel.recv, sock.sendall, close_all), + daemon=True, + name=f"bridge-ch→sock-{id(channel)}", + ).start() + threading.Thread( + target=copy, + args=(sock.recv, channel.sendall, close_all), + daemon=True, + name=f"bridge-sock→ch-{id(channel)}", + ).start() diff --git a/cvs/monitors/cluster-mon/backend/app/main.py b/cvs/monitors/cluster-mon/backend/app/main.py index 31f748ad..3cea4085 100644 --- a/cvs/monitors/cluster-mon/backend/app/main.py +++ b/cvs/monitors/cluster-mon/backend/app/main.py @@ -14,13 +14,18 @@ import time from pathlib import Path -from app.core.simple_config import config as settings +from app.core.config import settings from app.core.cvs_parallel_ssh_reliable import Pssh from app.core.jump_host_pssh import JumpHostPssh from app.collectors.gpu_collector import GPUMetricsCollector from app.collectors.nic_collector import NICMetricsCollector +from app.collectors.rccl_collector import RCCLCollector +from app.collectors.inspector_collector import InspectorCollector +from app.collectors.base import BaseCollector, CollectorResult, CollectorState from app.api import router as api_router +import redis.asyncio as aioredis + # Configure logging based on DEBUG environment variable # Using RotatingFileHandler for circular log files with 1MB max size DEBUG_MODE = os.getenv("DEBUG", "false").lower() in ("true", "1", "yes") @@ -61,86 +66,164 @@ class AppState: """Global application state.""" def __init__(self): + # SSH manager self.ssh_manager: Optional[Union[Pssh, JumpHostPssh]] = None + + # Unified collector registry (BaseCollector pattern) + self.collectors: dict[str, BaseCollector] = {} + self.collector_tasks: dict[str, asyncio.Task] = {} + self.collector_results: dict[str, CollectorResult] = {} + + # Legacy fields kept for backward compat during transition self.gpu_collector: GPUMetricsCollector = None self.nic_collector: NICMetricsCollector = None + self.collection_task: asyncio.Task = None # deprecated + self.latest_metrics: dict = {} self.websocket_clients: List[WebSocket] = [] - self.collection_task: asyncio.Task = None self.is_collecting: bool = False - # Node health tracking (for stability - require 5 consecutive failures) - self.node_failure_count: dict = {} # {node: consecutive_failure_count} - self.node_health_status: dict = {} # {node: 'healthy'|'unhealthy'|'unreachable'} - # Software info cache (updated every 180 seconds since it rarely changes) + + # Node health tracking + self.node_failure_count: dict = {} + self.node_health_status: dict = {} + + # Software info cache self.cached_gpu_software: dict = {} self.cached_nic_software: dict = {} self.cached_nic_advanced: dict = {} self.gpu_software_cache_time: float = 0 self.nic_software_cache_time: float = 0 self.nic_advanced_cache_time: float = 0 - self.software_cache_ttl: int = 180 # 3 minutes - # SECURITY: Passwords stored in memory only (never persisted to disk) - self.ssh_password: str = None # Direct SSH password - self.jump_host_password: str = None # Jump host password - # Periodic host probe task + self.software_cache_ttl: int = 180 + + # SECURITY: Passwords stored in memory only + self.ssh_password: str = None + self.jump_host_password: str = None + + # Periodic host probe self.probe_task: Optional[asyncio.Task] = None self.last_probe_time: Optional[float] = None - self.probe_count: int = 0 # Track number of probes for periodic client recreation + self.probe_count: int = 0 + self.probe_requested: asyncio.Event = None # set by collectors on ConnectionError + + # Redis client + self.redis: Optional[object] = None + + # RCCL state + self.rccl_data_store = None # RCCLDataStore, set in lifespan + self.latest_rccl_snapshot: Optional[dict] = None + self.rccl_websocket_clients: List[WebSocket] = [] app_state = AppState() +_reload_lock = asyncio.Lock() + + +# SSH Transport Scaling Note: +# The SSH-based collection transport has a practical limit of ~500-800 nodes at +# 60-second poll intervals. Known constraints at 600 nodes: 3-5GB RSS, pool_size +# reduced to 50, global threading lock serializes SSH batches. For clusters +# significantly larger, consider deploying lightweight push agents (Telegraf +# amd_rocm_smi plugin or rocm-smi-exporter) on compute nodes. + +REGISTERED_COLLECTORS: list[type[BaseCollector]] = [ + GPUMetricsCollector, + NICMetricsCollector, + RCCLCollector, + InspectorCollector, +] + + +def _start_collector_task(c: BaseCollector) -> asyncio.Task: + """Create a supervised collector task that restarts on crash with exponential backoff.""" + _backoff = [1.0] # mutable cell for closure + + def _on_done(task: asyncio.Task) -> None: + if task.cancelled() or not app_state.is_collecting: + return + exc = task.exception() + if exc is None: + logger.warning(f"Collector {c.name} task exited unexpectedly — restarting") + else: + delay = _backoff[0] + logger.error( + f"Collector {c.name} crashed: {exc!r} — restarting in {delay:.0f}s", + exc_info=exc, + ) + + async def _restart(): + await asyncio.sleep(_backoff[0]) + _backoff[0] = min(_backoff[0] * 2, 120) + new_task = _start_collector_task(c) + app_state.collector_tasks[c.name] = new_task + + def _schedule_restart(): + restart_task = asyncio.create_task(_restart(), name=f"restart-{c.name}") + app_state.collector_tasks[f"_restart_{c.name}"] = restart_task + + asyncio.get_running_loop().call_soon(_schedule_restart) + + task = asyncio.create_task( + c.run(app_state.ssh_manager, app_state), + name=f"collector-{c.name}", + ) + task.add_done_callback(_on_done) + return task + async def reload_configuration(): """ Reload configuration without restarting the entire process. - Stops metrics collection, closes SSH connections, reloads config, reinitializes, and restarts. + Uses topology-diff to restart only collectors whose config actually changed. Returns: dict: Status of reload operation with success/error details """ - try: - logger.info("Starting configuration reload...") - - # 1. Stop metrics collection and periodic probe - if app_state.is_collecting: - logger.info("Stopping metrics collection and periodic probe...") - app_state.is_collecting = False - if app_state.collection_task: - app_state.collection_task.cancel() - try: - await app_state.collection_task - except asyncio.CancelledError: - pass - if app_state.probe_task: - app_state.probe_task.cancel() - try: - await app_state.probe_task - except asyncio.CancelledError: - pass + async with _reload_lock: + return await _reload_configuration_inner() + - # 2. Close existing SSH connections - if app_state.ssh_manager: - logger.info("Closing existing SSH connections...") - app_state.ssh_manager.destroy_clients() - app_state.ssh_manager = None - - # 3. Clear cached data - app_state.latest_metrics = {} - app_state.node_failure_count = {} - app_state.node_health_status = {} - app_state.cached_gpu_software = {} - app_state.cached_nic_software = {} - app_state.cached_nic_advanced = {} - app_state.gpu_software_cache_time = 0 - app_state.nic_software_cache_time = 0 - app_state.nic_advanced_cache_time = 0 - - # 4. Reload configuration from files - logger.info("Reloading configuration from cluster.yaml and nodes.txt...") - from app.core.simple_config import SimpleConfig - - new_config = SimpleConfig() +async def _reload_configuration_inner(): + try: + logger.info("Starting configuration reload (topology-diff)...") + + # 1. Snapshot old settings and load new settings + from app.core.config import Settings + import app.core.config as config_module + + old_settings = config_module.settings + new_config = Settings() # re-reads YAML and env vars + + # 2. Determine which config sections changed + ssh_changed = ( + old_settings.ssh.model_dump() != new_config.ssh.model_dump() + ) + rccl_changed = ( + old_settings.rccl.model_dump() != new_config.rccl.model_dump() + ) + polling_changed = ( + old_settings.polling.model_dump() != new_config.polling.model_dump() + ) + + logger.info( + f"Config diff: ssh_changed={ssh_changed}, rccl_changed={rccl_changed}, " + f"polling_changed={polling_changed}" + ) + + # 3. Update the global settings reference + config_module.settings = new_config + + # 4. Determine which collectors need restart + collectors_to_restart: set[str] = set() + if polling_changed: + # Interval changed — restart all polling collectors + collectors_to_restart = {cls.name for cls in REGISTERED_COLLECTORS} + else: + if ssh_changed: + collectors_to_restart.update({"gpu", "nic"}) # SSH-dependent + if rccl_changed: + collectors_to_restart.add("rccl") # 5. Load new nodes nodes = new_config.load_nodes_from_file() @@ -165,7 +248,7 @@ async def reload_configuration(): logger.info(f"Checking for SSH key (key-based auth): {key_file_expanded}") if not os.path.exists(key_file_expanded): - logger.warning(f"❌ SSH key file not found: {key_file_expanded}") + logger.warning(f"SSH key file not found: {key_file_expanded}") logger.warning("Please upload SSH keys via Configuration UI or run refresh-ssh-keys.sh") return { "success": False, @@ -174,7 +257,7 @@ async def reload_configuration(): "requires_key_upload": True, } else: - logger.info(f"✅ SSH key file found: {key_file_expanded}") + logger.info(f"SSH key file found: {key_file_expanded}") # List the key file to verify import subprocess @@ -184,58 +267,128 @@ async def reload_configuration(): except: pass else: - logger.info("✅ Using password authentication - no key file check needed") + logger.info("Using password authentication - no key file check needed") - # 7. Reinitialize SSH manager with new configuration - try: - if new_config.ssh.jump_host.enabled and new_config.ssh.jump_host.host: - num_nodes = len(nodes) - min(num_nodes, 5) + # 7. Cancel only affected collector tasks + for name in collectors_to_restart: + task = app_state.collector_tasks.get(name) + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + # Also cancel any pending restart tasks + restart_key = f"_restart_{name}" + restart_task = app_state.collector_tasks.get(restart_key) + if restart_task: + restart_task.cancel() + try: + await restart_task + except asyncio.CancelledError: + pass - logger.info(f"Reinitializing with jump host: {new_config.ssh.jump_host.host}") - logger.info(f"Jump Host Username: {new_config.ssh.jump_host.username}") - logger.info(f"Cluster Nodes: {len(nodes)} nodes") - logger.info(f"Cluster Username: {new_config.node_username_via_jumphost}") + # If nothing changed, we can skip SSH and collector restart + if not collectors_to_restart and not ssh_changed: + logger.info("No config sections changed — nothing to restart") + return { + "success": True, + "message": "Configuration reloaded (no changes detected)", + "nodes_count": len(nodes), + "jump_host_enabled": new_config.ssh.jump_host.enabled, + } - # Use JumpHostPssh - working approach from test_auth_script.py - app_state.ssh_manager = JumpHostPssh( - jump_host=new_config.ssh.jump_host.host, - jump_user=new_config.ssh.jump_host.username, - jump_password=new_config.ssh.jump_host.password, - jump_pkey=new_config.ssh.jump_host.key_file if not new_config.ssh.jump_host.password else None, - target_hosts=nodes, - target_user=new_config.node_username_via_jumphost, - target_pkey=new_config.node_key_file_on_jumphost, - max_parallel=min(len(nodes), 5), # Limit to 5 to avoid exhausting paramiko channels (conservative) - timeout=new_config.ssh.timeout, - ) - logger.info("JumpHostPssh initialized successfully") - else: - logger.info("Reinitializing with direct SSH (no jump host)") - logger.info(f"Username: {new_config.ssh.username}") - logger.info(f"Nodes: {len(nodes)} nodes") + # 8. Recreate SSH manager only if SSH config changed + if ssh_changed: + # Stop probe task — it depends on the SSH manager + if app_state.probe_task: + app_state.probe_task.cancel() + try: + await app_state.probe_task + except asyncio.CancelledError: + pass - app_state.ssh_manager = Pssh( - log=logger, - host_list=nodes, - user=new_config.ssh.username, - password=app_state.ssh_password, # Use in-memory password - pkey=new_config.ssh.key_file, - timeout=new_config.ssh.timeout, - stop_on_errors=False, - ) - logger.info("Direct SSH manager reinitialized") - except Exception as e: - logger.error(f"Failed to reinitialize SSH manager: {e}") - return {"success": False, "error": f"Failed to initialize SSH manager: {str(e)}", "nodes_count": len(nodes)} + if app_state.ssh_manager: + logger.info("Closing existing SSH connections (ssh config changed)...") + app_state.ssh_manager.destroy_clients() + app_state.ssh_manager = None + + # Clear cached data (node topology may have changed) + app_state.latest_metrics = {} + app_state.node_failure_count = {} + app_state.node_health_status = {} + app_state.cached_gpu_software = {} + app_state.cached_nic_software = {} + app_state.cached_nic_advanced = {} + app_state.gpu_software_cache_time = 0 + app_state.nic_software_cache_time = 0 + app_state.nic_advanced_cache_time = 0 - # 7. Restart metrics collection and periodic probe + try: + if new_config.ssh.jump_host.enabled and new_config.ssh.jump_host.host: + num_nodes = len(nodes) + min(num_nodes, 5) + + logger.info(f"Reinitializing with jump host: {new_config.ssh.jump_host.host}") + logger.info(f"Jump Host Username: {new_config.ssh.jump_host.username}") + logger.info(f"Cluster Nodes: {len(nodes)} nodes") + logger.info(f"Cluster Username: {new_config.ssh.jump_host.node_username}") + + # Use JumpHostPssh - working approach from test_auth_script.py + app_state.ssh_manager = JumpHostPssh( + jump_host=new_config.ssh.jump_host.host, + jump_user=new_config.ssh.jump_host.username, + jump_password=new_config.ssh.jump_host.password, + jump_pkey=new_config.ssh.jump_host.key_file if not new_config.ssh.jump_host.password else None, + target_hosts=nodes, + target_user=new_config.ssh.jump_host.node_username, + target_pkey=new_config.ssh.jump_host.node_key_file, + max_parallel=min(len(nodes), 5), # Limit to 5 to avoid exhausting paramiko channels (conservative) + timeout=new_config.ssh.timeout, + ) + logger.info("JumpHostPssh initialized successfully") + else: + logger.info("Reinitializing with direct SSH (no jump host)") + logger.info(f"Username: {new_config.ssh.username}") + logger.info(f"Nodes: {len(nodes)} nodes") + + app_state.ssh_manager = Pssh( + log=logger, + host_list=nodes, + user=new_config.ssh.username, + password=app_state.ssh_password, # Use in-memory password + pkey=new_config.ssh.key_file, + timeout=new_config.ssh.timeout, + stop_on_errors=False, + ) + logger.info("Direct SSH manager reinitialized") + except Exception as e: + logger.error(f"Failed to reinitialize SSH manager: {e}") + return {"success": False, "error": f"Failed to initialize SSH manager: {str(e)}", "nodes_count": len(nodes)} + + # Restart probe task with new SSH manager + app_state.probe_requested = asyncio.Event() + app_state.probe_task = asyncio.create_task(periodic_host_probe()) + + # 9. Restart only the affected collectors if app_state.ssh_manager and nodes: - logger.info("Restarting metrics collection and periodic probe...") app_state.is_collecting = True - app_state.collection_task = asyncio.create_task(collect_metrics_loop()) - app_state.probe_task = asyncio.create_task(periodic_host_probe()) - logger.info("Metrics collection and periodic probe restarted") + for cls in REGISTERED_COLLECTORS: + if cls.name in collectors_to_restart: + old_collector = app_state.collectors.get(cls.name) + c = cls() + # Transfer stateful fields so a config reload doesn't emit a + # spurious job_start event (new instance initialises to NO_JOB). + if old_collector is not None: + if hasattr(old_collector, 'job_state') and hasattr(c, 'job_state'): + c.job_state = old_collector.job_state + if hasattr(c, '_bootstrapped'): + c._bootstrapped = True # skip bootstrap — state already known + app_state.collectors[c.name] = c + app_state.collector_tasks[c.name] = _start_collector_task(c) + logger.info(f"Restarted collector: {c.name}") + else: + logger.info(f"Collector unchanged, kept running: {cls.name}") logger.info("Configuration reload completed successfully!") return { @@ -243,6 +396,7 @@ async def reload_configuration(): "message": "Configuration reloaded successfully", "nodes_count": len(nodes), "jump_host_enabled": new_config.ssh.jump_host.enabled, + "collectors_restarted": list(collectors_to_restart), } except Exception as e: @@ -287,94 +441,88 @@ def update_node_status(node: str, is_error: bool, error_type: str = 'unreachable return app_state.node_health_status[node] -async def collect_metrics_loop(): - """Background task to collect metrics periodically.""" - logger.info("Starting metrics collection loop") - - while app_state.is_collecting: +class ConnectionManager: + """ + WebSocket connection manager with per-client bounded queues. + Slow clients are disconnected instead of blocking the broadcast loop. + """ + def __init__(self, max_queue_size: int = 64): + self._clients: dict[int, WebSocket] = {} + self._queues: dict[int, asyncio.Queue] = {} + self._send_tasks: dict[int, asyncio.Task] = {} + self._max_queue_size = max_queue_size + self._closing: set[int] = set() # guard against concurrent double-close + + async def connect(self, websocket: WebSocket): + await websocket.accept() + client_id = id(websocket) + self._clients[client_id] = websocket + q: asyncio.Queue = asyncio.Queue(maxsize=self._max_queue_size) + self._queues[client_id] = q + self._send_tasks[client_id] = asyncio.create_task( + self._sender(client_id, websocket, q) + ) + + async def _sender(self, client_id: int, ws: WebSocket, queue: asyncio.Queue): try: - logger.info("Collecting metrics...") - - # Collect GPU and NIC metrics with connection error handling - try: - gpu_metrics = await app_state.gpu_collector.collect_all_metrics(app_state.ssh_manager) - nic_metrics = await app_state.nic_collector.collect_all_metrics(app_state.ssh_manager) - except ConnectionError as e: - # Connection error during metrics collection - trigger immediate re-probe - logger.error(f"ConnectionError during metrics collection: {e}") - logger.info("Triggering immediate host re-probe...") - - # Trigger immediate re-probe - if app_state.ssh_manager: - changed = await asyncio.to_thread(app_state.ssh_manager.refresh_host_reachability) - if changed: - await asyncio.to_thread(app_state.ssh_manager.recreate_client) - logger.info("SSH client recreated with updated reachable hosts") - - # Continue to next iteration (skip this round) - logger.info("Skipping this metrics collection round, will retry next interval") - await asyncio.sleep(settings.polling.interval) - continue - - # Package metrics - metrics_payload = { - "timestamp": gpu_metrics.get("timestamp") if isinstance(gpu_metrics, dict) else None, - "gpu": gpu_metrics if not isinstance(gpu_metrics, Exception) else {"error": str(gpu_metrics)}, - "nic": nic_metrics if not isinstance(nic_metrics, Exception) else {"error": str(nic_metrics)}, - } - - # Update node status based on metrics collection success/failure - # Check each node and update failure counters - if isinstance(gpu_metrics, dict): - util_data = gpu_metrics.get("utilization", {}) - for node in app_state.ssh_manager.host_list: - has_error = False - - if node in util_data: - node_data = util_data[node] - if isinstance(node_data, dict) and 'error' in node_data: - has_error = True - - # Update status with stability check (5 consecutive failures required) - update_node_status(node, has_error, 'unreachable') + while True: + message = await queue.get() + await ws.send_json(message) + except Exception as e: + logger.debug(f"WebSocket sender error for client {client_id}: {e}") + finally: + await self._remove(client_id) - # Store in app state - app_state.latest_metrics = metrics_payload + async def disconnect(self, websocket: WebSocket): + await self._remove(id(websocket)) - # Broadcast to WebSocket clients - await broadcast_metrics(metrics_payload) + async def _remove(self, client_id: int): + if client_id in self._closing: + return + self._closing.add(client_id) + try: + task = self._send_tasks.pop(client_id, None) + if task and not task.done(): + task.cancel() + self._queues.pop(client_id, None) + ws = self._clients.pop(client_id, None) + if ws: + try: + await ws.close() + except Exception: + pass + finally: + self._closing.discard(client_id) - logger.info(f"Metrics collected successfully. {len(app_state.websocket_clients)} clients notified") + def broadcast(self, message: dict): + """Non-blocking broadcast: enqueues to each client's queue.""" + to_remove = [] + for client_id, q in self._queues.items(): + try: + q.put_nowait(message) + except asyncio.QueueFull: + logger.warning(f"WebSocket client {client_id} queue full — disconnecting") + to_remove.append(client_id) + for client_id in to_remove: + asyncio.create_task(self._remove(client_id)) - except asyncio.CancelledError: - logger.info("Metrics collection task cancelled") - raise - except Exception as e: - logger.error(f"Error in metrics collection loop: {e}", exc_info=True) + @property + def client_count(self) -> int: + return len(self._clients) - # Wait for next interval - await asyncio.sleep(settings.polling.interval) - logger.info("Metrics collection loop stopped") +metrics_ws_manager = ConnectionManager() +rccl_ws_manager = ConnectionManager() async def broadcast_metrics(metrics: dict): - """Broadcast metrics to all connected WebSocket clients.""" - if not app_state.websocket_clients: - return + """Broadcast metrics to all connected WebSocket clients (non-blocking).""" + metrics_ws_manager.broadcast({"type": "metrics", "data": metrics}) - disconnected_clients = [] - for client in app_state.websocket_clients: - try: - await client.send_json({"type": "metrics", "data": metrics}) - except Exception as e: - logger.warning(f"Failed to send metrics to client: {e}") - disconnected_clients.append(client) - - # Remove disconnected clients - for client in disconnected_clients: - app_state.websocket_clients.remove(client) +async def broadcast_rccl(snapshot: dict): + """Broadcast RCCL snapshot to /ws/rccl WebSocket clients (non-blocking).""" + rccl_ws_manager.broadcast({"type": "rccl_snapshot", "data": snapshot}) async def periodic_host_probe(): @@ -391,7 +539,11 @@ async def periodic_host_probe(): while app_state.is_collecting: try: - await asyncio.sleep(PROBE_INTERVAL) + try: + await asyncio.wait_for(app_state.probe_requested.wait(), timeout=300) + app_state.probe_requested.clear() + except asyncio.TimeoutError: + pass # normal 5-minute periodic probe if not app_state.ssh_manager: logger.debug("Skipping periodic probe - no SSH manager") @@ -458,10 +610,13 @@ async def lifespan(app: FastAPI): """Application lifespan manager.""" logger.info("Starting CVS Cluster Monitor") + # Initialize probe_requested event (before collectors start) + app_state.probe_requested = asyncio.Event() + # Load nodes from file nodes = settings.load_nodes_from_file() - # Initialize collectors (lightweight, no SSH needed) + # Also set legacy fields for backward-compat with existing API endpoints app_state.gpu_collector = GPUMetricsCollector() app_state.nic_collector = NICMetricsCollector() logger.info("Collectors initialized") @@ -479,7 +634,7 @@ async def lifespan(app: FastAPI): logger.info(f"Initializing with jump host: {settings.ssh.jump_host.host}") logger.info(f"Jump Host Username: {settings.ssh.jump_host.username}") logger.info(f"Cluster Nodes: {len(nodes)} nodes") - logger.info(f"Cluster Username: {settings.node_username_via_jumphost}") + logger.info(f"Cluster Username: {settings.ssh.jump_host.node_username}") app_state.ssh_manager = JumpHostPssh( jump_host=settings.ssh.jump_host.host, @@ -487,8 +642,8 @@ async def lifespan(app: FastAPI): jump_password=settings.ssh.jump_host.password, jump_pkey=settings.ssh.jump_host.key_file if not settings.ssh.jump_host.password else None, target_hosts=nodes, - target_user=settings.node_username_via_jumphost, - target_pkey=settings.node_key_file_on_jumphost, + target_user=settings.ssh.jump_host.node_username, + target_pkey=settings.ssh.jump_host.node_key_file, max_parallel=min(len(nodes), 5), timeout=settings.ssh.timeout, ) @@ -509,31 +664,79 @@ async def lifespan(app: FastAPI): ) logger.info("✅ Direct SSH manager initialized") - # Start metrics collection automatically if SSH manager initialized - if app_state.ssh_manager: - logger.info("Starting metrics collection and periodic probe...") - app_state.is_collecting = True - app_state.collection_task = asyncio.create_task(collect_metrics_loop()) - app_state.probe_task = asyncio.create_task(periodic_host_probe()) - logger.info("✅ Metrics collection started automatically") - except Exception as e: logger.error(f"Failed to auto-initialize SSH manager: {e}", exc_info=True) logger.warning("Will wait for manual configuration via web UI") + # Initialize Redis (optional — app continues without it) + try: + redis_kwargs = { + "db": settings.storage.redis.db, + "decode_responses": True, + } + if settings.storage.redis.password: + redis_kwargs["password"] = settings.storage.redis.password + app_state.redis = aioredis.from_url( + settings.storage.redis.url, + **redis_kwargs, + ) + await app_state.redis.ping() + logger.info(f"Redis connected: {settings.storage.redis.url}") + except Exception as e: + logger.warning(f"Redis unavailable: {e}. History features disabled.") + app_state.redis = None + + # Initialize RCCL data store (uses app_state.redis, degrades if None) + from app.collectors.rccl_data_store import RCCLDataStore + app_state.rccl_data_store = RCCLDataStore( + app_state.redis, + snapshot_max=settings.storage.redis.snapshot_max_entries, + event_max=settings.storage.redis.event_max_entries, + ) + + # Start metrics collection using unified collector registry + if app_state.ssh_manager: + logger.info("Starting metrics collection (BaseCollector pattern)...") + + # Pre-seed node_health_status so RCCL collector can pick a leader on its + # first poll cycle, before any GPU/NIC poll has completed. + startup_nodes = settings.load_nodes_from_file() + for node in startup_nodes: + if node not in app_state.node_health_status: + app_state.node_health_status[node] = "healthy" + app_state.node_failure_count[node] = 0 + + app_state.is_collecting = True + + for cls in REGISTERED_COLLECTORS: + c = cls() + app_state.collectors[c.name] = c + app_state.collector_tasks[c.name] = _start_collector_task(c) + + app_state.probe_task = asyncio.create_task(periodic_host_probe()) + logger.info("✅ Metrics collection started") + yield # Shutdown logger.info("Shutting down CVS Cluster Monitor") - # Stop metrics collection and periodic probe + # 1. Signal all background loops to stop accepting new work app_state.is_collecting = False - if app_state.collection_task: - app_state.collection_task.cancel() - try: - await app_state.collection_task - except asyncio.CancelledError: - pass + + # 2. Cancel collector tasks. Collectors call SSH inside asyncio.to_thread() + # which cannot be interrupted once the thread has started. Set client to + # None first so any in-flight thread that finishes and tries to issue the + # next SSH command gets the "no client" early-return rather than an + # AttributeError after destroy_clients() deletes the attribute. + if app_state.ssh_manager: + app_state.ssh_manager.client = None # type: ignore[assignment] + + for task in app_state.collector_tasks.values(): + task.cancel() + if app_state.collector_tasks: + await asyncio.gather(*app_state.collector_tasks.values(), return_exceptions=True) + if app_state.probe_task: app_state.probe_task.cancel() try: @@ -541,7 +744,11 @@ async def lifespan(app: FastAPI): except asyncio.CancelledError: pass - # Close SSH connections + # 3. Close Redis + if app_state.redis: + await app_state.redis.aclose() + + # 4. Destroy SSH connections (client already None, this cleans up port-forward state) if app_state.ssh_manager: app_state.ssh_manager.destroy_clients() @@ -568,31 +775,23 @@ async def lifespan(app: FastAPI): # WebSocket endpoint @app.websocket("/ws/metrics") -async def websocket_metrics(websocket: WebSocket): - """WebSocket endpoint for real-time metrics streaming.""" - await websocket.accept() - app_state.websocket_clients.append(websocket) - logger.info(f"WebSocket client connected. Total clients: {len(app_state.websocket_clients)}") - +async def websocket_endpoint(websocket: WebSocket): + await metrics_ws_manager.connect(websocket) try: - # Send initial metrics - if app_state.latest_metrics: - await websocket.send_json({"type": "metrics", "data": app_state.latest_metrics}) - - # Keep connection alive while True: - # Wait for client messages (ping/pong) - data = await websocket.receive_text() - if data == "ping": - await websocket.send_text("pong") + await websocket.receive_text() + except WebSocketDisconnect: + await metrics_ws_manager.disconnect(websocket) + +@app.websocket("/ws/rccl") +async def websocket_rccl(websocket: WebSocket): + await rccl_ws_manager.connect(websocket) + try: + while True: + await websocket.receive_text() except WebSocketDisconnect: - logger.info("WebSocket client disconnected") - except Exception as e: - logger.error(f"WebSocket error: {e}") - finally: - if websocket in app_state.websocket_clients: - app_state.websocket_clients.remove(websocket) + await rccl_ws_manager.disconnect(websocket) # Include API router FIRST (highest priority) @@ -607,7 +806,7 @@ async def health(): "status": "healthy", "ssh_manager": app_state.ssh_manager is not None, "collecting": app_state.is_collecting, - "clients": len(app_state.websocket_clients), + "clients": metrics_ws_manager.client_count, } @@ -628,7 +827,7 @@ async def root(): "name": settings.app_name, "version": "0.1.0", "status": "running", - "nodes": len(settings.nodes) if settings.nodes else 0, + "nodes": len(settings.load_nodes_from_file()), "collecting": app_state.is_collecting, "note": "Frontend not built. Run 'cd frontend && npm run build' to build the UI.", } diff --git a/cvs/monitors/cluster-mon/backend/app/models/rccl_models.py b/cvs/monitors/cluster-mon/backend/app/models/rccl_models.py new file mode 100644 index 00000000..96c97d4a --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/app/models/rccl_models.py @@ -0,0 +1,172 @@ +""" +Pydantic data models for RCCL monitoring. +RCCLJobState is the canonical definition — import from here, not from collectors. +""" + +import time +from enum import Enum +from typing import Optional +from pydantic import BaseModel + + +class RCCLJobState(str, Enum): + NO_JOB = "no_job" # Connection refused — no RCCL job running + UNREACHABLE = "unreachable" # SSH/TCP timeout — node down + HEALTHY = "healthy" # Job running, all communicators healthy + DEGRADED = "degraded" # Job running, some ranks missing or async errors + ERROR = "error" # Unexpected protocol error + + +class NCCLFunction(str, Enum): + """NCCL collective function names as they appear in VERBOSE STATUS text output. + + Used as typed dict keys after the text parser extracts named op counts from + lines like 'AllReduce=N AllGather=N ...'. Enum order does NOT affect correctness — + parsing is by string name, not by index. + """ + BROADCAST = "Broadcast" + REDUCE = "Reduce" + ALL_GATHER = "AllGather" + REDUCE_SCATTER = "ReduceScatter" + ALL_REDUCE = "AllReduce" + GATHER = "Gather" + SCATTER = "Scatter" + ALL_TO_ALL = "AllToAll" + ALL_TO_ALL_V = "AllToAllv" + SEND = "Send" + RECV = "Recv" + SEND_RECV = "SendRecv" + + +class RCCLRankStatus(BaseModel): + init_state: int # ncclResult_t value (0 = ncclSuccess) + async_error: int # ncclResult_t value + finalize_called: bool + destroy_flag: bool + abort_flag: bool + + +class RCCLRank(BaseModel): + comm_rank: int + node_addr: str # IP address of the node + pid: int + cuda_dev: int # CUDA device index (CUDA_VISIBLE_DEVICES-aware) + nvml_dev: int # NVML device index (raw hardware index) + coll_op_counts: dict[NCCLFunction, int] # Keyed by NCCLFunction string enum + status: RCCLRankStatus + + +class RCCLCommunicator(BaseModel): + comm_hash: str # Hex string of the 3-component commId hash + total_ranks: int # commNRanks from RAS collective + responding_ranks: int # nRanks — ranks we received data from + missing_ranks: int # nMissingRanks — declared missing by other ranks + ranks: list[RCCLRank] + health: RCCLJobState # Derived: HEALTHY/DEGRADED/ERROR + + +class RCCLPeer(BaseModel): + addr: str + pid: int + cuda_devs: int # Bitmask + nvml_devs: int # Bitmask + is_dead: bool + + +class RCCLJobSummary(BaseModel): + total_nodes: int + total_processes: int + total_gpus: int + rccl_version: str + hip_runtime_version: int + amdgpu_driver_version: int + inconsistent_topology: bool # True when nodes have different process/GPU counts + + +class RCCLSnapshot(BaseModel): + timestamp: float + state: RCCLJobState + job_summary: Optional[RCCLJobSummary] = None + communicators: list[RCCLCommunicator] = [] + peers: list[RCCLPeer] = [] + dead_peers: list[str] = [] # IP:port strings of declared-dead peers + errors: list[str] = [] # Raw error lines from the Errors section + + @classmethod + def empty(cls, state: RCCLJobState = RCCLJobState.NO_JOB) -> "RCCLSnapshot": + return cls(timestamp=time.time(), state=state) + + +class RCCLEvent(BaseModel): + timestamp: float + event_type: str # "lifecycle" or "trace" (Phase 3+) + source_node: str + details: str + peer_addr: Optional[str] = None + + +class RCCLMarker(BaseModel): + """Posted by the PyTorch callback via POST /api/rccl/markers.""" + type: str # e.g., "training_step" + step: int + loss: Optional[float] = None + rank: int + timestamp: str # ISO 8601 + + +# --------------------------------------------------------------------------- +# Inspector plugin models +# --------------------------------------------------------------------------- + +class InspectorKernelChannel(BaseModel): + """Per-channel kernel timing from NCCL_INSPECTOR_DUMP_VERBOSE=1.""" + channel_id: int + # Sequence numbers — monotonic counters across all profiler events + kernel_start_sn: Optional[int] = None + kernel_stop_sn: Optional[int] = None + kernel_record_sn: Optional[int] = None + # Raw timestamps — units depend on timing_source: + # kernel_gpu: GPU clock ticks (not directly convertible without clock freq) + # kernel_cpu / collective_cpu: microseconds + kernel_start_ts: Optional[int] = None + kernel_stop_ts: Optional[int] = None + kernel_record_ts: Optional[int] = None + + +class InspectorEventTrace(BaseModel): + """Verbose event trace from coll_perf.event_trace_sn / event_trace_ts.""" + coll_start_sn: Optional[int] = None + coll_stop_sn: Optional[int] = None + coll_start_ts: Optional[int] = None + coll_stop_ts: Optional[int] = None + channels: list[InspectorKernelChannel] = [] + + +class InspectorCollPerf(BaseModel): + """Single completed collective record from one Inspector JSONL line.""" + timestamp: float # dump_timestamp_us / 1e6 (Unix seconds) + comm_hash: str # header.id — communicator identity + rank: int # header.rank + nranks: int # header.n_ranks + nnodes: int # header.nnodes + hostname: str # metadata.hostname + pid: int # metadata.pid + collective: str # coll_perf.coll (e.g. "AllReduce") + sequence_num: int # coll_perf.coll_sn — monotonic counter per comm + msg_size_bytes: int # coll_perf.coll_msg_size_bytes + exec_time_us: int # coll_perf.coll_exec_time_us + timing_source: str # coll_perf.coll_timing_source + algo_bw_gbps: float # coll_perf.coll_algobw_gbs + bus_bw_gbps: float # coll_perf.coll_busbw_gbs + event_trace: Optional[InspectorEventTrace] = None # present only with NCCL_INSPECTOR_DUMP_VERBOSE=1 + + +class InspectorSnapshot(BaseModel): + """Aggregated Inspector data across all ranks collected in one poll cycle.""" + timestamp: float + records: list[InspectorCollPerf] + avg_bus_bw_gbps: Optional[float] = None + min_bus_bw_gbps: Optional[float] = None + max_bus_bw_gbps: Optional[float] = None + slowest_rank: Optional[int] = None + collective_breakdown: dict[str, int] = {} # collective name → count diff --git a/cvs/monitors/cluster-mon/backend/pytest.ini b/cvs/monitors/cluster-mon/backend/pytest.ini new file mode 100644 index 00000000..2f4c80e3 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/cvs/monitors/cluster-mon/backend/requirements.txt b/cvs/monitors/cluster-mon/backend/requirements.txt index 941df69d..858b6647 100644 --- a/cvs/monitors/cluster-mon/backend/requirements.txt +++ b/cvs/monitors/cluster-mon/backend/requirements.txt @@ -1,8 +1,8 @@ fastapi==0.109.0 uvicorn[standard]==0.27.0 websockets==12.0 -pydantic==2.5.3 -pydantic-settings==2.1.0 +pydantic==2.9.2 +pydantic-settings==2.4.0 parallel-ssh==2.12.0 paramiko==3.4.0 scp==0.14.5 diff --git a/cvs/monitors/cluster-mon/backend/tests/__init__.py b/cvs/monitors/cluster-mon/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cvs/monitors/cluster-mon/backend/tests/fixtures/inspector_sample.jsonl b/cvs/monitors/cluster-mon/backend/tests/fixtures/inspector_sample.jsonl new file mode 100644 index 00000000..a9b17b81 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/fixtures/inspector_sample.jsonl @@ -0,0 +1,7 @@ +{"header": {"id": "0x7f8c496ae9f661", "rank": 0, "n_ranks": 16, "nnodes": 2}, "metadata": {"inspector_output_format_version": "v4.0", "git_rev": "abc123def", "rec_mechanism": "nccl_profiler_interface", "dump_timestamp_us": 1711800000000000, "hostname": "gpu-node-01", "pid": 12345}, "coll_perf": {"coll": "AllReduce", "coll_sn": 4217, "coll_msg_size_bytes": 2097152, "coll_exec_time_us": 412, "coll_timing_source": "kernel_gpu", "coll_algobw_gbs": 193.6, "coll_busbw_gbs": 387.2}} +{"header": {"id": "0x7f8c496ae9f661", "rank": 1, "n_ranks": 16, "nnodes": 2}, "metadata": {"inspector_output_format_version": "v4.0", "git_rev": "abc123def", "rec_mechanism": "nccl_profiler_interface", "dump_timestamp_us": 1711800000000000, "hostname": "gpu-node-01", "pid": 12346}, "coll_perf": {"coll": "AllReduce", "coll_sn": 4217, "coll_msg_size_bytes": 2097152, "coll_exec_time_us": 445, "coll_timing_source": "kernel_gpu", "coll_algobw_gbs": 179.3, "coll_busbw_gbs": 358.6}} +{"header": {"id": "0x7f8c496ae9f661", "rank": 8, "n_ranks": 16, "nnodes": 2}, "metadata": {"inspector_output_format_version": "v4.0", "git_rev": "abc123def", "rec_mechanism": "nccl_profiler_interface", "dump_timestamp_us": 1711800000000000, "hostname": "gpu-node-02", "pid": 22345}, "coll_perf": {"coll": "AllReduce", "coll_sn": 4217, "coll_msg_size_bytes": 2097152, "coll_exec_time_us": 418, "coll_timing_source": "kernel_gpu", "coll_algobw_gbs": 191.0, "coll_busbw_gbs": 382.0}} +{"header": {"id": "0xaabbccdd11223344", "rank": 0, "n_ranks": 8, "nnodes": 1}, "metadata": {"inspector_output_format_version": "v4.0", "git_rev": "abc123def", "rec_mechanism": "nccl_profiler_interface", "dump_timestamp_us": 1711800005000000, "hostname": "gpu-node-01", "pid": 12345}, "coll_perf": {"coll": "ReduceScatter", "coll_sn": 891, "coll_msg_size_bytes": 524288, "coll_exec_time_us": 88, "coll_timing_source": "kernel_gpu", "coll_algobw_gbs": 47.8, "coll_busbw_gbs": 44.2}} +this line is malformed JSON and should be skipped silently +{"header": {"id": "0x7f8c496ae9f661", "rank": 2, "n_ranks": 16, "nnodes": 2}, "metadata": {"inspector_output_format_version": "v4.0", "git_rev": "abc123def", "rec_mechanism": "nccl_profiler_interface", "dump_timestamp_us": 1711800000000000, "hostname": "gpu-node-01", "pid": 12347}, "coll_perf": {"coll": "AllGather", "coll_sn": 102, "coll_msg_size_bytes": 131072, "coll_exec_time_us": 0, "coll_timing_source": "collective_cpu", "coll_algobw_gbs": 0.0, "coll_busbw_gbs": 0.0}} +{"header": {"id": "0x7f8c496ae9f661", "rank": 3, "n_ranks": 16, "nnodes": 2}, "metadata": {"inspector_output_format_version": "v4.0", "git_rev": "abc123def", "rec_mechanism": "nccl_profiler_interface", "dump_timestamp_us": 1711800000000000, "hostname": "gpu-node-02", "pid": 22346}, "coll_perf": {"coll": "AllReduce", "coll_sn": 4217, "coll_msg_size_bytes": 2097152, "coll_exec_time_us": 425, "coll_timing_source": "kernel_gpu", "coll_algobw_gbs": 187.4, "coll_busbw_gbs": 374.8, "event_trace_sn": {"coll_start_sn": 8430, "coll_stop_sn": 8445, "kernel_events": [{"channel_id": 0, "kernel_start_sn": 8431, "kernel_stop_sn": 8438, "kernel_record_sn": 8439}, {"channel_id": 1, "kernel_start_sn": 8432, "kernel_stop_sn": 8440, "kernel_record_sn": 8441}]}, "event_trace_ts": {"coll_start_ts": 1234567890, "coll_stop_ts": 1234568315, "kernel_events": [{"channel_id": 0, "kernel_start_ts": 1234567900, "kernel_stop_ts": 1234568200, "kernel_record_ts": 1234568210}, {"channel_id": 1, "kernel_start_ts": 1234567910, "kernel_stop_ts": 1234568300, "kernel_record_ts": 1234568310}]}}} diff --git a/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_connection_reset.txt b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_connection_reset.txt new file mode 100644 index 00000000..ea390f5e --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_connection_reset.txt @@ -0,0 +1 @@ +read socket: Connection reset by peer diff --git a/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_degraded.txt b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_degraded.txt new file mode 100644 index 00000000..ee940217 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_degraded.txt @@ -0,0 +1,30 @@ +RCCL version 2.28.3 compiled with ROCm "7.2.0.0-43-fc0010cf6a" +HIP runtime version 70226015, amdgpu driver version 70226015 + +Job summary +=========== + + Nodes Processes GPUs Processes GPUs +(total) per node per process (total) (total) + 1 8 1 8 8 + +Communicators... (0.00s) +============= + +Group Comms Nodes Ranks Ranks Ranks Status Errors + # in group per comm per node per comm in group + 0 1 1 7 8 7 RUNNING INCOMPLETE + +Errors +====== + +INCOMPLETE + Missing communicator data from 1 job process + Process 3467978 on node 10.194.132.77 managing GPU 6 + +#0-0 (3b2fe521bf43bc04) INCOMPLETE + Missing communicator data from 1 rank + Rank 6 -- GPU 6 managed by process 3467978 on node 10.194.132.77 + +Warnings +======== diff --git a/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_degraded_2node.txt b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_degraded_2node.txt new file mode 100644 index 00000000..a928686f --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_degraded_2node.txt @@ -0,0 +1,30 @@ +RCCL version 2.28.3 compiled with ROCm "7.0.2.0-56-9428210" +HIP runtime version 70051831, amdgpu driver version 70051831 + +Job summary +=========== + + Nodes Processes GPUs Processes GPUs +(total) per node per process (total) (total) + 2 8 1 16 16 + +Communicators... (0.00s) +============= + +Group Comms Nodes Ranks Ranks Ranks Status Errors + # in group per comm per node per comm in group + 0 1 2 7-8 16 15 RUNNING INCOMPLETE + +Errors +====== + +INCOMPLETE + Missing communicator data from 1 job process + Process 3534512 on node 10.245.40.180 managing GPU 3 + +#0-0 (488feb6f1c97a0e1) INCOMPLETE + Missing communicator data from 1 rank + Rank 3 -- GPU 3 managed by process 3534512 on node 10.245.40.180 + +Warnings +======== diff --git a/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_healthy.txt b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_healthy.txt new file mode 100644 index 00000000..35515a03 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/fixtures/rccl_verbose_status_healthy.txt @@ -0,0 +1,23 @@ +RCCL version 2.28.3 compiled with ROCm "7.2.0.0-43-fc0010cf6a" +HIP runtime version 70226015, amdgpu driver version 70226015 + +Job summary +=========== + + Nodes Processes GPUs Processes GPUs +(total) per node per process (total) (total) + 1 8 1 8 8 + +Communicators... (0.00s) +============= + +Group Comms Nodes Ranks Ranks Ranks Status Errors + # in group per comm per node per comm in group + 0 1 1 8 8 8 RUNNING OK + +Errors +====== + +Warnings +======== + diff --git a/cvs/monitors/cluster-mon/backend/tests/mock_rcclras_server.py b/cvs/monitors/cluster-mon/backend/tests/mock_rcclras_server.py new file mode 100644 index 00000000..0f38c520 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/mock_rcclras_server.py @@ -0,0 +1,57 @@ +""" +Mock rcclras TCP server for unit testing. +Replays a fixture response over TCP so collector tests run without a real RCCL job. +""" + +import asyncio +from typing import Optional + + +class MockRcclRasServer: + """Minimal rcclras server for unit testing. Replays a fixed fixture response.""" + + def __init__(self, fixture_data: bytes, protocol_version: int = 2): + self.fixture_data = fixture_data + self.protocol_version = protocol_version + self._server: Optional[asyncio.Server] = None + + async def handle( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + try: + line = await asyncio.wait_for(reader.readline(), timeout=2.0) + assert line.lower().startswith(b"client protocol"), \ + f"Expected CLIENT PROTOCOL, got: {line!r}" + + writer.write(f"SERVER PROTOCOL {self.protocol_version}\n".encode()) + await writer.drain() + + # Handle optional TIMEOUT command + line = await asyncio.wait_for(reader.readline(), timeout=2.0) + if line.lower().startswith(b"timeout"): + writer.write(b"OK\n") + await writer.drain() + line = await asyncio.wait_for(reader.readline(), timeout=2.0) + + # Expect STATUS or VERBOSE STATUS + assert b"status" in line.lower(), \ + f"Expected STATUS command, got: {line!r}" + + writer.write(self.fixture_data) + await writer.drain() + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + + async def start(self, host: str = "127.0.0.1", port: int = 0) -> int: + """Start the server. Returns the bound port (use port=0 for ephemeral).""" + self._server = await asyncio.start_server(self.handle, host, port) + return self._server.sockets[0].getsockname()[1] + + async def stop(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() diff --git a/cvs/monitors/cluster-mon/backend/tests/test_base_collector.py b/cvs/monitors/cluster-mon/backend/tests/test_base_collector.py new file mode 100644 index 00000000..361406a5 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_base_collector.py @@ -0,0 +1,214 @@ +""" +Tests for BaseCollector ABC, CollectorResult, CollectorState. +Uses TDD - written before/alongside implementation. +""" +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.collectors.base import ( + BaseCollector, + CollectorResult, + CollectorState, +) + + +# ── Concrete test subclass ────────────────────────────────────────────────── + + +class FakeCollector(BaseCollector): + name = "fake" + poll_interval = 1 + collect_timeout = 5.0 + critical = False + + def __init__(self, result: CollectorResult = None, raise_exc=None): + self._result = result or CollectorResult( + collector_name="fake", + timestamp=CollectorResult.now_iso(), + state=CollectorState.OK, + data={"value": 42}, + ) + self._raise_exc = raise_exc + self.collect_call_count = 0 + + async def collect(self, ssh_manager) -> CollectorResult: + self.collect_call_count += 1 + if self._raise_exc: + raise self._raise_exc + return self._result + + +class HangingCollector(BaseCollector): + name = "hanging" + poll_interval = 1 + collect_timeout = 0.1 # very short — will timeout + critical = False + + async def collect(self, ssh_manager) -> CollectorResult: + # Use an Event that is never set so collect() genuinely blocks, + # even if asyncio.sleep is patched in the test. + await asyncio.Event().wait() + return CollectorResult( + collector_name=self.name, + timestamp=CollectorResult.now_iso(), + state=CollectorState.OK, + data={}, + ) + + +# ── CollectorResult tests ─────────────────────────────────────────────────── + + +def test_collector_result_now_iso_is_utc(): + ts = CollectorResult.now_iso() + assert "T" in ts + assert ts.endswith("+00:00") or ts.endswith("Z") or "UTC" in ts or "+00" in ts + + +def test_collector_result_defaults(): + result = CollectorResult( + collector_name="gpu", + timestamp="2026-01-01T00:00:00+00:00", + state=CollectorState.OK, + data={"x": 1}, + ) + assert result.error is None + assert result.node_errors == {} + + +# ── CollectorState tests ──────────────────────────────────────────────────── + + +def test_collector_state_values(): + assert CollectorState.OK == "ok" + assert CollectorState.NO_SERVICE == "no_service" + assert CollectorState.UNREACHABLE == "unreachable" + assert CollectorState.ERROR == "error" + + +# ── BaseCollector.collect() is abstract ───────────────────────────────────── + + +def test_base_collector_is_abstract(): + with pytest.raises(TypeError): + BaseCollector() # cannot instantiate abstract class + + +# ── BaseCollector.run() — timeout enforcement ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_run_times_out_and_produces_error_result(): + """If collect() hangs beyond collect_timeout, run() produces an ERROR result.""" + collector = HangingCollector() + ssh_manager = MagicMock() + + app_state = MagicMock() + app_state.is_collecting = True + app_state.collector_results = {} + app_state.latest_metrics = {} + app_state.probe_requested = None # no probe event + + # Run one iteration then stop + call_count = 0 + original_sleep = asyncio.sleep + + async def stop_after_one(seconds): + nonlocal call_count + call_count += 1 + app_state.is_collecting = False + await original_sleep(0) + + with patch("asyncio.sleep", side_effect=stop_after_one): + with patch("app.collectors.base._update_node_status_via_app_state"): + # broadcast_metrics is imported lazily inside run(); the import + # will fail in the test environment, but that's caught by the + # try/except inside run(). No need to patch it. + task = asyncio.create_task(collector.run(ssh_manager, app_state)) + await original_sleep(0.5) # let it run + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # The collector_results should have an ERROR entry from timeout + assert "hanging" in app_state.collector_results, \ + "Expected collector_results to have 'hanging' entry after timeout" + result = app_state.collector_results["hanging"] + assert result.state == CollectorState.ERROR + assert "timed out" in result.error.lower() + + +@pytest.mark.asyncio +async def test_run_cancelled_error_propagates(): + """CancelledError must propagate out of run() without being swallowed.""" + collector = FakeCollector() + ssh_manager = MagicMock() + app_state = MagicMock() + app_state.is_collecting = True + app_state.collector_results = {} + app_state.latest_metrics = {} + app_state.probe_requested = None + + task = asyncio.create_task(collector.run(ssh_manager, app_state)) + await asyncio.sleep(0.01) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_run_connection_error_sets_probe_requested(): + """ConnectionError in collect() should call probe_requested.set().""" + collector = FakeCollector(raise_exc=ConnectionError("SSH timeout")) + ssh_manager = MagicMock() + + probe_event = asyncio.Event() + app_state = MagicMock() + app_state.is_collecting = True + app_state.collector_results = {} + app_state.latest_metrics = {} + app_state.probe_requested = probe_event + + task = asyncio.create_task(collector.run(ssh_manager, app_state)) + await asyncio.sleep(0.05) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert probe_event.is_set() + + +# ── Config tests ───────────────────────────────────────────────────────────── + + +def test_settings_defaults(): + from app.core.config import Settings + s = Settings() + assert s.polling.interval == 60 + assert s.polling.failure_threshold == 5 + assert s.rccl.ras_port == 28028 + assert s.storage.redis.url == "redis://localhost:6379" + + +def test_settings_backward_compat_properties(): + from app.core.config import Settings + s = Settings() + # These properties must exist for the existing main.py to keep working + assert hasattr(s, 'node_username_via_jumphost') + assert hasattr(s, 'node_key_file_on_jumphost') + assert hasattr(s, 'ssh_username') + + +def test_collector_state_str_enum(): + # CollectorState must be usable as a string (for JSON serialization) + # In Python 3.10, str() on (str, Enum) returns "ClassName.MEMBER", + # but the value and f-string formatting yield the raw string. + assert CollectorState.OK.value == "ok" + assert f"{CollectorState.OK}" == "ok" + # Enum equality with str works because CollectorState inherits from str + assert CollectorState.OK == "ok" diff --git a/cvs/monitors/cluster-mon/backend/tests/test_collectors.py b/cvs/monitors/cluster-mon/backend/tests/test_collectors.py new file mode 100644 index 00000000..02382de2 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_collectors.py @@ -0,0 +1,81 @@ +""" +Tests for GPUMetricsCollector and NICMetricsCollector as BaseCollector subclasses. +""" +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.collectors.base import CollectorState, CollectorResult +from app.collectors.gpu_collector import GPUMetricsCollector +from app.collectors.nic_collector import NICMetricsCollector + + +def test_gpu_collector_has_base_collector_attrs(): + assert hasattr(GPUMetricsCollector, 'name') + assert GPUMetricsCollector.name == "gpu" + assert hasattr(GPUMetricsCollector, 'poll_interval') + assert hasattr(GPUMetricsCollector, 'collect_timeout') + assert GPUMetricsCollector.critical is True + + +def test_nic_collector_has_base_collector_attrs(): + assert hasattr(NICMetricsCollector, 'name') + assert NICMetricsCollector.name == "nic" + assert hasattr(NICMetricsCollector, 'poll_interval') + assert hasattr(NICMetricsCollector, 'collect_timeout') + assert NICMetricsCollector.critical is True + + +@pytest.mark.asyncio +async def test_gpu_collector_collect_returns_collector_result(): + collector = GPUMetricsCollector() + ssh_manager = MagicMock() + + # Mock collect_all_metrics to return a simple metrics dict + fake_metrics = { + "timestamp": "2026-01-01T00:00:00+00:00", + "utilization": {"node1": {"gpu0": 80}}, + } + collector.collect_all_metrics = AsyncMock(return_value=fake_metrics) + + result = await collector.collect(ssh_manager) + + assert isinstance(result, CollectorResult) + assert result.collector_name == "gpu" + assert result.state == CollectorState.OK + assert result.data == fake_metrics + + +@pytest.mark.asyncio +async def test_gpu_collector_collect_handles_exception(): + collector = GPUMetricsCollector() + ssh_manager = MagicMock() + collector.collect_all_metrics = AsyncMock(side_effect=RuntimeError("SSH failed")) + + result = await collector.collect(ssh_manager) + + assert result.state == CollectorState.ERROR + assert "SSH failed" in result.error + + +@pytest.mark.asyncio +async def test_nic_collector_collect_returns_collector_result(): + collector = NICMetricsCollector() + ssh_manager = MagicMock() + fake_metrics = {"rdma_links": {"node1": {}}} + collector.collect_all_metrics = AsyncMock(return_value=fake_metrics) + + result = await collector.collect(ssh_manager) + + assert isinstance(result, CollectorResult) + assert result.collector_name == "nic" + assert result.state == CollectorState.OK + + +def test_registered_collectors_list(): + pytest.importorskip("fastapi") + from app.main import REGISTERED_COLLECTORS + from app.collectors.base import BaseCollector + assert len(REGISTERED_COLLECTORS) >= 2 + for cls in REGISTERED_COLLECTORS: + assert issubclass(cls, BaseCollector) diff --git a/cvs/monitors/cluster-mon/backend/tests/test_collectors_api.py b/cvs/monitors/cluster-mon/backend/tests/test_collectors_api.py new file mode 100644 index 00000000..d4e0ec11 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_collectors_api.py @@ -0,0 +1,66 @@ +""" +Tests for GET /api/collectors/status endpoint logic. +""" +import pytest +from unittest.mock import MagicMock + +from app.api.collectors import _compute_overall_status +from app.collectors.base import CollectorResult, CollectorState + + +def _make_result(state: CollectorState, name: str = "gpu") -> CollectorResult: + return CollectorResult( + collector_name=name, + timestamp="2026-01-01T00:00:00+00:00", + state=state, + data={}, + ) + + +def test_overall_status_healthy_when_all_ok(): + results = { + "gpu": _make_result(CollectorState.OK, "gpu"), + "nic": _make_result(CollectorState.OK, "nic"), + } + meta = {"gpu": {"critical": True}, "nic": {"critical": True}} + assert _compute_overall_status(results, meta) == "healthy" + + +def test_overall_status_healthy_when_no_results(): + assert _compute_overall_status({}, {}) == "healthy" + + +def test_overall_status_healthy_with_no_service(): + """NO_SERVICE (e.g., no RCCL job) is not an error — still healthy.""" + results = { + "gpu": _make_result(CollectorState.OK, "gpu"), + "rccl": _make_result(CollectorState.NO_SERVICE, "rccl"), + } + meta = {"gpu": {"critical": True}, "rccl": {"critical": False}} + assert _compute_overall_status(results, meta) == "healthy" + + +def test_overall_status_critical_when_critical_collector_errors(): + results = { + "gpu": _make_result(CollectorState.ERROR, "gpu"), + "nic": _make_result(CollectorState.OK, "nic"), + } + meta = {"gpu": {"critical": True}, "nic": {"critical": True}} + assert _compute_overall_status(results, meta) == "critical" + + +def test_overall_status_degraded_when_non_critical_errors(): + results = { + "gpu": _make_result(CollectorState.OK, "gpu"), + "rccl": _make_result(CollectorState.ERROR, "rccl"), + } + meta = {"gpu": {"critical": True}, "rccl": {"critical": False}} + assert _compute_overall_status(results, meta) == "degraded" + + +def test_overall_status_critical_on_unreachable_critical(): + results = { + "gpu": _make_result(CollectorState.UNREACHABLE, "gpu"), + } + meta = {"gpu": {"critical": True}} + assert _compute_overall_status(results, meta) == "critical" diff --git a/cvs/monitors/cluster-mon/backend/tests/test_config.py b/cvs/monitors/cluster-mon/backend/tests/test_config.py new file mode 100644 index 00000000..312e6a59 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_config.py @@ -0,0 +1,74 @@ +"""Tests for the new pydantic Settings config.""" +import pytest +from app.core.config import ( + Settings, + JumpHostConfig, + SSHConfig, + PollingConfig, + RCCLConfig, + StorageConfig, + RedisConfig, +) + + +def test_jump_host_config_defaults(): + cfg = JumpHostConfig() + assert cfg.enabled is False + assert cfg.node_username == "root" + assert cfg.node_key_file == "~/.ssh/id_rsa" + + +def test_ssh_config_has_jump_host(): + cfg = SSHConfig() + assert isinstance(cfg.jump_host, JumpHostConfig) + assert cfg.timeout == 30 + + +def test_polling_config_defaults(): + cfg = PollingConfig() + assert cfg.interval == 60 + assert cfg.failure_threshold == 5 + + +def test_rccl_config_defaults(): + cfg = RCCLConfig() + assert cfg.ras_port == 28028 + assert cfg.poll_interval == 30 + assert cfg.collective_timeout_secs == 10 + assert cfg.debug_log_path is None + + +def test_storage_redis_config_defaults(): + cfg = StorageConfig() + assert cfg.redis.url == "redis://localhost:6379" + assert cfg.redis.db == 0 + assert cfg.redis.snapshot_max_entries == 1000 + assert cfg.redis.event_max_entries == 10000 + + +def test_settings_has_all_sections(): + s = Settings() + assert hasattr(s, 'ssh') + assert hasattr(s, 'polling') + assert hasattr(s, 'alerts') + assert hasattr(s, 'storage') + assert hasattr(s, 'rccl') + + +def test_settings_env_nested_delimiter(monkeypatch): + """Verify env_nested_delimiter is set so POLLING__INTERVAL=30 works.""" + from app.core.config import Settings + monkeypatch.setenv('POLLING__INTERVAL', '45') + s = Settings() + assert s.polling.interval == 45 + # monkeypatch automatically cleans up after the test + + +def test_settings_load_nodes_from_file_missing(): + """load_nodes_from_file returns [] when no file exists.""" + from app.core.config import Settings + s = Settings() + # In test environment, no nodes.txt at the expected paths + # Result should be [] (not an exception) + result = s.load_nodes_from_file() + assert isinstance(result, list) diff --git a/cvs/monitors/cluster-mon/backend/tests/test_inspector_parser.py b/cvs/monitors/cluster-mon/backend/tests/test_inspector_parser.py new file mode 100644 index 00000000..84927b5a --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_inspector_parser.py @@ -0,0 +1,377 @@ +""" +Unit tests for InspectorParser and aggregate_snapshot. +""" + +import json +from pathlib import Path + +import pytest + +from app.collectors.inspector_parser import InspectorParser, aggregate_snapshot +from app.models.rccl_models import InspectorCollPerf, InspectorEventTrace + + +FIXTURE_PATH = Path(__file__).parent / "fixtures" / "inspector_sample.jsonl" + + +# --------------------------------------------------------------------------- +# parse_lines +# --------------------------------------------------------------------------- + +class TestParseLines: + def setup_method(self): + self.parser = InspectorParser() + + def test_parses_valid_allreduce_record(self): + line = json.dumps({ + "header": {"id": "0xabc", "rank": 0, "n_ranks": 8, "nnodes": 1}, + "metadata": { + "inspector_output_format_version": "v4.0", + "git_rev": "deadbeef", + "rec_mechanism": "nccl_profiler_interface", + "dump_timestamp_us": 1_711_800_000_000_000, + "hostname": "gpu-node-01", + "pid": 9999, + }, + "coll_perf": { + "coll": "AllReduce", + "coll_sn": 1, + "coll_msg_size_bytes": 1048576, + "coll_exec_time_us": 200, + "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 5.0, + "coll_busbw_gbs": 9.375, + }, + }) + records = self.parser.parse_lines(line) + assert len(records) == 1 + r = records[0] + assert r.comm_hash == "0xabc" + assert r.rank == 0 + assert r.nranks == 8 + assert r.nnodes == 1 + assert r.hostname == "gpu-node-01" + assert r.pid == 9999 + assert r.collective == "AllReduce" + assert r.sequence_num == 1 + assert r.msg_size_bytes == 1048576 + assert r.exec_time_us == 200 + assert r.timing_source == "kernel_gpu" + assert r.algo_bw_gbps == pytest.approx(5.0) + assert r.bus_bw_gbps == pytest.approx(9.375) + assert r.timestamp == pytest.approx(1_711_800_000.0) + + def test_skips_malformed_json_silently(self): + text = "not json at all\n" + json.dumps({ + "header": {"id": "0x1", "rank": 0, "n_ranks": 2, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 1000000, "hostname": "h", "pid": 1}, + "coll_perf": {"coll": "AllReduce", "coll_sn": 1, "coll_msg_size_bytes": 64, + "coll_exec_time_us": 10, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 1.0, "coll_busbw_gbs": 1.875}, + }) + records = self.parser.parse_lines(text) + assert len(records) == 1 + + def test_skips_missing_field_silently(self): + # Missing coll_perf entirely + line = json.dumps({"header": {"id": "0x1", "rank": 0}, "metadata": {}}) + records = self.parser.parse_lines(line) + assert records == [] + + def test_empty_text_returns_empty(self): + assert self.parser.parse_lines("") == [] + + def test_blank_lines_skipped(self): + assert self.parser.parse_lines("\n\n \n") == [] + + def test_timestamp_conversion(self): + """dump_timestamp_us is microseconds; should convert to Unix seconds.""" + line = json.dumps({ + "header": {"id": "0x1", "rank": 0, "n_ranks": 1, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 2_000_000_000_000, "hostname": "h", "pid": 1}, + "coll_perf": {"coll": "AllReduce", "coll_sn": 1, "coll_msg_size_bytes": 64, + "coll_exec_time_us": 10, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 1.0, "coll_busbw_gbs": 1.875}, + }) + records = self.parser.parse_lines(line) + assert records[0].timestamp == pytest.approx(2_000_000.0) + + def test_zero_exec_time_record_parsed(self): + """Zero exec_time (timing fallback) should parse, not be skipped.""" + line = json.dumps({ + "header": {"id": "0x1", "rank": 0, "n_ranks": 1, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 1000000, "hostname": "h", "pid": 1}, + "coll_perf": {"coll": "AllGather", "coll_sn": 5, "coll_msg_size_bytes": 128, + "coll_exec_time_us": 0, "coll_timing_source": "collective_cpu", + "coll_algobw_gbs": 0.0, "coll_busbw_gbs": 0.0}, + }) + records = self.parser.parse_lines(line) + assert len(records) == 1 + assert records[0].exec_time_us == 0 + assert records[0].bus_bw_gbps == 0.0 + + def test_multiple_valid_lines(self): + lines = [] + for rank in range(4): + lines.append(json.dumps({ + "header": {"id": "0xfeed", "rank": rank, "n_ranks": 4, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 5_000_000_000_000, "hostname": f"node{rank}", "pid": rank + 100}, + "coll_perf": {"coll": "ReduceScatter", "coll_sn": 10, "coll_msg_size_bytes": 256, + "coll_exec_time_us": 50 + rank * 5, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 2.0, "coll_busbw_gbs": 1.875}, + })) + records = self.parser.parse_lines("\n".join(lines)) + assert len(records) == 4 + assert {r.rank for r in records} == {0, 1, 2, 3} + + +# --------------------------------------------------------------------------- +# parse_file +# --------------------------------------------------------------------------- + +class TestParseFile: + def setup_method(self): + self.parser = InspectorParser() + + def test_parses_fixture_file(self): + records = self.parser.parse_file(FIXTURE_PATH) + # fixture has 5 valid lines + 1 malformed + 1 zero-bw + # All valid JSON lines should parse (including zero-bw) + assert len(records) >= 4 + + def test_fixture_contains_allreduce_records(self): + records = self.parser.parse_file(FIXTURE_PATH) + collectives = {r.collective for r in records} + assert "AllReduce" in collectives + + def test_fixture_contains_multiple_hosts(self): + records = self.parser.parse_file(FIXTURE_PATH) + hosts = {r.hostname for r in records} + assert len(hosts) >= 2 # gpu-node-01 and gpu-node-02 + + def test_missing_file_returns_empty(self): + records = self.parser.parse_file(Path("/nonexistent/path/inspector.log")) + assert records == [] + + def test_tail_limits_records(self, tmp_path): + """tail=2 should only parse the last 2 lines.""" + lines = [] + for i in range(10): + lines.append(json.dumps({ + "header": {"id": "0x1", "rank": i, "n_ranks": 10, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 1000000, "hostname": "h", "pid": i + 1}, + "coll_perf": {"coll": "AllReduce", "coll_sn": i, "coll_msg_size_bytes": 128, + "coll_exec_time_us": 10, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 1.0, "coll_busbw_gbs": 1.875}, + })) + log_file = tmp_path / "test.log" + log_file.write_text("\n".join(lines)) + records = self.parser.parse_file(log_file, tail=2) + assert len(records) == 2 + assert records[0].rank == 8 + assert records[1].rank == 9 + + +# --------------------------------------------------------------------------- +# aggregate_snapshot +# --------------------------------------------------------------------------- + +def _make_record(rank: int, bus_bw: float, collective: str = "AllReduce") -> InspectorCollPerf: + return InspectorCollPerf( + timestamp=1711800000.0, + comm_hash="0xtest", + rank=rank, + nranks=4, + nnodes=1, + hostname=f"node{rank}", + pid=1000 + rank, + collective=collective, + sequence_num=1, + msg_size_bytes=1048576, + exec_time_us=100 if bus_bw > 0 else 0, + timing_source="kernel_gpu", + algo_bw_gbps=bus_bw / 1.875, + bus_bw_gbps=bus_bw, + ) + + +class TestAggregateSnapshot: + def test_avg_min_max_computed(self): + records = [_make_record(i, bw) for i, bw in enumerate([300.0, 350.0, 400.0, 200.0])] + snap = aggregate_snapshot(records) + assert snap.avg_bus_bw_gbps == pytest.approx(312.5) + assert snap.min_bus_bw_gbps == pytest.approx(200.0) + assert snap.max_bus_bw_gbps == pytest.approx(400.0) + + def test_slowest_rank_identified(self): + records = [_make_record(i, bw) for i, bw in enumerate([300.0, 200.0, 350.0, 400.0])] + snap = aggregate_snapshot(records) + assert snap.slowest_rank == 1 # rank 1 has 200 GB/s + + def test_zero_bw_excluded_from_stats(self): + """Zero-bw records (exec_time=0) should not affect bandwidth statistics.""" + records = [ + _make_record(0, 300.0), + _make_record(1, 0.0), # timing fallback — excluded + _make_record(2, 400.0), + ] + snap = aggregate_snapshot(records) + assert snap.avg_bus_bw_gbps == pytest.approx(350.0) # (300+400)/2, not /3 + + def test_collective_breakdown_counts(self): + records = [ + _make_record(0, 300.0, "AllReduce"), + _make_record(1, 300.0, "AllReduce"), + _make_record(2, 200.0, "ReduceScatter"), + ] + snap = aggregate_snapshot(records) + assert snap.collective_breakdown["AllReduce"] == 2 + assert snap.collective_breakdown["ReduceScatter"] == 1 + + def test_empty_records_returns_valid_snapshot(self): + snap = aggregate_snapshot([]) + assert snap.avg_bus_bw_gbps is None + assert snap.min_bus_bw_gbps is None + assert snap.max_bus_bw_gbps is None + assert snap.slowest_rank is None + assert snap.collective_breakdown == {} + assert snap.records == [] + + def test_all_zero_bw_returns_none_stats(self): + records = [_make_record(i, 0.0) for i in range(3)] + snap = aggregate_snapshot(records) + assert snap.avg_bus_bw_gbps is None + assert snap.slowest_rank is None + + def test_records_deduplicated_to_latest_per_rank(self): + """records field should contain only the highest sequence_num per (rank, comm_hash).""" + def _make_seq(rank: int, seq: int, bw: float) -> InspectorCollPerf: + r = _make_record(rank, bw) + return r.model_copy(update={"sequence_num": seq}) + + records = [ + _make_seq(0, 10, 300.0), + _make_seq(0, 11, 320.0), # newer — should win + _make_seq(1, 10, 280.0), + _make_seq(1, 9, 260.0), # older — should lose + ] + snap = aggregate_snapshot(records) + # Only 2 display records (one per rank) + assert len(snap.records) == 2 + by_rank = {r.rank: r for r in snap.records} + assert by_rank[0].sequence_num == 11 + assert by_rank[1].sequence_num == 10 + # Stats still computed from all 4 records + assert snap.avg_bus_bw_gbps == pytest.approx((300 + 320 + 280 + 260) / 4) + + # ------------------------------------------------------------------ + # Verbose / event_trace tests + # ------------------------------------------------------------------ + +class TestVerboseParsing: + def setup_method(self): + self.parser = InspectorParser() + + def _verbose_line(self, rank=0, n_channels=2) -> str: + kernel_sn = [ + {"channel_id": ch, "kernel_start_sn": 100 + ch, "kernel_stop_sn": 110 + ch, "kernel_record_sn": 111 + ch} + for ch in range(n_channels) + ] + kernel_ts = [ + {"channel_id": ch, "kernel_start_ts": 1000 + ch * 100, "kernel_stop_ts": 1050 + ch * 100, "kernel_record_ts": 1051 + ch * 100} + for ch in range(n_channels) + ] + return json.dumps({ + "header": {"id": "0xverb", "rank": rank, "n_ranks": 4, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 2_000_000_000_000, "hostname": f"node{rank}", "pid": rank + 1}, + "coll_perf": { + "coll": "AllReduce", "coll_sn": 1, "coll_msg_size_bytes": 1048576, + "coll_exec_time_us": 200, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 5.0, "coll_busbw_gbs": 9.375, + "event_trace_sn": {"coll_start_sn": 99, "coll_stop_sn": 120, "kernel_events": kernel_sn}, + "event_trace_ts": {"coll_start_ts": 900, "coll_stop_ts": 1200, "kernel_events": kernel_ts}, + }, + }) + + def test_non_verbose_record_has_no_event_trace(self): + line = json.dumps({ + "header": {"id": "0x1", "rank": 0, "n_ranks": 1, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 1000000, "hostname": "h", "pid": 1}, + "coll_perf": {"coll": "AllReduce", "coll_sn": 1, "coll_msg_size_bytes": 64, + "coll_exec_time_us": 10, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 1.0, "coll_busbw_gbs": 1.875}, + }) + records = self.parser.parse_lines(line) + assert records[0].event_trace is None + + def test_verbose_record_has_event_trace(self): + records = self.parser.parse_lines(self._verbose_line()) + assert records[0].event_trace is not None + + def test_verbose_coll_level_sn_and_ts(self): + records = self.parser.parse_lines(self._verbose_line()) + et = records[0].event_trace + assert et.coll_start_sn == 99 + assert et.coll_stop_sn == 120 + assert et.coll_start_ts == 900 + assert et.coll_stop_ts == 1200 + + def test_verbose_channel_count(self): + records = self.parser.parse_lines(self._verbose_line(n_channels=3)) + assert len(records[0].event_trace.channels) == 3 + + def test_verbose_channel_fields(self): + records = self.parser.parse_lines(self._verbose_line(n_channels=2)) + ch0 = records[0].event_trace.channels[0] + assert ch0.channel_id == 0 + assert ch0.kernel_start_sn == 100 + assert ch0.kernel_stop_sn == 110 + assert ch0.kernel_record_sn == 111 + assert ch0.kernel_start_ts == 1000 + assert ch0.kernel_stop_ts == 1050 + assert ch0.kernel_record_ts == 1051 + + def test_verbose_channels_sorted_by_channel_id(self): + records = self.parser.parse_lines(self._verbose_line(n_channels=4)) + ids = [ch.channel_id for ch in records[0].event_trace.channels] + assert ids == sorted(ids) + + def test_fixture_contains_verbose_record(self): + records = InspectorParser().parse_file(FIXTURE_PATH) + verbose = [r for r in records if r.event_trace is not None] + assert len(verbose) >= 1 + assert verbose[0].event_trace.channels[0].channel_id == 0 + assert verbose[0].event_trace.channels[1].channel_id == 1 + + def test_malformed_event_trace_falls_back_to_none(self): + """Malformed event_trace content should not fail the whole record.""" + line = json.dumps({ + "header": {"id": "0x1", "rank": 0, "n_ranks": 1, "nnodes": 1}, + "metadata": {"dump_timestamp_us": 1000000, "hostname": "h", "pid": 1}, + "coll_perf": { + "coll": "AllReduce", "coll_sn": 1, "coll_msg_size_bytes": 64, + "coll_exec_time_us": 10, "coll_timing_source": "kernel_gpu", + "coll_algobw_gbs": 1.0, "coll_busbw_gbs": 1.875, + "event_trace_sn": "not_an_object", # malformed + }, + }) + records = self.parser.parse_lines(line) + assert len(records) == 1 + assert records[0].event_trace is None + + + def test_collective_breakdown_uses_all_records(self): + """collective_breakdown counts ALL records, not just deduplicated ones.""" + def _make_seq(rank: int, seq: int, coll: str) -> InspectorCollPerf: + r = _make_record(rank, 300.0, coll) + return r.model_copy(update={"sequence_num": seq}) + + records = [ + _make_seq(0, 1, "AllReduce"), + _make_seq(0, 2, "AllReduce"), # same rank, newer seq + _make_seq(1, 1, "ReduceScatter"), + ] + snap = aggregate_snapshot(records) + assert snap.collective_breakdown["AllReduce"] == 2 + assert snap.collective_breakdown["ReduceScatter"] == 1 + assert len(snap.records) == 2 # deduplicated diff --git a/cvs/monitors/cluster-mon/backend/tests/test_rccl_collector.py b/cvs/monitors/cluster-mon/backend/tests/test_rccl_collector.py new file mode 100644 index 00000000..96938ba5 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_rccl_collector.py @@ -0,0 +1,135 @@ +"""Tests for RCCLCollector.""" +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.collectors.base import CollectorState +from app.collectors.rccl_collector import RCCLCollector +from app.models.rccl_models import RCCLJobState + + +def _make_app_state(healthy_nodes=None): + app_state = MagicMock() + app_state.node_health_status = {n: "healthy" for n in (healthy_nodes or ["node1"])} + app_state.rccl_data_store = None + app_state.latest_rccl_snapshot = None + app_state.collector_results = {} + app_state.latest_metrics = {} + app_state.is_collecting = False + app_state.probe_requested = None + return app_state + + +def test_rccl_collector_attrs(): + assert RCCLCollector.name == "rccl" + assert RCCLCollector.critical is False + assert hasattr(RCCLCollector, 'poll_interval') + assert hasattr(RCCLCollector, 'collect_timeout') + + +def test_healthy_nodes_returns_all_healthy(): + collector = RCCLCollector() + app_state = _make_app_state(["node1", "node2"]) + nodes = collector._healthy_nodes(app_state) + assert set(nodes) == {"node1", "node2"} + + +def test_healthy_nodes_returns_empty_when_all_unhealthy(): + collector = RCCLCollector() + app_state = MagicMock() + app_state.node_health_status = {"node1": "unhealthy", "node2": "unreachable"} + assert collector._healthy_nodes(app_state) == [] + + +@pytest.mark.asyncio +async def test_collect_returns_unreachable_when_no_healthy_nodes(): + collector = RCCLCollector() + collector._app_state = _make_app_state() + collector._app_state.node_health_status = {} # no nodes + + ssh_manager = MagicMock() + result = await collector.collect(ssh_manager) + assert result.state == CollectorState.UNREACHABLE + assert collector.job_state == RCCLJobState.UNREACHABLE + + +@pytest.mark.asyncio +async def test_collect_returns_no_service_on_connection_refused(): + collector = RCCLCollector() + app_state = _make_app_state(["node1"]) + collector._app_state = app_state + + ssh_manager = MagicMock() + ssh_mock_ctx = AsyncMock() + ssh_mock_ctx.__aenter__ = AsyncMock(side_effect=ConnectionRefusedError("refused")) + ssh_mock_ctx.__aexit__ = AsyncMock(return_value=False) + ssh_manager.open_port_forward = MagicMock(return_value=ssh_mock_ctx) + + result = await collector.collect(ssh_manager) + assert result.state == CollectorState.NO_SERVICE + assert collector.job_state == RCCLJobState.NO_JOB + + +@pytest.mark.asyncio +async def test_collect_returns_error_when_app_state_not_set(): + collector = RCCLCollector() + # _app_state is None (run() not called) + result = await collector.collect(MagicMock()) + assert result.state == CollectorState.ERROR + + +def test_health_from_snapshot_healthy(): + from app.models.rccl_models import RCCLSnapshot, RCCLCommunicator, RCCLJobState + collector = RCCLCollector() + snapshot = RCCLSnapshot( + timestamp=1.0, + state=RCCLJobState.HEALTHY, + communicators=[], + ) + assert collector._health_from_snapshot(snapshot) == RCCLJobState.HEALTHY + + +@pytest.mark.asyncio +async def test_state_change_event_emitted_on_job_start(): + """Transition NO_JOB → HEALTHY should push a job_start event.""" + collector = RCCLCollector() + collector.job_state = RCCLJobState.NO_JOB + + data_store = MagicMock() + data_store.push_event = AsyncMock() + + app_state = MagicMock() + app_state.rccl_data_store = data_store + + await collector._push_state_event( + RCCLJobState.NO_JOB, RCCLJobState.HEALTHY, app_state, leader="node1" + ) + + data_store.push_event.assert_called_once() + event = data_store.push_event.call_args[0][0] + assert event["event_type"] == "job_start" + assert event["from_state"] == RCCLJobState.NO_JOB + assert event["to_state"] == RCCLJobState.HEALTHY + assert event["leader_node"] == "node1" + + +@pytest.mark.asyncio +async def test_no_event_when_state_unchanged(): + """No event should be emitted when state does not change.""" + collector = RCCLCollector() + data_store = MagicMock() + data_store.push_event = AsyncMock() + app_state = MagicMock() + app_state.rccl_data_store = data_store + + await collector._push_state_event( + RCCLJobState.HEALTHY, RCCLJobState.HEALTHY, app_state + ) + data_store.push_event.assert_not_called() + + +def test_rccl_endpoints_importable(): + from app.api.rccl_endpoints import router + routes = [r.path for r in router.routes] + assert any("status" in r for r in routes) + assert any("markers" in r for r in routes) diff --git a/cvs/monitors/cluster-mon/backend/tests/test_rccl_ras_client.py b/cvs/monitors/cluster-mon/backend/tests/test_rccl_ras_client.py new file mode 100644 index 00000000..d97e495f --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_rccl_ras_client.py @@ -0,0 +1,130 @@ +""" +Tests for RCCLRasClient against MockRcclRasServer. +""" +import asyncio +import pytest + +from app.collectors.rccl_ras_client import ( + RCCLRasClient, + ProtocolError, + ProtocolVersionError, + ProtocolVersion, +) +from tests.mock_rcclras_server import MockRcclRasServer + + +SAMPLE_STATUS = b"""RCCL version 2.28.3 compiled with ROCm 6.0 + +Job summary +=========== + + Nodes Processes +(total) per node + 2 8 + +Communicator abc123 HEALTHY + Ranks: 16 total, 16 responding, 0 missing +""" + + +@pytest.fixture +async def mock_server(): + server = MockRcclRasServer(fixture_data=SAMPLE_STATUS, protocol_version=2) + port = await server.start() + yield port + await server.stop() + + +async def _connect(port: int): + reader, writer = await asyncio.open_connection("127.0.0.1", port) + return RCCLRasClient(reader, writer) + + +@pytest.mark.asyncio +async def test_handshake_returns_server_version(mock_server): + client = await _connect(mock_server) + version = await client.handshake() + assert version == 2 + assert client.server_protocol == 2 + client._writer.close() + + +@pytest.mark.asyncio +async def test_set_timeout_ok(mock_server): + client = await _connect(mock_server) + await client.handshake() + await client.set_timeout(10) + client._writer.close() + + +@pytest.mark.asyncio +async def test_get_status_returns_text(mock_server): + client = await _connect(mock_server) + await client.handshake() + await client.set_timeout(10) + text = await client.get_status(verbose=True) + assert "RCCL version" in text + assert "Communicator" in text + + +@pytest.mark.asyncio +async def test_set_format_raises_on_protocol_2(mock_server): + """set_format requires protocol 3+ — should raise on a protocol 2 server.""" + client = await _connect(mock_server) + await client.handshake() + assert client.server_protocol == 2 + with pytest.raises(ProtocolVersionError): + await client.set_format("json") + client._writer.close() + + +@pytest.mark.asyncio +async def test_start_monitor_raises_on_protocol_2(mock_server): + """start_monitor requires protocol 4+ — should raise on a protocol 2 server.""" + client = await _connect(mock_server) + await client.handshake() + with pytest.raises(ProtocolVersionError): + async for _ in client.start_monitor(): + pass + client._writer.close() + + +def test_rccl_models_import(): + from app.models.rccl_models import ( + RCCLSnapshot, + RCCLJobState, + NCCLFunction, + RCCLMarker, + ) + snapshot = RCCLSnapshot.empty() + assert snapshot.state == RCCLJobState.NO_JOB + assert snapshot.communicators == [] + + +@pytest.mark.asyncio +async def test_rccl_data_store_degrades_without_redis(): + """Without Redis, in-memory fallback buffers are used.""" + from app.collectors.rccl_data_store import RCCLDataStore + store = RCCLDataStore(redis_client=None) + await store.push_snapshot({"timestamp": 1.0}) + await store.push_event({"timestamp": 1.0}) + result = await store.get_recent_snapshots() + assert result == [{"timestamp": 1.0}] + result = await store.get_current_snapshot() + assert result == {"timestamp": 1.0} + events = await store.get_events_in_range(0.0, 2.0) + assert events == [{"timestamp": 1.0}] + + +def test_ncclfunction_enum_str_values(): + from app.models.rccl_models import NCCLFunction + assert NCCLFunction.ALL_REDUCE == "AllReduce" + assert NCCLFunction.ALL_GATHER == "AllGather" + assert NCCLFunction.SEND == "Send" + + +def test_rccl_job_state_values(): + from app.models.rccl_models import RCCLJobState + assert RCCLJobState.NO_JOB == "no_job" + assert RCCLJobState.HEALTHY == "healthy" + assert RCCLJobState.DEGRADED == "degraded" diff --git a/cvs/monitors/cluster-mon/backend/tests/test_rccl_text_parser.py b/cvs/monitors/cluster-mon/backend/tests/test_rccl_text_parser.py new file mode 100644 index 00000000..ec3c0dc4 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_rccl_text_parser.py @@ -0,0 +1,189 @@ +""" +Tests for RCCL RAS text output parser. +Written test-first against real captured rcclras -v output from a live MI300X cluster. +""" +import pytest +from pathlib import Path + +from app.collectors.rccl_text_parser import RCCLTextParser +from app.models.rccl_models import RCCLJobState + +FIXTURES_DIR = Path(__file__).parent / "fixtures" + + +@pytest.fixture +def healthy_output(): + return (FIXTURES_DIR / "rccl_verbose_status_healthy.txt").read_text() + + +@pytest.fixture +def degraded_output(): + return (FIXTURES_DIR / "rccl_verbose_status_degraded.txt").read_text() + + +@pytest.fixture +def connection_reset_output(): + return (FIXTURES_DIR / "rccl_verbose_status_connection_reset.txt").read_text() + + +@pytest.fixture +def parser(): + return RCCLTextParser() + + +# -- Healthy fixture tests --------------------------------------------------- + +def test_parse_rccl_version(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary is not None + assert snapshot.job_summary.rccl_version == "2.28.3" + + +def test_parse_hip_runtime_version(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary.hip_runtime_version == 70226015 + + +def test_parse_driver_version(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary.amdgpu_driver_version == 70226015 + + +def test_parse_job_summary_nodes(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary.total_nodes == 1 + + +def test_parse_job_summary_processes(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary.total_processes == 8 + + +def test_parse_job_summary_gpus(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary.total_gpus == 8 + + +def test_parse_healthy_state(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.state == RCCLJobState.HEALTHY + + +def test_parse_healthy_communicators(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert len(snapshot.communicators) == 1 + comm = snapshot.communicators[0] + assert comm.total_ranks == 8 + assert comm.responding_ranks == 8 + assert comm.missing_ranks == 0 + assert comm.health == RCCLJobState.HEALTHY + + +def test_parse_healthy_no_dead_peers(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.dead_peers == [] + + +def test_parse_healthy_no_errors(parser, healthy_output): + """Errors section is empty in healthy fixture.""" + snapshot = parser.parse(healthy_output) + assert snapshot.state == RCCLJobState.HEALTHY + + +# -- Degraded fixture tests -------------------------------------------------- + +def test_parse_degraded_state(parser, degraded_output): + snapshot = parser.parse(degraded_output) + assert snapshot.state == RCCLJobState.DEGRADED + + +def test_parse_degraded_communicator_ranks(parser, degraded_output): + snapshot = parser.parse(degraded_output) + assert len(snapshot.communicators) >= 1 + comm = snapshot.communicators[0] + # 7 responding out of 8 total + assert comm.total_ranks == 8 + assert comm.responding_ranks == 7 + assert comm.missing_ranks == 1 + + +def test_parse_degraded_communicator_health(parser, degraded_output): + snapshot = parser.parse(degraded_output) + comm = snapshot.communicators[0] + assert comm.health == RCCLJobState.DEGRADED + + +def test_parse_degraded_has_communicator_hash(parser, degraded_output): + """The degraded fixture contains communicator hash 3b2fe521bf43bc04 in the Errors section.""" + snapshot = parser.parse(degraded_output) + # The parser should extract comm hash from error section if available + # At minimum, the communicator should be parsed from the table + assert len(snapshot.communicators) >= 1 + + +def test_parse_degraded_errors_section_not_empty(parser, degraded_output): + """The degraded fixture has INCOMPLETE errors -- parser should detect this.""" + snapshot = parser.parse(degraded_output) + assert snapshot.state == RCCLJobState.DEGRADED + + +# -- 2-node degraded fixture (ranks_per_node shown as range "7-8") ------------ + +@pytest.fixture +def degraded_2node_output(): + return (FIXTURES_DIR / "rccl_verbose_status_degraded_2node.txt").read_text() + + +def test_parse_degraded_2node_state(parser, degraded_2node_output): + snapshot = parser.parse(degraded_2node_output) + assert snapshot.state == RCCLJobState.DEGRADED + + +def test_parse_degraded_2node_communicator_parsed(parser, degraded_2node_output): + """ranks_per_node='7-8' range must not prevent communicator row from matching.""" + snapshot = parser.parse(degraded_2node_output) + assert len(snapshot.communicators) == 1 + comm = snapshot.communicators[0] + assert comm.total_ranks == 16 + assert comm.responding_ranks == 15 + assert comm.missing_ranks == 1 + assert comm.health == RCCLJobState.DEGRADED + + +def test_parse_degraded_2node_job_summary(parser, degraded_2node_output): + snapshot = parser.parse(degraded_2node_output) + assert snapshot.job_summary is not None + assert snapshot.job_summary.total_nodes == 2 + assert snapshot.job_summary.total_gpus == 16 + + +# -- Connection reset / error tests ------------------------------------------ + +def test_parse_connection_reset(parser, connection_reset_output): + snapshot = parser.parse(connection_reset_output) + assert snapshot.state in (RCCLJobState.NO_JOB, RCCLJobState.ERROR) + + +def test_parse_empty_string(parser): + snapshot = parser.parse("") + assert snapshot.state == RCCLJobState.NO_JOB + + +def test_parse_connection_refused(parser): + text = "Connecting to 127.0.0.1:28028: Connection refused\nFailed to connect to the NCCL RAS service!" + snapshot = parser.parse(text) + assert snapshot.state == RCCLJobState.NO_JOB + + +# -- Edge cases --------------------------------------------------------------- + +def test_parse_inconsistent_topology_single_node(parser, healthy_output): + snapshot = parser.parse(healthy_output) + assert snapshot.job_summary.inconsistent_topology is False + + +def test_parser_does_not_crash_on_garbage(parser): + """Parser should return ERROR state on unparseable text, not crash.""" + snapshot = parser.parse("some random garbage that is not rcclras output") + # Should return a snapshot (not raise), with NO_JOB or ERROR state + assert snapshot.state in (RCCLJobState.NO_JOB, RCCLJobState.ERROR) diff --git a/cvs/monitors/cluster-mon/backend/tests/test_reload.py b/cvs/monitors/cluster-mon/backend/tests/test_reload.py new file mode 100644 index 00000000..28fa88bf --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_reload.py @@ -0,0 +1,23 @@ +"""Tests for reload_configuration topology-diff logic.""" +import pytest +from unittest.mock import patch, MagicMock, AsyncMock + + +def test_settings_model_dump_comparison(): + """Verify that Settings.model_dump() can detect config changes.""" + from app.core.config import Settings + s1 = Settings() + s2 = Settings() + # Same defaults = equal dumps + assert s1.ssh.model_dump() == s2.ssh.model_dump() + assert s1.polling.model_dump() == s2.polling.model_dump() + assert s1.rccl.model_dump() == s2.rccl.model_dump() + + +def test_settings_model_dump_detects_change(monkeypatch): + """Different env vars should produce different model_dump().""" + from app.core.config import Settings + s1 = Settings() + monkeypatch.setenv('POLLING__INTERVAL', '99') + s2 = Settings() + assert s1.polling.model_dump() != s2.polling.model_dump() diff --git a/cvs/monitors/cluster-mon/backend/tests/test_ssh_port_forward.py b/cvs/monitors/cluster-mon/backend/tests/test_ssh_port_forward.py new file mode 100644 index 00000000..c22c42a8 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_ssh_port_forward.py @@ -0,0 +1,112 @@ +""" +Tests for SSH port-forwarding bridge (_run_bridge). +Tests use real socketpairs and mock paramiko channels. +""" +import asyncio +import socket +import threading +import time +import pytest +from unittest.mock import MagicMock, patch + +from app.core.ssh_port_forward import _run_bridge + + +class MockChannel: + """Minimal mock paramiko channel backed by a real socket.""" + + def __init__(self, sock: socket.socket): + self._sock = sock + self.closed = False + + def recv(self, nbytes: int) -> bytes: + try: + return self._sock.recv(nbytes) + except Exception: + return b"" + + def sendall(self, data: bytes) -> None: + self._sock.sendall(data) + + def close(self) -> None: + self.closed = True + try: + self._sock.close() + except Exception: + pass + + +def test_run_bridge_forwards_data_ch_to_sock(): + """Data written to channel side arrives on the socket side.""" + # Create two connected socket pairs to simulate channel <-> bridge <-> user + ch_side, bridge_ch_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + bridge_sock_end, user_sock_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + + channel = MockChannel(ch_side) + _run_bridge(channel, bridge_sock_end) + + # Write to channel side -> should arrive at user_sock_end + bridge_ch_end.sendall(b"hello from channel") + user_sock_end.settimeout(2.0) + data = user_sock_end.recv(100) + assert data == b"hello from channel" + + bridge_ch_end.close() + user_sock_end.close() + + +def test_run_bridge_forwards_data_sock_to_ch(): + """Data written to the socket side arrives on the channel side.""" + ch_side, bridge_ch_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + bridge_sock_end, user_sock_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + + channel = MockChannel(ch_side) + _run_bridge(channel, bridge_sock_end) + + # Write to user_sock_end -> should arrive at bridge_ch_end (channel side) + user_sock_end.sendall(b"hello from socket") + bridge_ch_end.settimeout(2.0) + data = bridge_ch_end.recv(100) + assert data == b"hello from socket" + + bridge_ch_end.close() + user_sock_end.close() + + +def test_run_bridge_closes_both_on_channel_close(): + """When the channel closes, the socket side also closes (no thread leak).""" + ch_side, bridge_ch_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + bridge_sock_end, user_sock_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + + channel = MockChannel(ch_side) + _run_bridge(channel, bridge_sock_end) + + # Close the channel side -- this sends EOF + bridge_ch_end.close() + ch_side.close() + + # The user_sock_end should eventually get EOF too + user_sock_end.settimeout(2.0) + data = user_sock_end.recv(100) + assert data == b"" # EOF propagated + + user_sock_end.close() + + +def test_run_bridge_daemon_threads(): + """Bridge threads must be daemon threads (don't block process exit).""" + ch_side, bridge_ch_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + bridge_sock_end, user_sock_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + + before = {t.name for t in threading.enumerate()} + channel = MockChannel(ch_side) + _run_bridge(channel, bridge_sock_end) + + # Find the new threads + after = {t for t in threading.enumerate() if t.name not in before} + bridge_threads = [t for t in after if "bridge" in t.name] + assert len(bridge_threads) == 2 + assert all(t.daemon for t in bridge_threads) + + bridge_ch_end.close() + user_sock_end.close() diff --git a/cvs/monitors/cluster-mon/backend/tests/test_websocket_manager.py b/cvs/monitors/cluster-mon/backend/tests/test_websocket_manager.py new file mode 100644 index 00000000..024b6df7 --- /dev/null +++ b/cvs/monitors/cluster-mon/backend/tests/test_websocket_manager.py @@ -0,0 +1,71 @@ +"""Tests for ConnectionManager WebSocket pattern.""" +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.main import ConnectionManager + + +@pytest.fixture +def manager(): + return ConnectionManager(max_queue_size=4) + + +@pytest.mark.asyncio +async def test_broadcast_is_nonblocking(manager): + """broadcast() should return immediately even with no clients.""" + manager.broadcast({"type": "test"}) # should not raise or block + + +@pytest.mark.asyncio +async def test_broadcast_to_connected_client(manager): + """broadcast() should put messages in client queues.""" + ws = AsyncMock() + ws.accept = AsyncMock() + ws.close = AsyncMock() + await manager.connect(ws) + + manager.broadcast({"type": "test", "data": "hello"}) + + # Give the sender task a moment to process + await asyncio.sleep(0.05) + + # The sender task should have called send_json + ws.send_json.assert_called() + assert manager.client_count == 1 + + +@pytest.mark.asyncio +async def test_slow_client_disconnected_on_full_queue(manager): + """Client with full queue should be disconnected.""" + ws = AsyncMock() + ws.accept = AsyncMock() + ws.close = AsyncMock() + # Make send_json block forever to fill the queue + async def _block(_msg): + await asyncio.sleep(100) + + ws.send_json = AsyncMock(side_effect=_block) + + await manager.connect(ws) + assert manager.client_count == 1 + + # Fill the queue (maxsize=4) + for i in range(10): + manager.broadcast({"msg": i}) + + await asyncio.sleep(0.1) # let cleanup tasks run + # Client should have been removed due to full queue + assert manager.client_count == 0 + + +@pytest.mark.asyncio +async def test_disconnect_cleans_up(manager): + ws = AsyncMock() + ws.accept = AsyncMock() + ws.close = AsyncMock() + await manager.connect(ws) + assert manager.client_count == 1 + + await manager.disconnect(ws) + assert manager.client_count == 0 diff --git a/cvs/monitors/cluster-mon/config/cluster.yaml.example b/cvs/monitors/cluster-mon/config/cluster.yaml.example index 342b9949..eefa06af 100644 --- a/cvs/monitors/cluster-mon/config/cluster.yaml.example +++ b/cvs/monitors/cluster-mon/config/cluster.yaml.example @@ -23,3 +23,24 @@ cluster: alerts: gpu_temp_threshold: 85.0 gpu_util_threshold: 95.0 + + storage: + redis: + url: redis://localhost:6379 + db: 0 + password: null # set to a string to enable Redis AUTH + snapshot_max_entries: 1000 + event_max_entries: 10000 + + rccl: + ras_port: 28028 + poll_interval: 30 + collective_timeout_secs: 10 + debug_log_path: null + + inspector: + enabled: false # set to true when an RCCL job is running with Inspector plugin + mode: ssh # "ssh" (remote NFS, default) or "file" (local NFS) + dump_dir: null # NFS path: e.g. /nfs/shared/inspector-logs/ + poll_interval: 10 # seconds between collection cycles + max_records_per_file: 500 # tail last N lines per .log file diff --git a/cvs/monitors/cluster-mon/docker-compose.yml b/cvs/monitors/cluster-mon/docker-compose.yml index c4481d1c..05a2c26b 100644 --- a/cvs/monitors/cluster-mon/docker-compose.yml +++ b/cvs/monitors/cluster-mon/docker-compose.yml @@ -6,6 +6,18 @@ version: '3.8' # Option 2: Use network_mode: "host" (uncomment below) services: + redis: + image: redis:7-alpine + restart: unless-stopped + volumes: + - redis_data:/data + command: redis-server --appendonly yes --appendfsync everysec --requirepass ${REDIS_PASSWORD:-cvs_cluster_mon} + healthcheck: + test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD:-cvs_cluster_mon}", "ping"] + interval: 10s + timeout: 5s + retries: 3 + cvs-cluster-monitor: build: context: . @@ -14,8 +26,14 @@ services: # Network mode: Use "host" to avoid iptables issues (recommended for production) # Uncomment the line below and comment out the ports section + # NOTE: If using network_mode: "host", remove depends_on and set + # STORAGE__REDIS__URL=redis://localhost:6379 instead. # network_mode: "host" + depends_on: + redis: + condition: service_healthy + # Port mapping (used with bridge network, default) # Comment out if using network_mode: "host" ports: @@ -32,6 +50,9 @@ services: - DEBUG=${DEBUG:-false} # Cluster monitor home - CLUSTER_MONITOR_HOME=/app + # Redis connection (uses Docker service name as hostname) + - STORAGE__REDIS__URL=redis://redis:6379 + - STORAGE__REDIS__PASSWORD=${REDIS_PASSWORD:-cvs_cluster_mon} restart: unless-stopped # Health check disabled - uncomment if needed # healthcheck: @@ -40,3 +61,6 @@ services: # timeout: 10s # retries: 3 # start_period: 40s + +volumes: + redis_data: diff --git a/cvs/monitors/cluster-mon/frontend/src/App.tsx b/cvs/monitors/cluster-mon/frontend/src/App.tsx index 0b5c72be..d58f5faf 100644 --- a/cvs/monitors/cluster-mon/frontend/src/App.tsx +++ b/cvs/monitors/cluster-mon/frontend/src/App.tsx @@ -9,6 +9,10 @@ import { TopologyPage } from './pages/TopologyPage' import { GPUSoftwarePage } from './pages/GPUSoftwarePage' import { NICSoftwarePage } from './pages/NICSoftwarePage' import { LogsPage } from './pages/LogsPage' +import { RCCLHealthPage } from './pages/RCCLHealthPage' +import { RCCLTopologyPage } from './pages/RCCLTopologyPage' +import { RCCLTimelinePage } from './pages/RCCLTimelinePage' +import { RCCLPerformancePage } from './pages/RCCLPerformancePage' import { NodeDetailsModal } from './components/NodeDetailsModal' function App() { @@ -26,6 +30,10 @@ function App() { } /> } /> } /> + } /> + } /> + } /> + } /> diff --git a/cvs/monitors/cluster-mon/frontend/src/components/Layout/Sidebar.tsx b/cvs/monitors/cluster-mon/frontend/src/components/Layout/Sidebar.tsx index 1a3c1994..3c940f45 100644 --- a/cvs/monitors/cluster-mon/frontend/src/components/Layout/Sidebar.tsx +++ b/cvs/monitors/cluster-mon/frontend/src/components/Layout/Sidebar.tsx @@ -1,8 +1,8 @@ import { NavLink } from 'react-router-dom' -import { LayoutDashboard, Settings, Cpu, Network, Activity, Package, HardDrive, Share2, FileText } from 'lucide-react' +import { LayoutDashboard, Settings, Cpu, Network, Activity, Package, HardDrive, Share2, FileText, Clock, GitFork, Zap } from 'lucide-react' import { cn } from '@/utils/cn' -const navigation = [ +const mainNav = [ { name: 'Dashboard', href: '/', icon: LayoutDashboard }, { name: 'Configuration', href: '/config', icon: Settings }, { name: 'GPU Metrics', href: '/gpu-metrics', icon: Cpu }, @@ -13,6 +13,32 @@ const navigation = [ { name: 'Logs', href: '/logs', icon: FileText }, ] +const rcclNav = [ + { name: 'RCCL Health', href: '/rccl-health', icon: Activity }, + { name: 'RAS Topology', href: '/rccl-topology', icon: GitFork }, + { name: 'Timeline', href: '/rccl-timeline', icon: Clock }, + { name: 'Performance', href: '/rccl-performance', icon: Zap }, +] + +function NavItem({ name, href, icon: Icon }: { name: string; href: string; icon: React.ElementType }) { + return ( + + cn( + 'flex items-center gap-3 px-4 py-3 rounded-lg transition-colors', + isActive + ? 'bg-blue-600 text-white' + : 'text-gray-300 hover:bg-gray-800 hover:text-white' + ) + } + > + + {name} + + ) +} + export function Sidebar() { return (
@@ -26,27 +52,13 @@ export function Sidebar() {
{/* Navigation */} -