diff --git a/src/ecs-mcp-server/awslabs/ecs_mcp_server/api/security_analysis.py b/src/ecs-mcp-server/awslabs/ecs_mcp_server/api/security_analysis.py new file mode 100644 index 0000000000..d84e87c1af --- /dev/null +++ b/src/ecs-mcp-server/awslabs/ecs_mcp_server/api/security_analysis.py @@ -0,0 +1,550 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +API for ECS security analysis operations. + +This module provides functions for analyzing ECS cluster security configurations +and generating security recommendations. +""" + +import logging +import os +from typing import Any, Dict + +import boto3 + +from awslabs.ecs_mcp_server.api.resource_management import ecs_api_operation +from awslabs.ecs_mcp_server.utils.aws import get_aws_config + +logger = logging.getLogger(__name__) + + +def get_target_region() -> str: + """ + Get the target AWS region for security analysis from environment variable. + + Returns: + AWS region name from AWS_REGION environment variable (defaults to 'us-east-1') + """ + region = os.environ.get("AWS_REGION", "us-east-1") + logger.info(f"Using region from environment: {region}") + return region + + +async def get_clusters_with_metadata(region: str) -> list[Dict[str, Any]]: + """ + Get all ECS clusters in the specified region with their metadata. + + Args: + region: AWS region to get clusters from + + Returns: + List of cluster dictionaries with metadata + + Raises: + Exception: If retrieving clusters fails + """ + logger.info(f"Listing ECS clusters in region: {region}") + + try: + # List cluster ARNs + list_response = await ecs_api_operation(api_operation="ListClusters", api_params={}) + + cluster_arns = list_response.get("clusterArns", []) + + if not cluster_arns: + logger.info(f"No clusters found in region {region}") + return [] + + logger.info(f"Found {len(cluster_arns)} cluster(s) in region {region}") + + # Describe clusters to get metadata + describe_response = await ecs_api_operation( + api_operation="DescribeClusters", + api_params={ + "clusters": cluster_arns, + "include": ["ATTACHMENTS", "SETTINGS", "STATISTICS", "TAGS"], + }, + ) + + clusters = describe_response.get("clusters", []) + + # Format cluster information + all_clusters = [] + for cluster in clusters: + cluster_info = { + "cluster_name": cluster.get("clusterName"), + "cluster_arn": cluster.get("clusterArn"), + "status": cluster.get("status"), + "running_tasks_count": cluster.get("runningTasksCount", 0), + "pending_tasks_count": cluster.get("pendingTasksCount", 0), + "active_services_count": cluster.get("activeServicesCount", 0), + "registered_container_instances_count": cluster.get( + "registeredContainerInstancesCount", 0 + ), + "tags": {tag["key"]: tag["value"] for tag in cluster.get("tags", [])}, + } + all_clusters.append(cluster_info) + + logger.info(f"Successfully retrieved metadata for {len(all_clusters)} cluster(s)") + return all_clusters + + except Exception as e: + logger.error(f"Error retrieving clusters in region {region}: {e}") + raise Exception(f"Failed to retrieve clusters in region '{region}': {str(e)}") from e + + +def format_clusters_for_display(clusters: list[Dict[str, Any]], region: str) -> str: + """ + Format cluster data into a user-friendly display string. + + Args: + clusters: List of cluster dictionaries + region: AWS region name + + Returns: + Formatted string with cluster information for display + """ + if not clusters: + return f""" +No ECS clusters found in region: {region} + +To create a cluster, you can use the AWS CLI: +```bash +aws ecs create-cluster --cluster-name my-cluster --region {region} +``` + +Or use the AWS Console to create a cluster in the ECS service. +""" + + # Build the formatted output + lines = [ + "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", + f"📋 ECS CLUSTERS IN REGION: {region}", + "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", + "", + f"Found {len(clusters)} cluster(s):", + "", + ] + + for i, cluster in enumerate(clusters, 1): + cluster_name = cluster.get("cluster_name", "Unknown") + status = cluster.get("status", "Unknown") + running_tasks = cluster.get("running_tasks_count", 0) + active_services = cluster.get("active_services_count", 0) + + lines.extend( + [ + f"{i}. {cluster_name}", + f" Status: {status}", + f" Running Tasks: {running_tasks}", + f" Active Services: {active_services}", + "", + ] + ) + + lines.extend( + [ + "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", + "", + "To analyze a specific cluster, call this tool again with:", + " cluster_names=['cluster-name']", + "", + "Example:", + f" analyze_ecs_security(" + f"cluster_names=['{clusters[0].get('cluster_name')}'], " + f"region='{region}')", + "", + ] + ) + + return "\n".join(lines) + + +class ClusterNotFoundError(Exception): + """Raised when one or more clusters cannot be found.""" + + pass + + +async def validate_clusters(cluster_names: list[str], region: str) -> list[str]: + """ + Validate that the specified clusters exist and return their ARNs. + + Args: + cluster_names: List of cluster names to validate + region: AWS region to check clusters in + + Returns: + List of validated cluster ARNs + + Raises: + ClusterNotFoundError: If one or more clusters are not found + """ + logger.info(f"Validating {len(cluster_names)} cluster(s) in region {region}") + + try: + # Describe clusters to validate they exist + describe_response = await ecs_api_operation( + api_operation="DescribeClusters", + api_params={"clusters": cluster_names, "include": ["TAGS"]}, + ) + + found_clusters = describe_response.get("clusters", []) + failures = describe_response.get("failures", []) + + if failures: + failed_names = [f["arn"] for f in failures] + raise ClusterNotFoundError(f"Clusters not found in region '{region}': {failed_names}") + + if len(found_clusters) != len(cluster_names): + found_names = [c["clusterName"] for c in found_clusters] + missing = set(cluster_names) - set(found_names) + raise ClusterNotFoundError(f"Clusters not found in region '{region}': {list(missing)}") + + cluster_arns = [cluster["clusterArn"] for cluster in found_clusters] + logger.info(f"Successfully validated {len(cluster_arns)} cluster(s)") + return cluster_arns + + except ClusterNotFoundError: + raise + except Exception as e: + logger.error(f"Error validating clusters: {e}") + raise Exception(f"Failed to validate clusters: {str(e)}") from e + + +async def collect_cluster_configuration(region: str, cluster_name: str) -> Dict[str, Any]: + """ + Collect comprehensive configuration for an ECS cluster. + + This function gathers all security-relevant configuration data for analysis: + - Cluster metadata and settings + - Service configurations + - Task definition configurations + - Security group configurations + - IAM role references + + Args: + region: AWS region containing the cluster + cluster_name: Name of the cluster to analyze + + Returns: + Dictionary containing complete cluster configuration + + Note: + This function collects data but does not perform security analysis. + The analysis is performed by AI agents using the collected data. + """ + logger.info(f"Collecting configuration for cluster '{cluster_name}' in region {region}") + + cluster_config = { + "cluster_name": cluster_name, + "region": region, + "cluster_metadata": {}, + "services": [], + "task_definitions": [], + "security_groups": [], + "collection_errors": [], + } + + try: + # Step 1: Collect cluster metadata + logger.info(f"Step 1: Collecting cluster metadata for '{cluster_name}'") + describe_response = await ecs_api_operation( + api_operation="DescribeClusters", + api_params={ + "clusters": [cluster_name], + "include": ["ATTACHMENTS", "SETTINGS", "STATISTICS", "TAGS"], + }, + ) + + clusters = describe_response.get("clusters", []) + if not clusters: + raise Exception(f"Cluster '{cluster_name}' not found") + + cluster = clusters[0] + cluster_config["cluster_metadata"] = { + "cluster_arn": cluster.get("clusterArn"), + "cluster_name": cluster.get("clusterName"), + "status": cluster.get("status"), + "running_tasks_count": cluster.get("runningTasksCount", 0), + "pending_tasks_count": cluster.get("pendingTasksCount", 0), + "active_services_count": cluster.get("activeServicesCount", 0), + "registered_container_instances_count": cluster.get( + "registeredContainerInstancesCount", 0 + ), + "statistics": cluster.get("statistics", []), + "tags": {tag["key"]: tag["value"] for tag in cluster.get("tags", [])}, + "settings": cluster.get("settings", []), + "configuration": cluster.get("configuration", {}), + "service_connect_defaults": cluster.get("serviceConnectDefaults", {}), + "attachments": cluster.get("attachments", []), + } + + logger.info(f"Successfully collected cluster metadata for '{cluster_name}'") + + except Exception as e: + error_msg = f"Failed to collect cluster metadata: {str(e)}" + logger.warning(error_msg) + cluster_config["collection_errors"].append(error_msg) + + # Step 2: Collect service configurations + try: + logger.info(f"Step 2: Collecting service configurations for cluster '{cluster_name}'") + services_response = await ecs_api_operation( + api_operation="ListServices", api_params={"cluster": cluster_name} + ) + + service_arns = services_response.get("serviceArns", []) + logger.info(f"Found {len(service_arns)} service(s) in cluster '{cluster_name}'") + + if service_arns: + # Process services in batches (DescribeServices has a limit) + batch_size = 10 + services_list = [] + services_with_errors = [] + + for i in range(0, len(service_arns), batch_size): + batch_arns = service_arns[i : i + batch_size] + try: + describe_services_response = await ecs_api_operation( + api_operation="DescribeServices", + api_params={"cluster": cluster_name, "services": batch_arns}, + ) + + services = describe_services_response.get("services", []) + + for service in services: + try: + service_name = service.get("serviceName") + service_status = service.get("status") + + # Log warning for non-active services + if service_status not in ["ACTIVE", "DRAINING"]: + logger.warning( + f"Service {service_name} is in {service_status} state" + ) + + # Collect security group information + security_group_details = [] + network_config = service.get("networkConfiguration", {}) + awsvpc_config = network_config.get("awsvpcConfiguration", {}) + security_group_ids = awsvpc_config.get("securityGroups", []) + + if security_group_ids: + try: + # Describe security groups + ec2_client = boto3.client( + "ec2", region_name=region, config=get_aws_config() + ) + sg_response = ec2_client.describe_security_groups( + GroupIds=security_group_ids + ) + security_group_details.extend( + [ + { + "group_id": sg["GroupId"], + "group_name": sg.get("GroupName", ""), + "description": sg.get("Description", ""), + "vpc_id": sg.get("VpcId", ""), + "ingress_rules": sg.get("IpPermissions", []), + "egress_rules": sg.get("IpPermissionsEgress", []), + "tags": { + tag["Key"]: tag["Value"] + for tag in sg.get("Tags", []) + }, + } + for sg in sg_response.get("SecurityGroups", []) + ] + ) + logger.info( + f"Collected {len(security_group_details)} " + f"security group(s) for service {service_name}" + ) + except Exception as e: + error_msg = ( + f"Failed to describe security groups " + f"{security_group_ids}: {str(e)}" + ) + logger.error(f"Service {service_name}: {error_msg}") + services_with_errors.append( + {"service_name": service_name, "error": error_msg} + ) + + # Build service configuration + service_config = { + "service_name": service_name, + "service_arn": service.get("serviceArn"), + "cluster_arn": service.get("clusterArn"), + "task_definition": service.get("taskDefinition"), + "desired_count": service.get("desiredCount", 0), + "running_count": service.get("runningCount", 0), + "pending_count": service.get("pendingCount", 0), + "status": service_status, + "launch_type": service.get("launchType"), + "capacity_provider_strategy": service.get( + "capacityProviderStrategy", [] + ), + "platform_version": service.get("platformVersion"), + "platform_family": service.get("platformFamily"), + "network_configuration": network_config, + "security_groups": security_group_details, + "load_balancers": service.get("loadBalancers", []), + "service_registries": service.get("serviceRegistries", []), + "tags": { + tag["key"]: tag["value"] for tag in service.get("tags", []) + }, + "enable_execute_command": service.get( + "enableExecuteCommand", False + ), + "health_check_grace_period_seconds": service.get( + "healthCheckGracePeriodSeconds" + ), + "scheduling_strategy": service.get("schedulingStrategy"), + "deployment_controller": service.get("deploymentController", {}), + "service_connect_configuration": service.get( + "serviceConnectConfiguration", {} + ), + } + services_list.append(service_config) + + except Exception as e: + logger.error(f"Failed to process service: {e}") + services_with_errors.append( + { + "service_name": service.get("serviceName", "Unknown"), + "error": str(e), + } + ) + + except Exception as e: + logger.error(f"Failed to describe services batch: {e}") + for arn in batch_arns: + service_name = arn.split("/")[-1] + services_with_errors.append({"service_name": service_name, "error": str(e)}) + + cluster_config["services"] = services_list + if services_with_errors: + cluster_config["collection_errors"].extend( + [f"Service {s['service_name']}: {s['error']}" for s in services_with_errors] + ) + logger.warning( + f"{len(services_with_errors)} service(s) had errors during collection" + ) + + except Exception as e: + error_msg = f"Failed to collect service configurations: {str(e)}" + logger.warning(error_msg) + cluster_config["collection_errors"].append(error_msg) + + # Step 3: Collect task definition configurations + try: + logger.info( + f"Step 3: Collecting task definition configurations for cluster '{cluster_name}'" + ) + + # Get unique task definition ARNs from services + task_def_arns = set() + for service in cluster_config["services"]: + task_def_arn = service.get("task_definition") + if task_def_arn: + task_def_arns.add(task_def_arn) + + logger.info(f"Found {len(task_def_arns)} unique task definition(s) to describe") + + task_definitions_list = [] + for task_def_arn in task_def_arns: + try: + # Describe individual task definition + task_def_response = await ecs_api_operation( + api_operation="DescribeTaskDefinition", + api_params={"taskDefinition": task_def_arn, "include": ["TAGS"]}, + ) + + task_def = task_def_response.get("taskDefinition", {}) + task_def_config = { + "family": task_def.get("family"), + "task_definition_arn": task_def.get("taskDefinitionArn"), + "revision": task_def.get("revision"), + "status": task_def.get("status"), + "requires_compatibilities": task_def.get("requiresCompatibilities", []), + "network_mode": task_def.get("networkMode"), + "cpu": task_def.get("cpu"), + "memory": task_def.get("memory"), + "task_role_arn": task_def.get("taskRoleArn"), + "execution_role_arn": task_def.get("executionRoleArn"), + "container_definitions": task_def.get("containerDefinitions", []), + "volumes": task_def.get("volumes", []), + "placement_constraints": task_def.get("placementConstraints", []), + "requires_attributes": task_def.get("requiresAttributes", []), + "pid_mode": task_def.get("pidMode"), + "ipc_mode": task_def.get("ipcMode"), + "proxy_configuration": task_def.get("proxyConfiguration", {}), + "inference_accelerators": task_def.get("inferenceAccelerators", []), + "ephemeral_storage": task_def.get("ephemeralStorage", {}), + "runtime_platform": task_def.get("runtimePlatform", {}), + "tags": {tag["key"]: tag["value"] for tag in task_def_response.get("tags", [])}, + } + task_definitions_list.append(task_def_config) + + except Exception as e: + logger.warning(f"Failed to describe task definition '{task_def_arn}': {e}") + cluster_config["collection_errors"].append( + f"Task definition {task_def_arn}: {str(e)}" + ) + + cluster_config["task_definitions"] = task_definitions_list + logger.info( + f"Successfully collected {len(task_definitions_list)} task definition configuration(s)" + ) + + except Exception as e: + error_msg = f"Failed to collect task definition configurations: {str(e)}" + logger.warning(error_msg) + cluster_config["collection_errors"].append(error_msg) + + # Step 4: Collect unique security groups (deduplicated from services) + try: + logger.info("Step 4: Deduplicating security group configurations") + unique_security_groups = {} + + for service in cluster_config["services"]: + for sg in service.get("security_groups", []): + sg_id = sg.get("group_id") + if sg_id and sg_id not in unique_security_groups: + unique_security_groups[sg_id] = sg + + cluster_config["security_groups"] = list(unique_security_groups.values()) + logger.info(f"Found {len(unique_security_groups)} unique security group(s)") + + except Exception as e: + error_msg = f"Failed to process security groups: {str(e)}" + logger.warning(error_msg) + cluster_config["collection_errors"].append(error_msg) + + # Log collection summary + logger.info( + f"Configuration collection complete for '{cluster_name}': " + f"{len(cluster_config['services'])} services, " + f"{len(cluster_config['task_definitions'])} task definitions, " + f"{len(cluster_config['security_groups'])} security groups" + ) + + if cluster_config["collection_errors"]: + logger.warning( + f"Collection completed with {len(cluster_config['collection_errors'])} error(s)" + ) + + return cluster_config diff --git a/src/ecs-mcp-server/awslabs/ecs_mcp_server/main.py b/src/ecs-mcp-server/awslabs/ecs_mcp_server/main.py index a1dbaa58a9..96a305b9cf 100755 --- a/src/ecs-mcp-server/awslabs/ecs_mcp_server/main.py +++ b/src/ecs-mcp-server/awslabs/ecs_mcp_server/main.py @@ -31,6 +31,7 @@ deployment_status, infrastructure, resource_management, + security_analysis, troubleshooting, ) from awslabs.ecs_mcp_server.utils.config import get_config @@ -134,6 +135,7 @@ def _create_ecs_mcp_server() -> Tuple[FastMCP, Dict[str, Any]]: resource_management.register_module(mcp) troubleshooting.register_module(mcp) delete.register_module(mcp) + security_analysis.register_module(mcp) # Register all proxies aws_knowledge_proxy.register_proxy(mcp) diff --git a/src/ecs-mcp-server/awslabs/ecs_mcp_server/modules/security_analysis.py b/src/ecs-mcp-server/awslabs/ecs_mcp_server/modules/security_analysis.py new file mode 100644 index 0000000000..36c1f38043 --- /dev/null +++ b/src/ecs-mcp-server/awslabs/ecs_mcp_server/modules/security_analysis.py @@ -0,0 +1,221 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Security Analysis module for ECS MCP Server. +This module provides tools and prompts for analyzing ECS security configurations. +""" + +import json +import logging +from datetime import datetime + +from fastmcp import FastMCP +from pydantic import Field + +from awslabs.ecs_mcp_server.api.security_analysis import ( + collect_cluster_configuration, + format_clusters_for_display, + get_clusters_with_metadata, + get_target_region, + validate_clusters, +) + +logger = logging.getLogger(__name__) + + +def register_module(mcp: FastMCP) -> None: + """Register security analysis module tools and prompts with the MCP server.""" + + # Define pydantic Field descriptions for all parameters + cluster_names_field = Field( + default=None, + description=( + "Optional list of specific cluster names to analyze. " + "If not provided, lists all available clusters for user selection. " + "When provided, collects complete configuration data for the specified clusters. " + "Example: ['prod-cluster', 'staging-cluster']" + ), + ) + + @mcp.tool(name="analyze_ecs_security", annotations=None) + async def mcp_analyze_ecs_security( + cluster_names: list[str] | None = cluster_names_field, + ) -> str: + """ + Analyze ECS cluster security configurations or list available clusters. + + The region is determined from the AWS_REGION environment variable (defaults to 'us-east-1'). + + This tool provides two modes of operation: + 1. Cluster Discovery: List all available clusters (when cluster_names not provided) + 2. Configuration Collection: Collect detailed configuration data + (when cluster_names provided) + + Interactive Workflow: + + Step 1: List Available Clusters + - Call with NO cluster_names to list clusters in configured region + - Returns formatted list of available clusters with metadata + - User can review and select which clusters to analyze + + Step 2: Collect Configuration Data + - Call with specific cluster_names to collect configuration + - Returns comprehensive JSON configuration data + - Includes services, task definitions, security groups, IAM roles + + Usage Examples: + + Example 1 - List clusters: + analyze_ecs_security() + + Example 2 - Analyze specific cluster: + analyze_ecs_security(cluster_names=["prod-cluster"]) + + Example 3 - Analyze multiple clusters: + analyze_ecs_security(cluster_names=["web-cluster", "api-cluster"]) + + Configuration Data Collected: + - Cluster metadata (settings, statistics, tags) + - Service configurations (network, load balancers, capacity) + - Task definition details (containers, IAM roles, volumes) + - Security group rules (ingress/egress permissions) + - IAM role references and policies + - Container security settings + - Network configurations + + Parameters: + cluster_names: Optional list of cluster names. If None, lists available clusters. + + Returns: + - If cluster_names not provided: Formatted list of available clusters + - If cluster_names provided: JSON configuration data for analysis + + Error Handling: + - Cluster not found: Returns error with available cluster names + - Partial failures: Returns data with error details in collection_errors + """ + try: + # Step 1: Get target region from environment + logger.info("Step 1: Getting target region from environment") + target_region = get_target_region() + + # Step 2: Check operation mode + if cluster_names is None: + # Mode 1: List clusters for user selection + logger.info( + f"Step 2: Listing clusters in region '{target_region}' for user selection" + ) + clusters = await get_clusters_with_metadata(target_region) + return format_clusters_for_display(clusters, target_region) + else: + # Mode 2: Collect configuration data for specified clusters + logger.info( + f"Step 2: Collecting configuration for {len(cluster_names)} " + f"cluster(s) in region '{target_region}'" + ) + + # Step 2a: Validate clusters exist + logger.info("Step 2a: Validating cluster existence") + validated_arns = await validate_clusters(cluster_names, target_region) + logger.info(f"Successfully validated {len(validated_arns)} cluster(s)") + + # Step 2b: Collect configuration for each cluster + logger.info("Step 2b: Collecting cluster configurations") + cluster_configs = [] + for cluster_name in cluster_names: + try: + config = await collect_cluster_configuration(target_region, cluster_name) + cluster_configs.append(config) + except Exception as e: + logger.error( + f"Failed to collect configuration for cluster '{cluster_name}': {e}" + ) + # Add error info to the config + error_config = { + "cluster_name": cluster_name, + "region": target_region, + "collection_error": str(e), + "cluster_metadata": {}, + "services": [], + "task_definitions": [], + "security_groups": [], + "collection_errors": [f"Failed to collect configuration: {str(e)}"], + } + cluster_configs.append(error_config) + + # Step 2c: Format response + response_data = { + "analysis_type": "ecs_security_configuration", + "region": target_region, + "clusters_analyzed": len(cluster_configs), + "cluster_configurations": cluster_configs, + "collection_timestamp": datetime.utcnow().isoformat() + "Z", + } + + logger.info( + f"Configuration collection complete: " + f"{len(cluster_configs)} cluster(s) processed" + ) + return json.dumps(response_data, indent=2, default=str) + + except Exception as e: + import traceback + + error_msg = f"Error during security analysis: {str(e)}" + logger.error(error_msg) + logger.error(f"Traceback: {traceback.format_exc()}") + return f"❌ {error_msg}\n\nDetailed error:\n{traceback.format_exc()}" + + # Register prompt patterns for security analysis + + @mcp.prompt("analyze ecs security") + def analyze_ecs_security_prompt(): + """User wants to analyze ECS security""" + return ["analyze_ecs_security"] + + @mcp.prompt("check ecs security") + def check_ecs_security_prompt(): + """User wants to check ECS security""" + return ["analyze_ecs_security"] + + @mcp.prompt("ecs security audit") + def ecs_security_audit_prompt(): + """User wants to perform an ECS security audit""" + return ["analyze_ecs_security"] + + @mcp.prompt("security best practices") + def security_best_practices_prompt(): + """User wants to check security best practices""" + return ["analyze_ecs_security"] + + @mcp.prompt("security recommendations") + def security_recommendations_prompt(): + """User wants security recommendations""" + return ["analyze_ecs_security"] + + @mcp.prompt("scan ecs clusters") + def scan_ecs_clusters_prompt(): + """User wants to scan ECS clusters for security issues""" + return ["analyze_ecs_security"] + + @mcp.prompt("ecs security scan") + def ecs_security_scan_prompt(): + """User wants to perform an ECS security scan""" + return ["analyze_ecs_security"] + + @mcp.prompt("list ecs clusters") + def list_ecs_clusters_prompt(): + """User wants to list ECS clusters""" + return ["analyze_ecs_security"] diff --git a/src/ecs-mcp-server/tests/unit/test_security_analysis_api.py b/src/ecs-mcp-server/tests/unit/test_security_analysis_api.py new file mode 100644 index 0000000000..23916cab77 --- /dev/null +++ b/src/ecs-mcp-server/tests/unit/test_security_analysis_api.py @@ -0,0 +1,773 @@ +""" +Unit tests for security analysis API functions. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from awslabs.ecs_mcp_server.api.security_analysis import ( + ClusterNotFoundError, + collect_cluster_configuration, + format_clusters_for_display, + get_clusters_with_metadata, + get_target_region, + validate_clusters, +) + + +@patch.dict("os.environ", {"AWS_REGION": "eu-central-1"}) +def test_get_target_region_from_env(): + """Test get_target_region when using environment variable.""" + result = get_target_region() + + assert result == "eu-central-1" + + +@patch.dict("os.environ", {}, clear=True) +def test_get_target_region_default(): + """Test get_target_region when no region specified and no env var.""" + result = get_target_region() + + assert result == "us-east-1" + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_get_clusters_with_metadata_success(mock_ecs_api): + """Test get_clusters_with_metadata with successful response.""" + # Mock ListClusters response + mock_ecs_api.side_effect = [ + { + "clusterArns": [ + "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-1", + "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-2", + ] + }, + { + "clusters": [ + { + "clusterName": "test-cluster-1", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-1", + "status": "ACTIVE", + "runningTasksCount": 5, + "pendingTasksCount": 0, + "activeServicesCount": 3, + "registeredContainerInstancesCount": 2, + "tags": [{"key": "Environment", "value": "Production"}], + }, + { + "clusterName": "test-cluster-2", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-2", + "status": "ACTIVE", + "runningTasksCount": 2, + "pendingTasksCount": 1, + "activeServicesCount": 1, + "registeredContainerInstancesCount": 1, + "tags": [], + }, + ] + }, + ] + + result = await get_clusters_with_metadata("us-east-1") + + assert len(result) == 2 + assert result[0]["cluster_name"] == "test-cluster-1" + assert result[0]["status"] == "ACTIVE" + assert result[0]["running_tasks_count"] == 5 + assert result[0]["tags"] == {"Environment": "Production"} + assert result[1]["cluster_name"] == "test-cluster-2" + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_get_clusters_with_metadata_empty(mock_ecs_api): + """Test get_clusters_with_metadata when no clusters exist.""" + mock_ecs_api.return_value = {"clusterArns": []} + + result = await get_clusters_with_metadata("us-east-1") + + assert result == [] + mock_ecs_api.assert_called_once() + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_get_clusters_with_metadata_error(mock_ecs_api): + """Test get_clusters_with_metadata when API call fails.""" + mock_ecs_api.side_effect = Exception("API Error") + + with pytest.raises(Exception) as exc_info: + await get_clusters_with_metadata("us-east-1") + + assert "Failed to retrieve clusters" in str(exc_info.value) + + +def test_format_clusters_for_display_with_clusters(): + """Test format_clusters_for_display with multiple clusters.""" + clusters = [ + { + "cluster_name": "prod-cluster", + "status": "ACTIVE", + "running_tasks_count": 10, + "active_services_count": 5, + }, + { + "cluster_name": "staging-cluster", + "status": "ACTIVE", + "running_tasks_count": 3, + "active_services_count": 2, + }, + ] + + result = format_clusters_for_display(clusters, "us-east-1") + + assert "ECS CLUSTERS IN REGION: us-east-1" in result + assert "Found 2 cluster(s)" in result + assert "prod-cluster" in result + assert "staging-cluster" in result + assert "Running Tasks: 10" in result + assert "Active Services: 5" in result + + +def test_format_clusters_for_display_empty(): + """Test format_clusters_for_display with no clusters.""" + result = format_clusters_for_display([], "us-west-2") + + assert "No ECS clusters found in region: us-west-2" in result + assert "create-cluster" in result + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_validate_clusters_success(mock_ecs_api): + """Test validate_clusters with successful validation.""" + mock_ecs_api.return_value = { + "clusters": [ + { + "clusterName": "test-cluster-1", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-1", + }, + { + "clusterName": "test-cluster-2", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-2", + }, + ], + "failures": [], + } + + result = await validate_clusters(["test-cluster-1", "test-cluster-2"], "us-east-1") + + assert len(result) == 2 + assert "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-1" in result + assert "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-2" in result + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_validate_clusters_not_found(mock_ecs_api): + """Test validate_clusters when clusters are not found.""" + mock_ecs_api.return_value = { + "clusters": [], + "failures": [ + {"arn": "nonexistent-cluster", "reason": "MISSING"}, + ], + } + + with pytest.raises(ClusterNotFoundError) as exc_info: + await validate_clusters(["nonexistent-cluster"], "us-east-1") + + assert "Clusters not found" in str(exc_info.value) + assert "nonexistent-cluster" in str(exc_info.value) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_validate_clusters_partial_failure(mock_ecs_api): + """Test validate_clusters with partial failures.""" + mock_ecs_api.return_value = { + "clusters": [ + { + "clusterName": "test-cluster-1", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster-1", + } + ], + "failures": [], + } + + with pytest.raises(ClusterNotFoundError) as exc_info: + await validate_clusters(["test-cluster-1", "test-cluster-2"], "us-east-1") + + assert "Clusters not found" in str(exc_info.value) + assert "test-cluster-2" in str(exc_info.value) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_success(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with successful collection.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "runningTasksCount": 5, + "tags": [{"key": "Environment", "value": "Test"}], + "settings": [], + "configuration": {}, + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices response + { + "services": [ + { + "serviceName": "test-service", + "serviceArn": ( + "arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service" + ), + "status": "ACTIVE", + "taskDefinition": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "desiredCount": 2, + "networkConfiguration": { + "awsvpcConfiguration": {"securityGroups": ["sg-12345"]} + }, + "tags": [], + } + ] + }, + # DescribeTaskDefinition response + { + "taskDefinition": { + "family": "test-task", + "taskDefinitionArn": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "revision": 1, + "status": "ACTIVE", + "networkMode": "awsvpc", + "containerDefinitions": [ + { + "name": "test-container", + "image": "nginx:latest", + "memory": 512, + } + ], + }, + "tags": [], + }, + ] + + # Mock EC2 client for security groups + mock_ec2_client = MagicMock() + mock_ec2_client.describe_security_groups.return_value = { + "SecurityGroups": [ + { + "GroupId": "sg-12345", + "GroupName": "test-sg", + "Description": "Test security group", + "VpcId": "vpc-12345", + "IpPermissions": [], + "IpPermissionsEgress": [], + "Tags": [], + } + ] + } + mock_boto3.client.return_value = mock_ec2_client + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + assert result["cluster_name"] == "test-cluster" + assert result["region"] == "us-east-1" + assert len(result["services"]) == 1 + assert len(result["task_definitions"]) == 1 + assert len(result["security_groups"]) == 1 + assert result["services"][0]["service_name"] == "test-service" + assert result["task_definitions"][0]["family"] == "test-task" + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_collect_cluster_configuration_cluster_not_found(mock_ecs_api): + """Test collect_cluster_configuration when cluster is not found.""" + mock_ecs_api.return_value = {"clusters": []} + + result = await collect_cluster_configuration("us-east-1", "nonexistent-cluster") + + assert result["cluster_name"] == "nonexistent-cluster" + assert len(result["collection_errors"]) > 0 + assert "not found" in str(result["collection_errors"]) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_collect_cluster_configuration_partial_failure(mock_ecs_api): + """Test collect_cluster_configuration with partial failures.""" + # Mock successful cluster describe but failed services + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices fails + Exception("Service listing failed"), + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + assert result["cluster_name"] == "test-cluster" + assert len(result["collection_errors"]) > 0 + assert "Service listing failed" in str(result["collection_errors"]) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_no_services(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with no services.""" + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response - empty + {"serviceArns": []}, + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + assert result["cluster_name"] == "test-cluster" + assert len(result["services"]) == 0 + assert len(result["task_definitions"]) == 0 + assert len(result["security_groups"]) == 0 + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_validate_clusters_api_error(mock_ecs_api): + """Test validate_clusters with API error.""" + mock_ecs_api.side_effect = Exception("API error") + + with pytest.raises(Exception) as exc_info: + await validate_clusters(["test-cluster"], "us-east-1") + + assert "Failed to validate clusters" in str(exc_info.value) + assert "API error" in str(exc_info.value) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_security_group_error(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with security group error.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices response + { + "services": [ + { + "serviceName": "test-service", + "serviceArn": ( + "arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service" + ), + "status": "ACTIVE", + "taskDefinition": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "desiredCount": 2, + "networkConfiguration": { + "awsvpcConfiguration": {"securityGroups": ["sg-12345"]} + }, + "tags": [], + } + ] + }, + ] + + # Mock EC2 client to raise error + mock_ec2_client = MagicMock() + mock_ec2_client.describe_security_groups.side_effect = Exception("Security group error") + mock_boto3.client.return_value = mock_ec2_client + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should still return data but with errors + assert result["cluster_name"] == "test-cluster" + assert len(result["collection_errors"]) > 0 + assert any("Security group error" in str(e) for e in result["collection_errors"]) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_collect_cluster_configuration_service_processing_error(mock_ecs_api): + """Test collect_cluster_configuration with service processing error.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices fails + Exception("Service describe error"), + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should still return data but with errors + assert result["cluster_name"] == "test-cluster" + assert len(result["collection_errors"]) > 0 + assert any("Service describe error" in str(e) for e in result["collection_errors"]) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +async def test_collect_cluster_configuration_task_definition_error(mock_ecs_api): + """Test collect_cluster_configuration with task definition error.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices response + { + "services": [ + { + "serviceName": "test-service", + "serviceArn": ( + "arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service" + ), + "status": "ACTIVE", + "taskDefinition": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "desiredCount": 2, + "networkConfiguration": {}, + "tags": [], + } + ] + }, + # DescribeTaskDefinition fails + Exception("Task definition error"), + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should still return data but with errors + assert result["cluster_name"] == "test-cluster" + assert len(result["collection_errors"]) > 0 + assert any("Task definition error" in str(e) for e in result["collection_errors"]) + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_service_config_error(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with service config building error.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices response with malformed service + { + "services": [ + { + # Missing serviceName to trigger error in processing + "serviceArn": ( + "arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service" + ), + "status": "ACTIVE", + } + ] + }, + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should handle the error gracefully + assert result["cluster_name"] == "test-cluster" + # May have errors depending on how the service is processed + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_with_draining_service(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with DRAINING service status.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices response with DRAINING status + { + "services": [ + { + "serviceName": "test-service", + "serviceArn": ( + "arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service" + ), + "status": "DRAINING", + "taskDefinition": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "desiredCount": 0, + "networkConfiguration": {}, + "tags": [], + } + ] + }, + # DescribeTaskDefinition response + { + "taskDefinition": { + "family": "test-task", + "taskDefinitionArn": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "revision": 1, + "status": "ACTIVE", + "networkMode": "awsvpc", + "containerDefinitions": [], + }, + "tags": [], + }, + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should successfully collect even with DRAINING service + assert result["cluster_name"] == "test-cluster" + assert len(result["services"]) == 1 + assert result["services"][0]["status"] == "DRAINING" + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_with_inactive_service(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with INACTIVE service status.""" + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response + {"serviceArns": ["arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service"]}, + # DescribeServices response with INACTIVE status + { + "services": [ + { + "serviceName": "test-service", + "serviceArn": ( + "arn:aws:ecs:us-east-1:123456789012:service/test-cluster/test-service" + ), + "status": "INACTIVE", + "taskDefinition": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "desiredCount": 0, + "networkConfiguration": {}, + "tags": [], + } + ] + }, + # DescribeTaskDefinition response + { + "taskDefinition": { + "family": "test-task", + "taskDefinitionArn": ( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-task:1" + ), + "revision": 1, + "status": "ACTIVE", + "networkMode": "awsvpc", + "containerDefinitions": [], + }, + "tags": [], + }, + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should successfully collect even with INACTIVE service + assert result["cluster_name"] == "test-cluster" + assert len(result["services"]) == 1 + assert result["services"][0]["status"] == "INACTIVE" + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.api.security_analysis.ecs_api_operation") +@patch("awslabs.ecs_mcp_server.api.security_analysis.boto3") +async def test_collect_cluster_configuration_batch_services(mock_boto3, mock_ecs_api): + """Test collect_cluster_configuration with multiple service batches.""" + # Create 15 service ARNs to test batch processing (batch size is 10) + service_arns = [ + f"arn:aws:ecs:us-east-1:123456789012:service/test-cluster/service-{i}" for i in range(15) + ] + + # Mock ECS API responses + mock_ecs_api.side_effect = [ + # DescribeClusters response + { + "clusters": [ + { + "clusterName": "test-cluster", + "clusterArn": "arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster", + "status": "ACTIVE", + "tags": [], + "settings": [], + } + ] + }, + # ListServices response with 15 services + {"serviceArns": service_arns}, + # First batch (10 services) + { + "services": [ + { + "serviceName": f"service-{i}", + "serviceArn": service_arns[i], + "status": "ACTIVE", + "taskDefinition": ( + f"arn:aws:ecs:us-east-1:123456789012:task-definition/task-{i}:1" + ), + "desiredCount": 1, + "networkConfiguration": {}, + "tags": [], + } + for i in range(10) + ] + }, + # Second batch (5 services) + { + "services": [ + { + "serviceName": f"service-{i}", + "serviceArn": service_arns[i], + "status": "ACTIVE", + "taskDefinition": ( + f"arn:aws:ecs:us-east-1:123456789012:task-definition/task-{i}:1" + ), + "desiredCount": 1, + "networkConfiguration": {}, + "tags": [], + } + for i in range(10, 15) + ] + }, + ] + [ + # DescribeTaskDefinition responses for each unique task definition + { + "taskDefinition": { + "family": f"task-{i}", + "taskDefinitionArn": ( + f"arn:aws:ecs:us-east-1:123456789012:task-definition/task-{i}:1" + ), + "revision": 1, + "status": "ACTIVE", + "networkMode": "awsvpc", + "containerDefinitions": [], + }, + "tags": [], + } + for i in range(15) + ] + + result = await collect_cluster_configuration("us-east-1", "test-cluster") + + # Should successfully collect all 15 services across 2 batches + assert result["cluster_name"] == "test-cluster" + assert len(result["services"]) == 15 + assert len(result["task_definitions"]) == 15 diff --git a/src/ecs-mcp-server/tests/unit/test_security_analysis_module.py b/src/ecs-mcp-server/tests/unit/test_security_analysis_module.py new file mode 100644 index 0000000000..a9734bdd46 --- /dev/null +++ b/src/ecs-mcp-server/tests/unit/test_security_analysis_module.py @@ -0,0 +1,325 @@ +""" +Unit tests for security analysis module registration. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import FastMCP + + +def test_module_registration(): + """Test that security_analysis module registers correctly.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Register the module + security_analysis.register_module(mock_mcp) + + # Verify tool was registered + assert mock_mcp.tool.called + # Verify prompts were registered + assert mock_mcp.prompt.called + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.modules.security_analysis.get_target_region") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.get_clusters_with_metadata") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.format_clusters_for_display") +async def test_tool_execution_list_clusters(mock_format, mock_list_clusters, mock_get_region): + """Test tool execution for listing clusters.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Setup mocks + mock_get_region.return_value = "us-east-1" + mock_list_clusters.return_value = [ + { + "cluster_name": "test-cluster", + "status": "ACTIVE", + "running_tasks_count": 5, + "active_services_count": 3, + } + ] + mock_format.return_value = "Formatted cluster list" + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Track the registered tool function + registered_tool = None + + def capture_tool(name, annotations): + def decorator(func): + nonlocal registered_tool + registered_tool = func + return func + + return decorator + + mock_mcp.tool = capture_tool + + # Register the module + security_analysis.register_module(mock_mcp) + + # Execute the tool (explicitly pass cluster_names=None for list mode) + result = await registered_tool(cluster_names=None) + + # Verify the result + assert result == "Formatted cluster list" + mock_get_region.assert_called_once() + mock_list_clusters.assert_called_once_with("us-east-1") + mock_format.assert_called_once() + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.modules.security_analysis.get_target_region") +async def test_tool_execution_error_handling(mock_get_region): + """Test tool execution error handling.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Setup mock to raise an exception + mock_get_region.side_effect = Exception("Test error") + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Track the registered tool function + registered_tool = None + + def capture_tool(name, annotations): + def decorator(func): + nonlocal registered_tool + registered_tool = func + return func + + return decorator + + mock_mcp.tool = capture_tool + + # Register the module + security_analysis.register_module(mock_mcp) + + # Execute the tool + result = await registered_tool() + + # Verify error message is returned + assert "❌" in result + assert "Test error" in result + + +def test_prompt_patterns_registered(): + """Test that all prompt patterns are registered.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Track registered prompts + registered_prompts = [] + + def capture_prompt(pattern): + def decorator(func): + registered_prompts.append(pattern) + return func + + return decorator + + mock_mcp.prompt = capture_prompt + + # Register the module + security_analysis.register_module(mock_mcp) + + # Verify expected prompts are registered + expected_prompts = [ + "analyze ecs security", + "check ecs security", + "ecs security audit", + "security best practices", + "security recommendations", + "scan ecs clusters", + "ecs security scan", + "list ecs clusters", + ] + + for expected in expected_prompts: + assert expected in registered_prompts + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.modules.security_analysis.get_target_region") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.validate_clusters") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.collect_cluster_configuration") +async def test_tool_execution_collect_configuration( + mock_collect_config, mock_validate, mock_get_region +): + """Test tool execution for collecting cluster configuration.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Setup mocks + mock_get_region.return_value = "us-east-1" + mock_validate.return_value = ["arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster"] + mock_collect_config.return_value = { + "cluster_name": "test-cluster", + "region": "us-east-1", + "cluster_metadata": {"status": "ACTIVE"}, + "services": [], + "task_definitions": [], + "security_groups": [], + "collection_errors": [], + } + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Track the registered tool function + registered_tool = None + + def capture_tool(name, annotations): + def decorator(func): + nonlocal registered_tool + registered_tool = func + return func + + return decorator + + mock_mcp.tool = capture_tool + + # Register the module + security_analysis.register_module(mock_mcp) + + # Execute the tool with cluster_names + result = await registered_tool(cluster_names=["test-cluster"]) + + # Verify the result is JSON + import json + + result_data = json.loads(result) + assert result_data["analysis_type"] == "ecs_security_configuration" + assert result_data["region"] == "us-east-1" + assert result_data["clusters_analyzed"] == 1 + assert len(result_data["cluster_configurations"]) == 1 + + # Verify mocks were called + mock_get_region.assert_called_once() + mock_validate.assert_called_once_with(["test-cluster"], "us-east-1") + mock_collect_config.assert_called_once_with("us-east-1", "test-cluster") + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.modules.security_analysis.get_target_region") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.validate_clusters") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.collect_cluster_configuration") +async def test_tool_execution_multiple_clusters( + mock_collect_config, mock_validate, mock_get_region +): + """Test tool execution with multiple clusters.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Setup mocks + mock_get_region.return_value = "us-east-1" + mock_validate.return_value = [ + "arn:aws:ecs:us-east-1:123456789012:cluster/cluster-1", + "arn:aws:ecs:us-east-1:123456789012:cluster/cluster-2", + ] + mock_collect_config.side_effect = [ + { + "cluster_name": "cluster-1", + "region": "us-east-1", + "cluster_metadata": {}, + "services": [], + "task_definitions": [], + "security_groups": [], + "collection_errors": [], + }, + { + "cluster_name": "cluster-2", + "region": "us-east-1", + "cluster_metadata": {}, + "services": [], + "task_definitions": [], + "security_groups": [], + "collection_errors": [], + }, + ] + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Track the registered tool function + registered_tool = None + + def capture_tool(name, annotations): + def decorator(func): + nonlocal registered_tool + registered_tool = func + return func + + return decorator + + mock_mcp.tool = capture_tool + + # Register the module + security_analysis.register_module(mock_mcp) + + # Execute the tool with multiple clusters + result = await registered_tool(cluster_names=["cluster-1", "cluster-2"]) + + # Verify the result + import json + + result_data = json.loads(result) + assert result_data["clusters_analyzed"] == 2 + assert len(result_data["cluster_configurations"]) == 2 + + # Verify collect_config was called twice + assert mock_collect_config.call_count == 2 + + +@pytest.mark.anyio +@patch("awslabs.ecs_mcp_server.modules.security_analysis.get_target_region") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.validate_clusters") +@patch("awslabs.ecs_mcp_server.modules.security_analysis.collect_cluster_configuration") +async def test_tool_execution_with_collection_error( + mock_collect_config, mock_validate, mock_get_region +): + """Test tool execution when configuration collection fails for a cluster.""" + from awslabs.ecs_mcp_server.modules import security_analysis + + # Setup mocks + mock_get_region.return_value = "us-east-1" + mock_validate.return_value = ["arn:aws:ecs:us-east-1:123456789012:cluster/test-cluster"] + mock_collect_config.side_effect = Exception("Collection failed") + + # Create a mock FastMCP instance + mock_mcp = MagicMock(spec=FastMCP) + + # Track the registered tool function + registered_tool = None + + def capture_tool(name, annotations): + def decorator(func): + nonlocal registered_tool + registered_tool = func + return func + + return decorator + + mock_mcp.tool = capture_tool + + # Register the module + security_analysis.register_module(mock_mcp) + + # Execute the tool with cluster_names + result = await registered_tool(cluster_names=["test-cluster"]) + + # Verify the result includes error information + import json + + result_data = json.loads(result) + assert result_data["clusters_analyzed"] == 1 + assert len(result_data["cluster_configurations"]) == 1 + # Should have error config + assert "collection_error" in result_data["cluster_configurations"][0] + assert "Collection failed" in result_data["cluster_configurations"][0]["collection_error"]