Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Python Tests
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.event.pull_request.number || github.ref }}-tests
permissions:
contents: read
env:
AZURE_CLIENT_SECRET: ${{ secrets.IDP_CLIENT_CREDENTIAL }}
AZURE_CLIENT_ID: ${{ secrets.IDP_CLIENT_ID }}
AZURE_TENANT_ID: ${{ secrets.IDP_TENANT_ID }}
jobs:
tests:
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy-3.9', 'pypy-3.10']
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r dev_requirements.txt
- name: Run tests with Python version ${{ matrix.python-version }}
run: |
pytest --junitxml=test-results.xml -m "not managed_identity"
- name: Upload test results
if: matrix.python-version == '3.12'
uses: actions/upload-artifact@v4
with:
name: test-results
path: test-results.xml
51 changes: 29 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ You need to install the `redis-py` Entra ID package via the following command:
pip install redis-entra-id
```

The package depends on [redis-py](https://github.com/redis/redis-py/tree/v5.3.0b4) version `5.3.0b4`.
The package depends on [redis-py](https://github.com/redis/redis-py).

## Usage

Expand All @@ -44,49 +44,56 @@ The package depends on [redis-py](https://github.com/redis/redis-py/tree/v5.3.0b
After having installed the package, you can import its modules:

```python
import redis
from redis_entraid import identity_provider
from redis_entraid import cred_provider
from redis import Redis
from redis_entraid.cred_provider import *
```

### Step 2 - Define your authority based on the tenant ID
### Step 2 - Create the credential provider via the factory method

```python
authority = "{}/{}".format("https://login.microsoftonline.com", "<TENANT_ID>")
credential_provider = create_from_service_principal(
CLIENT_ID,
CLIENT_SECRET,
TENANT_ID
)
```

> This step is going to be removed in the next pre-release version of `redis-py-entraid`. Instead, the factory method will allow to pass the tenant id direclty.
### Step 3 - Provide optional token renewal configuration

### Step 3 - Create the identity provider via the factory method

```python
idp = identity_provider.create_provider_from_service_principal("<CLIENT_SECRET>", "<CLIENT_ID>", authority=authority)
```

### Step 4 - Initialize a credentials provider from the authentication configuration

You can use the default configuration or customize the background task for token renewal.
The default configuration would be applied, but you're able to customise it.

```python
auth_config = TokenAuthConfig(idp)
cred_provider = EntraIdCredentialsProvider(auth_config)
credential_provider = create_from_service_principal(
CLIENT_ID,
CLIENT_SECRET,
TENANT_ID,
token_manager_config=TokenManagerConfig(
expiration_refresh_ratio=0.9,
lower_refresh_bound_millis=DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
retry_policy=RetryPolicy(
max_attempts=5,
delay_in_ms=50
)
)
)
```

You can test the credentials provider by obtaining a token. The following example demonstrates both, a synchronous and an asynchronous approach:

```python
# Synchronous
cred_provider.get_credentials()
credential_provider.get_credentials()

