diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index 88d5e605..f86a8799 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -10,7 +10,11 @@ import boto3 from pydantic import BaseModel -from bedrock_agentcore._utils.endpoints import CP_ENDPOINT_OVERRIDE, DP_ENDPOINT_OVERRIDE +from bedrock_agentcore._utils.endpoints import ( + CP_ENDPOINT_OVERRIDE, + DP_ENDPOINT_OVERRIDE, +) +from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs class TokenPoller(ABC): @@ -85,15 +89,48 @@ def __init__(self, region: str): self.dp_client = boto3.client("bedrock-agentcore", **dp_kwargs) self.logger = logging.getLogger("bedrock_agentcore.identity_client") - def create_oauth2_credential_provider(self, req): - """Create an OAuth2 credential provider.""" - self.logger.info("Creating OAuth2 credential provider...") - return self.cp_client.create_oauth2_credential_provider(**req) - - def create_api_key_credential_provider(self, req): - """Create an API key credential provider.""" - self.logger.info("Creating API key credential provider...") - return self.cp_client.create_api_key_credential_provider(**req) + # Pass-through + # ------------------------------------------------------------------------- + _ALLOWED_CP_METHODS = { + # OAuth2 credential provider CRUD + "create_oauth2_credential_provider", + "get_oauth2_credential_provider", + "list_oauth2_credential_providers", + "update_oauth2_credential_provider", + "delete_oauth2_credential_provider", + # API key credential provider CRUD + "create_api_key_credential_provider", + "get_api_key_credential_provider", + "list_api_key_credential_providers", + "delete_api_key_credential_provider", + # Workload identity + "get_workload_identity", + "update_workload_identity", + } + + _ALLOWED_DP_METHODS = { + "get_resource_oauth2_token", + "get_resource_api_key", + "get_workload_access_token_for_jwt", + "get_workload_access_token_for_user_id", + } + + def __getattr__(self, name: str): + """Dynamically forward allowlisted method calls to the appropriate boto3 client.""" + if name in self._ALLOWED_DP_METHODS and hasattr(self.dp_client, name): + method = getattr(self.dp_client, name) + return accept_snake_case_kwargs(method) + + if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name): + method = getattr(self.cp_client, name) + return accept_snake_case_kwargs(method) + + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'. " + f"Method not found on data plane or control plane client. " + f"Available methods can be found in the boto3 documentation for " + f"'bedrock-agentcore' and 'bedrock-agentcore-control' services." + ) def get_workload_access_token( self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None @@ -125,20 +162,6 @@ def create_workload_identity( name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls or [] ) - def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> Dict: - """Update an existing workload identity with allowed resource OAuth2 callback urls.""" - self.logger.info( - "Updating workload identity '%s' with callback urls: %s", name, allowed_resource_oauth_2_return_urls - ) - return self.cp_client.update_workload_identity( - name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls - ) - - def get_workload_identity(self, name: str) -> Dict: - """Retrieves information about a workload identity.""" - self.logger.info("Fetching workload identity '%s'", name) - return self.cp_client.get_workload_identity(name=name) - def complete_resource_token_auth( self, session_uri: str, user_identifier: Union[UserTokenIdentifier, UserIdIdentifier] ): diff --git a/tests/bedrock_agentcore/services/test_identity.py b/tests/bedrock_agentcore/services/test_identity.py index 1c03f8a8..a3b72f42 100644 --- a/tests/bedrock_agentcore/services/test_identity.py +++ b/tests/bedrock_agentcore/services/test_identity.py @@ -32,44 +32,54 @@ def test_initialization(self): ) def test_create_oauth2_credential_provider(self): - """Test OAuth2 credential provider creation.""" + """Test OAuth2 credential provider creation via passthrough.""" region = "us-west-2" with patch("boto3.client") as mock_boto_client: - mock_client = Mock() - mock_boto_client.return_value = mock_client + mock_cp_client = Mock() + mock_dp_client = Mock() + mock_boto_client.side_effect = [mock_cp_client, mock_dp_client] identity_client = IdentityClient(region) - # Test data - req = {"name": "test-provider", "clientId": "test-client"} expected_response = {"providerId": "test-provider-id"} - mock_client.create_oauth2_credential_provider.return_value = expected_response + mock_cp_client.create_oauth2_credential_provider.return_value = expected_response - result = identity_client.create_oauth2_credential_provider(req) + result = identity_client.create_oauth2_credential_provider( + name="test-provider", + clientId="test-client", + ) assert result == expected_response - mock_client.create_oauth2_credential_provider.assert_called_once_with(**req) + mock_cp_client.create_oauth2_credential_provider.assert_called_once_with( + name="test-provider", + clientId="test-client", + ) def test_create_api_key_credential_provider(self): - """Test API key credential provider creation.""" + """Test API key credential provider creation via passthrough.""" region = "us-west-2" with patch("boto3.client") as mock_boto_client: - mock_client = Mock() - mock_boto_client.return_value = mock_client + mock_cp_client = Mock() + mock_dp_client = Mock() + mock_boto_client.side_effect = [mock_cp_client, mock_dp_client] identity_client = IdentityClient(region) - # Test data - req = {"name": "test-api-provider", "apiKeyName": "test-key"} expected_response = {"providerId": "test-api-provider-id"} - mock_client.create_api_key_credential_provider.return_value = expected_response + mock_cp_client.create_api_key_credential_provider.return_value = expected_response - result = identity_client.create_api_key_credential_provider(req) + result = identity_client.create_api_key_credential_provider( + name="test-api-provider", + apiKeyName="test-key", + ) assert result == expected_response - mock_client.create_api_key_credential_provider.assert_called_once_with(**req) + mock_cp_client.create_api_key_credential_provider.assert_called_once_with( + name="test-api-provider", + apiKeyName="test-key", + ) @pytest.mark.asyncio async def test_get_token_direct_response(self): @@ -505,11 +515,15 @@ def test_update_workload_identity(self): mock_cp_client.update_workload_identity.return_value = expected_response - result = identity_client.update_workload_identity(workload_name, allowed_urls) + result = identity_client.update_workload_identity( + name=workload_name, + allowedResourceOauth2ReturnUrls=allowed_urls, + ) assert result == expected_response mock_cp_client.update_workload_identity.assert_called_once_with( - name=workload_name, allowedResourceOauth2ReturnUrls=allowed_urls + name=workload_name, + allowedResourceOauth2ReturnUrls=allowed_urls, ) def test_get_workload_identity(self): @@ -523,15 +537,16 @@ def test_get_workload_identity(self): identity_client = IdentityClient(region) workload_name = "test-workload" - allowed_urls = ["https://unit-test.com/callback", "https://test.com/oauth"] - expected_response = {"name": workload_name, "allowedResourceOauth2ReturnUrls": allowed_urls} + expected_response = {"name": workload_name} mock_cp_client.get_workload_identity.return_value = expected_response - result = identity_client.get_workload_identity(workload_name) + result = identity_client.get_workload_identity(name=workload_name) assert result == expected_response - mock_cp_client.get_workload_identity.assert_called_once_with(name=workload_name) + mock_cp_client.get_workload_identity.assert_called_once_with( + name=workload_name, + ) def test_complete_resource_token_auth_with_user_id(self): region = "us-west-2" diff --git a/tests_integ/identity/test_identity_client.py b/tests_integ/identity/test_identity_client.py new file mode 100644 index 00000000..2ebd7404 --- /dev/null +++ b/tests_integ/identity/test_identity_client.py @@ -0,0 +1,136 @@ +"""Integration tests for IdentityClient passthrough and __getattr__ methods.""" + +import os +import time + +import pytest +from botocore.exceptions import ClientError + +from bedrock_agentcore._utils.polling import wait_until_deleted +from bedrock_agentcore.services.identity import IdentityClient + + +@pytest.mark.integration +class TestIdentityClientPassthrough: + """Integration tests for IdentityClient passthrough via __getattr__. + + Tests read-only operations that don't require pre-existing resources. + """ + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.client = IdentityClient(region=cls.region) + + @pytest.mark.order(1) + def test_list_oauth2_credential_providers_passthrough(self): + response = self.client.list_oauth2_credential_providers() + assert "credentialProviders" in response + + @pytest.mark.order(2) + def test_list_api_key_credential_providers_passthrough(self): + response = self.client.list_api_key_credential_providers() + assert "credentialProviders" in response + + @pytest.mark.order(3) + def test_list_oauth2_snake_case(self): + response = self.client.list_oauth2_credential_providers( + max_results=10, + ) + assert "credentialProviders" in response + + @pytest.mark.order(4) + def test_get_nonexistent_oauth2_provider(self): + with pytest.raises(ClientError) as exc_info: + self.client.get_oauth2_credential_provider( + name="nonexistent-provider", + ) + assert exc_info.value.response["Error"]["Code"] in ( + "ResourceNotFoundException", + "AccessDeniedException", + ) + + @pytest.mark.order(5) + def test_get_nonexistent_api_key_provider(self): + with pytest.raises(ClientError) as exc_info: + self.client.get_api_key_credential_provider( + name="nonexistent-provider", + ) + assert exc_info.value.response["Error"]["Code"] in ( + "ResourceNotFoundException", + "AccessDeniedException", + ) + + @pytest.mark.order(6) + def test_non_allowlisted_method_raises(self): + with pytest.raises(AttributeError): + self.client.not_a_real_method() + + +@pytest.mark.integration +class TestIdentityClientOauth2Crud: + """Integration tests for OAuth2 credential provider CRUD via passthrough. + + Requires COGNITO_POOL_ID, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET. + """ + + @classmethod + def setup_class(cls): + cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") + cls.pool_id = os.environ.get("COGNITO_POOL_ID") + cls.client_id = os.environ.get("COGNITO_CLIENT_ID") + cls.client_secret = os.environ.get("COGNITO_CLIENT_SECRET") + if not all([cls.pool_id, cls.client_id, cls.client_secret]): + pytest.skip("COGNITO_POOL_ID, COGNITO_CLIENT_ID, and COGNITO_CLIENT_SECRET must all be set") + cls.client = IdentityClient(region=cls.region) + cls.discovery_url = ( + f"https://cognito-idp.{cls.region}.amazonaws.com/{cls.pool_id}/.well-known/openid-configuration" + ) + cls.provider_name = f"sdk-integ-{int(time.time())}" + + @classmethod + def teardown_class(cls): + try: + cls.client.delete_oauth2_credential_provider( + name=cls.provider_name, + ) + except Exception as e: + print(f"Teardown: {e}") + + @pytest.mark.order(10) + def test_create_oauth2_credential_provider(self): + self.client.create_oauth2_credential_provider( + name=self.provider_name, + credentialProviderVendor="CustomOauth2", + oauth2ProviderConfigInput={ + "customOauth2ProviderConfig": { + "oauthDiscovery": { + "discoveryUrl": self.discovery_url, + }, + "clientId": self.client_id, + "clientSecret": self.client_secret, + } + }, + ) + provider = self.client.get_oauth2_credential_provider( + name=self.provider_name, + ) + assert provider["name"] == self.provider_name + + @pytest.mark.order(11) + def test_get_oauth2_provider_passthrough(self): + provider = self.client.get_oauth2_credential_provider( + name=self.provider_name, + ) + assert provider["name"] == self.provider_name + + @pytest.mark.order(12) + def test_delete_oauth2_credential_provider(self): + self.client.delete_oauth2_credential_provider( + name=self.provider_name, + ) + wait_until_deleted( + lambda: self.client.get_oauth2_credential_provider( + name=self.provider_name, + ), + )