diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..3653525 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,41 @@ +name: Python Tests +on: + push: + branches: + - main + pull_request: +concurrency: + group: ${{ github.event.pull_request.number || github.ref }}-tests +permissions: + contents: read +env: + AZURE_CLIENT_SECRET: ${{ secrets.IDP_CLIENT_CREDENTIAL }} + AZURE_CLIENT_ID: ${{ secrets.IDP_CLIENT_ID }} + AZURE_TENANT_ID: ${{ secrets.IDP_TENANT_ID }} +jobs: + tests: + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy-3.9', 'pypy-3.10'] + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r dev_requirements.txt + - name: Run tests with Python version ${{ matrix.python-version }} + run: | + pytest --junitxml=test-results.xml -m "not managed_identity" + - name: Upload test results + if: matrix.python-version == '3.12' + uses: actions/upload-artifact@v4 + with: + name: test-results + path: test-results.xml \ No newline at end of file diff --git a/README.md b/README.md index 5481fb0..06d1394 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ You need to install the `redis-py` Entra ID package via the following command: pip install redis-entra-id ``` -The package depends on [redis-py](https://github.com/redis/redis-py/tree/v5.3.0b4) version `5.3.0b4`. +The package depends on [redis-py](https://github.com/redis/redis-py). ## Usage @@ -44,49 +44,56 @@ The package depends on [redis-py](https://github.com/redis/redis-py/tree/v5.3.0b After having installed the package, you can import its modules: ```python -import redis -from redis_entraid import identity_provider -from redis_entraid import cred_provider +from redis import Redis +from redis_entraid.cred_provider import * ``` -### Step 2 - Define your authority based on the tenant ID +### Step 2 - Create the credential provider via the factory method ```python -authority = "{}/{}".format("https://login.microsoftonline.com", "") +credential_provider = create_from_service_principal( + CLIENT_ID, + CLIENT_SECRET, + TENANT_ID +) ``` -> This step is going to be removed in the next pre-release version of `redis-py-entraid`. Instead, the factory method will allow to pass the tenant id direclty. +### Step 3 - Provide optional token renewal configuration -### Step 3 - Create the identity provider via the factory method - -```python -idp = identity_provider.create_provider_from_service_principal("", "", authority=authority) -``` - -### Step 4 - Initialize a credentials provider from the authentication configuration - -You can use the default configuration or customize the background task for token renewal. +The default configuration would be applied, but you're able to customise it. ```python -auth_config = TokenAuthConfig(idp) -cred_provider = EntraIdCredentialsProvider(auth_config) +credential_provider = create_from_service_principal( + CLIENT_ID, + CLIENT_SECRET, + TENANT_ID, + token_manager_config=TokenManagerConfig( + expiration_refresh_ratio=0.9, + lower_refresh_bound_millis=DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + retry_policy=RetryPolicy( + max_attempts=5, + delay_in_ms=50 + ) + ) +) ``` You can test the credentials provider by obtaining a token. The following example demonstrates both, a synchronous and an asynchronous approach: ```python # Synchronous -cred_provider.get_credentials() +credential_provider.get_credentials() # Asynchronous -await cred_provider.get_credentials_async() +await credential_provider.get_credentials_async() ``` -### Step 5 - Connect to Redis +### Step 4 - Connect to Redis When using Entra ID, Azure enforces TLS on your Redis connection. Here is an example that shows how to **test** the connection in an insecure way: ```python -client = redis.Redis(host="", port=, ssl=True, ssl_cert_reqs=None, credential_provider=cred_provider) +client = Redis(host=HOST, port=PORT, ssl=True, ssl_cert_reqs=None, credential_provider=credential_provider) print("The database size is: {}".format(client.dbsize())) ``` diff --git a/pytest.ini b/pytest.ini index d280de0..53f543b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,4 @@ [pytest] -asyncio_mode = auto \ No newline at end of file +asyncio_mode = auto +markers = + managed_identity: Tests that should be run on Azure VM to be able to reach managed identity service. \ No newline at end of file diff --git a/redis_entraid/cred_provider.py b/redis_entraid/cred_provider.py index ab706a2..7d4db6a 100644 --- a/redis_entraid/cred_provider.py +++ b/redis_entraid/cred_provider.py @@ -1,63 +1,40 @@ -from dataclasses import dataclass -from typing import Union, Tuple, Callable, Any, Awaitable +from typing import Union, Tuple, Callable, Any, Awaitable, Optional, List from redis.credentials import StreamingCredentialProvider from redis.auth.token_manager import TokenManagerConfig, RetryPolicy, TokenManager, CredentialsListener -from redis_entraid.identity_provider import EntraIDIdentityProvider - - -@dataclass -class TokenAuthConfig: - """ - Configuration for token authentication. - - Requires :class:`EntraIDIdentityProvider`. It's recommended to use an additional factory methods. - See :class:`EntraIDIdentityProvider` for more information. - """ - DEFAULT_EXPIRATION_REFRESH_RATIO = 0.8 - DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0 - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 100 - DEFAULT_MAX_ATTEMPTS = 3 - DEFAULT_DELAY_IN_MS = 3 - - idp: EntraIDIdentityProvider - expiration_refresh_ratio: float = DEFAULT_EXPIRATION_REFRESH_RATIO - lower_refresh_bound_millis: int = DEFAULT_LOWER_REFRESH_BOUND_MILLIS - token_request_execution_timeout_in_ms: int = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS - max_attempts: int = DEFAULT_MAX_ATTEMPTS - delay_in_ms: int = DEFAULT_DELAY_IN_MS - - def get_token_manager_config(self) -> TokenManagerConfig: - return TokenManagerConfig( - self.expiration_refresh_ratio, - self.lower_refresh_bound_millis, - self.token_request_execution_timeout_in_ms, - RetryPolicy( - self.max_attempts, - self.delay_in_ms - ) - ) - - def get_identity_provider(self) -> EntraIDIdentityProvider: - return self.idp +from redis_entraid.identity_provider import ManagedIdentityType, ManagedIdentityIdType, \ + _create_provider_from_managed_identity, ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig, \ + _create_provider_from_service_principal +DEFAULT_EXPIRATION_REFRESH_RATIO = 0.7 +DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0 +DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 100 +DEFAULT_MAX_ATTEMPTS = 3 +DEFAULT_DELAY_IN_MS = 3 class EntraIdCredentialsProvider(StreamingCredentialProvider): def __init__( self, - config: TokenAuthConfig, + idp_config: Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig], + token_manager_config: TokenManagerConfig, initial_delay_in_ms: float = 0, block_for_initial: bool = False, ): """ - :param config: + :param idp_config: Identity provider specific configuration. + :param token_manager_config: Token manager specific configuration. :param initial_delay_in_ms: Initial delay before run background refresh (valid for async only) :param block_for_initial: Block execution until initial token will be acquired (valid for async only) """ + if isinstance(idp_config, ManagedIdentityProviderConfig): + idp = _create_provider_from_managed_identity(idp_config) + else: + idp = _create_provider_from_service_principal(idp_config) + self._token_mgr = TokenManager( - config.get_identity_provider(), - config.get_token_manager_config() + idp, + token_manager_config ) self._listener = CredentialsListener() self._is_streaming = False @@ -65,6 +42,9 @@ def __init__( self._block_for_initial = block_for_initial def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + """ + Acquire token from the identity provider. + """ init_token = self._token_mgr.acquire_token() if self._is_streaming is False: @@ -77,6 +57,9 @@ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: return init_token.get_token().try_get('oid'), init_token.get_token().get_value() async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]: + """ + Acquire token from the identity provider in async mode. + """ init_token = await self._token_mgr.acquire_token_async() if self._is_streaming is False: @@ -98,3 +81,84 @@ def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]): def is_streaming(self) -> bool: return self._is_streaming + + +def create_from_managed_identity( + identity_type: ManagedIdentityType, + resource: str, + id_type: Optional[ManagedIdentityIdType] = None, + id_value: Optional[str] = '', + kwargs: Optional[dict] = {}, + token_manager_config: Optional[TokenManagerConfig] = TokenManagerConfig( + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + RetryPolicy( + DEFAULT_MAX_ATTEMPTS, + DEFAULT_DELAY_IN_MS + ) + ) +) -> EntraIdCredentialsProvider: + """ + Create a credential provider from a managed identity type. + + :param identity_type: Managed identity type. + :param resource: Identity provider resource. + :param id_type: Identity provider type. + :param id_value: Identity provider value. + :param kwargs: Optional keyword arguments to pass to identity provider. See: :class:`ManagedIdentityClient` + :param token_manager_config: Token manager specific configuration. + :return: EntraIdCredentialsProvider instance. + """ + managed_identity_config = ManagedIdentityProviderConfig( + identity_type=identity_type, + resource=resource, + id_type=id_type, + id_value=id_value, + kwargs=kwargs + ) + + return EntraIdCredentialsProvider(managed_identity_config, token_manager_config) + + +def create_from_service_principal( + client_id: str, + client_credential: Any, + tenant_id: Optional[str] = None, + scopes: Optional[List[str]] = None, + timeout: Optional[float] = None, + token_kwargs: Optional[dict] = {}, + app_kwargs: Optional[dict] = {}, + token_manager_config: Optional[TokenManagerConfig] = TokenManagerConfig( + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + RetryPolicy( + DEFAULT_MAX_ATTEMPTS, + DEFAULT_DELAY_IN_MS + ) + )) -> EntraIdCredentialsProvider: + """ + Create a credential provider from a service principal. + + :param client_credential: Service principal credentials. + :param client_id: Service principal client ID. + :param scopes: Service principal scopes. Fallback to default scopes if None. + :param timeout: Service principal timeout. + :param tenant_id: Service principal tenant ID. + :param token_kwargs: Optional token arguments to pass to service identity provider. + :param app_kwargs: Optional keyword arguments to pass to service principal application. + :param token_manager_config: Token manager specific configuration. + :return: EntraIdCredentialsProvider instance. + """ + service_principal_config = ServicePrincipalIdentityProviderConfig( + client_credential=client_credential, + client_id=client_id, + scopes=scopes, + timeout=timeout, + tenant_id=tenant_id, + app_kwargs=app_kwargs, + token_kwargs=token_kwargs, + ) + + return EntraIdCredentialsProvider(service_principal_config, token_manager_config) \ No newline at end of file diff --git a/redis_entraid/identity_provider.py b/redis_entraid/identity_provider.py index c8f4f66..f3a5c45 100644 --- a/redis_entraid/identity_provider.py +++ b/redis_entraid/identity_provider.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass, field from enum import Enum -from typing import Optional, Union, Callable +from typing import Optional, Union, Callable, Any, List import requests from msal import ( @@ -24,6 +25,26 @@ class ManagedIdentityIdType(Enum): RESOURCE_ID = "resource_id" +@dataclass +class ManagedIdentityProviderConfig: + identity_type: ManagedIdentityType + resource: str + id_type: Optional[ManagedIdentityIdType] = None + id_value: Optional[str] = '' + kwargs: Optional[dict] = field(default_factory=dict) + + +@dataclass +class ServicePrincipalIdentityProviderConfig: + client_credential: Any + client_id: str + scopes: Optional[List[str]] = None + timeout: Optional[float] = None + tenant_id: Optional[str] = None + token_kwargs: Optional[dict] = None + app_kwargs: Optional[dict] = field(default_factory=dict) + + class EntraIDIdentityProvider(IdentityProviderInterface): """ EntraID Identity Provider implementation. @@ -34,7 +55,7 @@ class EntraIDIdentityProvider(IdentityProviderInterface): def __init__( self, app: Union[ManagedIdentityClient, ConfidentialClientApplication], - scopes : list = [], + scopes : List = [], resource: str = '', **kwargs ): @@ -75,70 +96,54 @@ def _get_token(self, callback: Callable, **kwargs) -> JWToken: raise RequestTokenErr(e) -def create_provider_from_managed_identity( - identity_type: ManagedIdentityType, - resource: str, - id_type: Optional[ManagedIdentityIdType] = None, - id_value: Optional[str] = '', - **kwargs -) -> EntraIDIdentityProvider: +def _create_provider_from_managed_identity(config: ManagedIdentityProviderConfig) -> EntraIDIdentityProvider: """ Create an EntraID identity provider following Managed Identity auth flow. - :param identity_type: User Assigned or System Assigned. - :param resource: Resource for which token should be acquired. - :param id_type: Required for User Assigned identity type only. - :param id_value: Required for User Assigned identity type only. - :param kwargs: Additional arguments you may need during specify to request token. + :param config: Config for managed assigned identity provider See: :class:`ManagedIdentityClient` acquire_token_for_client method. :return: :class:`EntraIDIdentityProvider` """ - if identity_type == ManagedIdentityType.USER_ASSIGNED: - if id_type is None or id_value == '': + if config.identity_type == ManagedIdentityType.USER_ASSIGNED: + if config.id_type is None or config.id_value == '': raise ValueError("Id_type and id_value are required for User Assigned identity auth") kwargs = { - id_type.value: id_value + config.id_type.value: config.id_value } - managed_identity = identity_type.value(**kwargs) + managed_identity = config.identity_type.value(**kwargs) else: - managed_identity = identity_type.value() + managed_identity = config.identity_type.value() app = ManagedIdentityClient(managed_identity, http_client=requests.Session()) - return EntraIDIdentityProvider(app, [], resource, **kwargs) + return EntraIDIdentityProvider(app, [], config.resource, **config.kwargs) -def create_provider_from_service_principal( - client_credential, - client_id: str, - scopes: list = [], - timeout: Optional[float] = None, - token_kwargs: dict = {}, - **app_kwargs -) -> EntraIDIdentityProvider: +def _create_provider_from_service_principal(config: ServicePrincipalIdentityProviderConfig) -> EntraIDIdentityProvider: """ Create an EntraID identity provider following Service Principal auth flow. - :param client_credential: Can be secret string, PEM certificate and more. - See: :class:`ConfidentialClientApplication`. + :param config: Config for service principal identity provider - :param client_id: Application (Client) ID. - :param scopes: If no scopes will be provided, default will be used. - :param timeout: Timeout in seconds. - :param token_kwargs: Additional arguments you may need during token request. - :param app_kwargs: Additional arguments you may need to configure an application. :return: :class:`EntraIDIdentityProvider` + See: :class:`ConfidentialClientApplication`. """ - if len(scopes) == 0: - scopes.append("https://redis.azure.com/.default") + if config.scopes is None: + scopes = ["https://redis.azure.com/.default"] + else: + scopes = config.scopes + + authority = f"https://login.microsoftonline.com/{config.tenant_id}" \ + if config.tenant_id is not None else config.tenant_id app = ConfidentialClientApplication( - client_id=client_id, - client_credential=client_credential, - timeout=timeout, - **app_kwargs + client_id=config.client_id, + client_credential=config.client_credential, + timeout=config.timeout, + authority=authority, + **config.app_kwargs ) - return EntraIDIdentityProvider(app, scopes, **token_kwargs) + return EntraIDIdentityProvider(app, scopes, **config.token_kwargs) diff --git a/requirements.txt b/requirements.txt index fa28e3a..6574314 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ PyJWT~=2.9.0 msal~=1.31.0 -redis @ git+https://github.com/redis/redis-py.git/@vv-tba-support +redis==5.3.0b4 requests~=2.32.3 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f49a20f..1455563 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,17 @@ import os from enum import Enum +from typing import Union import pytest -from _pytest.fixtures import SubRequest from redis import CredentialProvider -from redis.auth.idp import IdentityProviderInterface +from redis.auth.token_manager import TokenManagerConfig, RetryPolicy -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig -from redis_entraid.identity_provider import ManagedIdentityType, create_provider_from_managed_identity, \ - create_provider_from_service_principal, EntraIDIdentityProvider, ManagedIdentityIdType +from redis_entraid.cred_provider import DEFAULT_EXPIRATION_REFRESH_RATIO, \ + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, DEFAULT_MAX_ATTEMPTS, DEFAULT_DELAY_IN_MS, \ + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, create_from_service_principal, create_from_managed_identity +from redis_entraid.identity_provider import ManagedIdentityType, EntraIDIdentityProvider, ManagedIdentityIdType, \ + ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig, _create_provider_from_managed_identity, \ + _create_provider_from_service_principal class AuthType(Enum): @@ -16,7 +19,7 @@ class AuthType(Enum): SERVICE_PRINCIPAL = "service_principal" -def get_identity_provider(request) -> EntraIDIdentityProvider: +def get_identity_provider_config(request) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) else: @@ -25,13 +28,12 @@ def get_identity_provider(request) -> EntraIDIdentityProvider: auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) if auth_type == AuthType.MANAGED_IDENTITY: - return _get_managed_identity_provider(request) + return _get_managed_identity_provider_config(request) - return _get_service_principal_provider(request) + return _get_service_principal_provider_config(request) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) @@ -43,21 +45,20 @@ def _get_managed_identity_provider(request): identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( - identity_type=identity_type, - resource=resource, - id_type=id_type, - id_value=id_value, - authority=authority, - **kwargs - ) + return ManagedIdentityProviderConfig( + identity_type=identity_type, + resource=resource, + id_type=id_type, + id_value=id_value, + kwargs=kwargs + ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config(request) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -71,16 +72,15 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(',') - return create_provider_from_service_principal( - client_id=client_id, - client_credential=client_credential, - scopes=scopes, - timeout=timeout, - token_kwargs=token_kwargs, - authority=authority, - **kwargs - ) - + return ServicePrincipalIdentityProviderConfig( + client_id=client_id, + client_credential=client_credential, + scopes=scopes, + timeout=timeout, + token_kwargs=token_kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs + ) def get_credential_provider(request) -> CredentialProvider: if hasattr(request, "param"): @@ -88,32 +88,49 @@ def get_credential_provider(request) -> CredentialProvider: else: cred_provider_kwargs = {} - idp = get_identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) + idp_config = get_identity_provider_config(request) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "max_attempts", DEFAULT_MAX_ATTEMPTS ) delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + "delay_in_ms", DEFAULT_DELAY_IN_MS ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ) + ) - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + if isinstance(idp_config, ServicePrincipalIdentityProviderConfig): + return create_from_service_principal( + idp_config.client_id, + idp_config.client_credential, + idp_config.tenant_id, + idp_config.scopes, + idp_config.timeout, + idp_config.token_kwargs, + idp_config.app_kwargs, + token_mgr_config, + ) + + return create_from_managed_identity( + idp_config.identity_type, + idp_config.resource, + idp_config.id_type, + idp_config.id_value, + idp_config.kwargs, + token_mgr_config, ) @@ -121,7 +138,13 @@ def get_credential_provider(request) -> CredentialProvider: def credential_provider(request) -> CredentialProvider: return get_credential_provider(request) - @pytest.fixture() def identity_provider(request) -> EntraIDIdentityProvider: - return get_identity_provider(request) \ No newline at end of file + config = _identity_provider_config(request) + if isinstance(config, ManagedIdentityProviderConfig): + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + +def _identity_provider_config(request) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + return get_identity_provider_config(request) \ No newline at end of file diff --git a/tests/test_cred_provider.py b/tests/test_cred_provider.py index 6b30182..f7b49b4 100644 --- a/tests/test_cred_provider.py +++ b/tests/test_cred_provider.py @@ -16,6 +16,18 @@ class TestEntraIdCredentialsProvider: { "idp_kwargs": {"auth_type": AuthType.SERVICE_PRINCIPAL}, }, + ], + ids=["Service principal"], + indirect=True, + ) + def test_get_credentials(self, credential_provider: EntraIdCredentialsProvider): + credentials = credential_provider.get_credentials() + assert len(credentials) == 2 + + + @pytest.mark.parametrize( + "credential_provider", + [ { "idp_kwargs": {"auth_type": AuthType.MANAGED_IDENTITY}, }, @@ -26,13 +38,15 @@ class TestEntraIdCredentialsProvider: }, } ], - ids=["Service principal", "Managed Identity (System-assigned)", "Managed Identity (User-assigned)"], + ids=["Managed Identity (System-assigned)", "Managed Identity (User-assigned)"], indirect=True, ) - def test_get_credentials(self, credential_provider: EntraIdCredentialsProvider): + @pytest.mark.managed_identity + def test_get_credentials_managed_identity(self, credential_provider: EntraIdCredentialsProvider): credentials = credential_provider.get_credentials() assert len(credentials) == 2 + @pytest.mark.parametrize( "credential_provider", [ @@ -40,6 +54,19 @@ def test_get_credentials(self, credential_provider: EntraIdCredentialsProvider): "cred_provider_kwargs": {"block_for_initial": False}, "idp_kwargs": {"auth_type": AuthType.SERVICE_PRINCIPAL}, }, + ], + ids=["Service principal"], + indirect=True, + ) + @pytest.mark.asyncio + async def test_get_credentials_async(self, credential_provider: EntraIdCredentialsProvider): + credentials = await credential_provider.get_credentials_async() + assert len(credentials) == 2 + + + @pytest.mark.parametrize( + "credential_provider", + [ { "cred_provider_kwargs": {"block_for_initial": True}, "idp_kwargs": {"auth_type": AuthType.MANAGED_IDENTITY}, @@ -51,14 +78,16 @@ def test_get_credentials(self, credential_provider: EntraIdCredentialsProvider): }, } ], - ids=["Service principal", "Managed Identity (System-assigned)", "Managed Identity (User-assigned)"], + ids=["Managed Identity (System-assigned)", "Managed Identity (User-assigned)"], indirect=True, ) @pytest.mark.asyncio - async def test_get_credentials_async(self, credential_provider: EntraIdCredentialsProvider): + @pytest.mark.managed_identity + async def test_get_credentials_async_managed_identity(self, credential_provider: EntraIdCredentialsProvider): credentials = await credential_provider.get_credentials_async() assert len(credentials) == 2 + @pytest.mark.parametrize( "credential_provider", [ diff --git a/tests/test_identity_provider.py b/tests/test_identity_provider.py index fbd8dd9..635cb3d 100644 --- a/tests/test_identity_provider.py +++ b/tests/test_identity_provider.py @@ -19,7 +19,7 @@ def test_request_token_from_service_principal_identity(self, identity_provider: ], indirect=True, ) - def test_request_token_caches_token_after_initial_request(self, identity_provider): + def test_request_token_caches_token_after_initial_request(self, identity_provider: EntraIDIdentityProvider): assert len(list(self.CUSTOM_CACHE.search(TokenCache.CredentialType.ACCESS_TOKEN))) == 0 token = identity_provider.request_token()