# Asynchronous
await cred_provider.get_credentials_async()
await credential_provider.get_credentials_async()
```

### Step 5 - Connect to Redis
### Step 4 - Connect to Redis

When using Entra ID, Azure enforces TLS on your Redis connection. Here is an example that shows how to **test** the connection in an insecure way:

```python
client = redis.Redis(host="<HOST>", port=<PORT>, ssl=True, ssl_cert_reqs=None, credential_provider=cred_provider)
client = Redis(host=HOST, port=PORT, ssl=True, ssl_cert_reqs=None, credential_provider=credential_provider)
print("The database size is: {}".format(client.dbsize()))
```
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[pytest]
asyncio_mode = auto
asyncio_mode = auto
markers =
managed_identity: Tests that should be run on Azure VM to be able to reach managed identity service.
150 changes: 107 additions & 43 deletions redis_entraid/cred_provider.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,50 @@
from dataclasses import dataclass
from typing import Union, Tuple, Callable, Any, Awaitable
from typing import Union, Tuple, Callable, Any, Awaitable, Optional, List

from redis.credentials import StreamingCredentialProvider
from redis.auth.token_manager import TokenManagerConfig, RetryPolicy, TokenManager, CredentialsListener

from redis_entraid.identity_provider import EntraIDIdentityProvider


@dataclass
class TokenAuthConfig:
"""
Configuration for token authentication.

Requires :class:`EntraIDIdentityProvider`. It's recommended to use an additional factory methods.
See :class:`EntraIDIdentityProvider` for more information.
"""
DEFAULT_EXPIRATION_REFRESH_RATIO = 0.8
DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 100
DEFAULT_MAX_ATTEMPTS = 3
DEFAULT_DELAY_IN_MS = 3

idp: EntraIDIdentityProvider
expiration_refresh_ratio: float = DEFAULT_EXPIRATION_REFRESH_RATIO
lower_refresh_bound_millis: int = DEFAULT_LOWER_REFRESH_BOUND_MILLIS
token_request_execution_timeout_in_ms: int = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS
max_attempts: int = DEFAULT_MAX_ATTEMPTS
delay_in_ms: int = DEFAULT_DELAY_IN_MS

def get_token_manager_config(self) -> TokenManagerConfig:
return TokenManagerConfig(
self.expiration_refresh_ratio,
self.lower_refresh_bound_millis,
self.token_request_execution_timeout_in_ms,
RetryPolicy(
self.max_attempts,
self.delay_in_ms
)
)

def get_identity_provider(self) -> EntraIDIdentityProvider:
return self.idp
from redis_entraid.identity_provider import ManagedIdentityType, ManagedIdentityIdType, \
_create_provider_from_managed_identity, ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig, \
_create_provider_from_service_principal

DEFAULT_EXPIRATION_REFRESH_RATIO = 0.7
DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 100
DEFAULT_MAX_ATTEMPTS = 3
DEFAULT_DELAY_IN_MS = 3

class EntraIdCredentialsProvider(StreamingCredentialProvider):
def __init__(
self,
config: TokenAuthConfig,
idp_config: Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig],
token_manager_config: TokenManagerConfig,
initial_delay_in_ms: float = 0,
block_for_initial: bool = False,
):
"""
:param config:
:param idp_config: Identity provider specific configuration.
:param token_manager_config: Token manager specific configuration.
:param initial_delay_in_ms: Initial delay before run background refresh (valid for async only)
:param block_for_initial: Block execution until initial token will be acquired (valid for async only)
"""
if isinstance(idp_config, ManagedIdentityProviderConfig):
idp = _create_provider_from_managed_identity(idp_config)
else:
idp = _create_provider_from_service_principal(idp_config)

self._token_mgr = TokenManager(
config.get_identity_provider(),
config.get_token_manager_config()
idp,
token_manager_config
)
self._listener = CredentialsListener()
self._is_streaming = False
self._initial_delay_in_ms = initial_delay_in_ms
self._block_for_initial = block_for_initial

def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
"""
Acquire token from the identity provider.
"""
init_token = self._token_mgr.acquire_token()

if self._is_streaming is False:
Expand All @@ -77,6 +57,9 @@ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
return init_token.get_token().try_get('oid'), init_token.get_token().get_value()

async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
"""
Acquire token from the identity provider in async mode.
"""
init_token = await self._token_mgr.acquire_token_async()

if self._is_streaming is False:
Expand All @@ -98,3 +81,84 @@ def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]):

def is_streaming(self) -> bool:
return self._is_streaming


def create_from_managed_identity(
identity_type: ManagedIdentityType,
resource: str,
id_type: Optional[ManagedIdentityIdType] = None,
id_value: Optional[str] = '',
kwargs: Optional[dict] = {},
token_manager_config: Optional[TokenManagerConfig] = TokenManagerConfig(
DEFAULT_EXPIRATION_REFRESH_RATIO,
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
RetryPolicy(
DEFAULT_MAX_ATTEMPTS,
DEFAULT_DELAY_IN_MS
)
)
) -> EntraIdCredentialsProvider:
"""
Create a credential provider from a managed identity type.

:param identity_type: Managed identity type.
:param resource: Identity provider resource.
:param id_type: Identity provider type.
:param id_value: Identity provider value.
:param kwargs: Optional keyword arguments to pass to identity provider. See: :class:`ManagedIdentityClient`
:param token_manager_config: Token manager specific configuration.
:return: EntraIdCredentialsProvider instance.
"""
managed_identity_config = ManagedIdentityProviderConfig(
identity_type=identity_type,
resource=resource,
id_type=id_type,
id_value=id_value,
kwargs=kwargs
)

return EntraIdCredentialsProvider(managed_identity_config, token_manager_config)


def create_from_service_principal(
client_id: str,
client_credential: Any,
tenant_id: Optional[str] = None,
scopes: Optional[List[str]] = None,
timeout: Optional[float] = None,
token_kwargs: Optional[dict] = {},
app_kwargs: Optional[dict] = {},
token_manager_config: Optional[TokenManagerConfig] = TokenManagerConfig(
DEFAULT_EXPIRATION_REFRESH_RATIO,
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
RetryPolicy(
DEFAULT_MAX_ATTEMPTS,
DEFAULT_DELAY_IN_MS
)
)) -> EntraIdCredentialsProvider:
"""
Create a credential provider from a service principal.

:param client_credential: Service principal credentials.
:param client_id: Service principal client ID.
:param scopes: Service principal scopes. Fallback to default scopes if None.
:param timeout: Service principal timeout.
:param tenant_id: Service principal tenant ID.
:param token_kwargs: Optional token arguments to pass to service identity provider.
:param app_kwargs: Optional keyword arguments to pass to service principal application.
:param token_manager_config: Token manager specific configuration.
:return: EntraIdCredentialsProvider instance.
"""
service_principal_config = ServicePrincipalIdentityProviderConfig(
client_credential=client_credential,
client_id=client_id,
scopes=scopes,
timeout=timeout,
tenant_id=tenant_id,
app_kwargs=app_kwargs,
token_kwargs=token_kwargs,
)

return EntraIdCredentialsProvider(service_principal_config, token_manager_config)
Loading