diff --git a/src/business_logic/authorization/factories/auth_service_factory.py b/src/business_logic/authorization/factories/auth_service_factory.py index fa31c350..2b6d5b43 100644 --- a/src/business_logic/authorization/factories/auth_service_factory.py +++ b/src/business_logic/authorization/factories/auth_service_factory.py @@ -23,7 +23,8 @@ ) if TYPE_CHECKING: - from src.business_logic.services import JWTService, PasswordHash + from src.business_logic.services import PasswordHash + from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager FactoryMethod = Callable[..., AuthServiceProtocol] @@ -41,7 +42,7 @@ def __init__( persistent_grant_repo: PersistentGrantRepository, device_repo: DeviceRepository, password_service: PasswordHash, - jwt_service: JWTService, + jwt_service: JWTManager, ) -> None: self.session = session self._client_repo = client_repo diff --git a/src/business_logic/authorization/factories/factory_methods/create_id_token_auth_service.py b/src/business_logic/authorization/factories/factory_methods/create_id_token_auth_service.py index 82d552b8..c5f8cfd9 100644 --- a/src/business_logic/authorization/factories/factory_methods/create_id_token_auth_service.py +++ b/src/business_logic/authorization/factories/factory_methods/create_id_token_auth_service.py @@ -14,7 +14,8 @@ if TYPE_CHECKING: from src.business_logic.authorization.interfaces import AuthServiceProtocol - from src.business_logic.services import JWTService, PasswordHash + from src.business_logic.services import PasswordHash # JWTService, + from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager from src.data_access.postgresql.repositories import ( ClientRepository, UserRepository, @@ -26,7 +27,7 @@ def _create_id_token_auth_service( client_repo: ClientRepository, user_repo: UserRepository, password_service: PasswordHash, - jwt_service: JWTService, + jwt_service: JWTManager, **kwargs: Any, ) -> AuthServiceProtocol: return IdTokenAuthService( diff --git a/src/business_logic/authorization/factories/factory_methods/create_id_token_token_auth_service.py b/src/business_logic/authorization/factories/factory_methods/create_id_token_token_auth_service.py index 148b8975..21226b2d 100644 --- a/src/business_logic/authorization/factories/factory_methods/create_id_token_token_auth_service.py +++ b/src/business_logic/authorization/factories/factory_methods/create_id_token_token_auth_service.py @@ -16,7 +16,8 @@ if TYPE_CHECKING: from src.business_logic.authorization.interfaces import AuthServiceProtocol - from src.business_logic.services import JWTService, PasswordHash + from src.business_logic.services import PasswordHash # JWTService, + from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager from src.data_access.postgresql.repositories import ( ClientRepository, UserRepository, @@ -28,7 +29,7 @@ def _create_id_token_token_auth_service( client_repo: ClientRepository, user_repo: UserRepository, password_service: PasswordHash, - jwt_service: JWTService, + jwt_service: JWTManager, **kwargs: Any, ) -> AuthServiceProtocol: return IdTokenTokenAuthService( diff --git a/src/business_logic/authorization/factories/factory_methods/create_token_auth_service.py b/src/business_logic/authorization/factories/factory_methods/create_token_auth_service.py index d969f41d..ee6f2bb6 100644 --- a/src/business_logic/authorization/factories/factory_methods/create_token_auth_service.py +++ b/src/business_logic/authorization/factories/factory_methods/create_token_auth_service.py @@ -14,7 +14,8 @@ if TYPE_CHECKING: from src.business_logic.authorization.interfaces import AuthServiceProtocol - from src.business_logic.services import JWTService, PasswordHash + from src.business_logic.services import PasswordHash #JWTService, + from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager from src.data_access.postgresql.repositories import ( ClientRepository, UserRepository, @@ -26,7 +27,7 @@ def _create_token_auth_service( client_repo: ClientRepository, user_repo: UserRepository, password_service: PasswordHash, - jwt_service: JWTService, + jwt_service: JWTManager, **kwargs: Any, ) -> AuthServiceProtocol: return TokenAuthService( diff --git a/src/business_logic/get_tokens/service_impls/auth_code.py b/src/business_logic/get_tokens/service_impls/auth_code.py index 24da3083..4bbc23b1 100644 --- a/src/business_logic/get_tokens/service_impls/auth_code.py +++ b/src/business_logic/get_tokens/service_impls/auth_code.py @@ -105,13 +105,13 @@ async def _get_access_token(self, request_data: RequestTokenModel, user_id: int, jti=str(uuid.uuid4()), acr=0, ) - return self._jwt_manager.encode(payload=payload, algorithm='RS256') + return await self._jwt_manager.encode(payload=payload, algorithm='RS256') async def _get_refresh_token(self, request_data: RequestTokenModel) -> str: payload = RefreshTokenPayload( jti=str(uuid.uuid4()) ) - return self._jwt_manager.encode(payload=payload, algorithm='RS256') + return await self._jwt_manager.encode(payload=payload, algorithm='RS256') async def _get_id_token(self, request_data: RequestTokenModel, user_id: int, unix_time: int) -> str: payload = IdTokenPayload( @@ -123,4 +123,4 @@ async def _get_id_token(self, request_data: RequestTokenModel, user_id: int, uni jti=str(uuid.uuid4()), acr=0, ) - return self._jwt_manager.encode(payload=payload, algorithm='RS256') + return await self._jwt_manager.encode(payload=payload, algorithm='RS256') diff --git a/src/business_logic/jwt_manager/dto/__init__.py b/src/business_logic/jwt_manager/dto/__init__.py index c937fe95..5635ff0a 100644 --- a/src/business_logic/jwt_manager/dto/__init__.py +++ b/src/business_logic/jwt_manager/dto/__init__.py @@ -2,7 +2,8 @@ AccessTokenPayload, RefreshTokenPayload, IdTokenPayload, + AdminUIPayload ) -__all__ = ['AccessTokenPayload', 'RefreshTokenPayload', 'IdTokenPayload'] +__all__ = ['AccessTokenPayload', 'RefreshTokenPayload', 'IdTokenPayload', 'AdminUIPayload'] diff --git a/src/business_logic/jwt_manager/dto/input.py b/src/business_logic/jwt_manager/dto/input.py index 00bf94f5..ffba94b8 100644 --- a/src/business_logic/jwt_manager/dto/input.py +++ b/src/business_logic/jwt_manager/dto/input.py @@ -1,16 +1,17 @@ from pydantic import BaseModel from typing import Optional, Any, Union - +from src.dyna_config import DOMAIN_NAME +import uuid class BaseJWTPayload(BaseModel): sub: Union[int, str] # user id - iss: str # auth service uri iat: int # time of creation exp: int # time when token will expire aud: Optional[Union[str, list[str]]] # name for whom token was generated client_id: str # id of the client who issued a token - jti: str # uniques identifier for token, UUID4 acr: Optional[int] # default 0 + jti: str = str(uuid.uuid4()) # uniques identifier for token, UUID4 + iss: str = DOMAIN_NAME # auth service uri class AccessTokenPayload(BaseJWTPayload): @@ -31,3 +32,8 @@ class IdTokenPayload(BaseJWTPayload): picture: Optional[str] = None zoneinfo: Optional[str] = None locale: Optional[str] = None + +class AdminUIPayload(BaseModel): + sub:int + exp:int + aud:list[str] = ["oidc:admin_ui"] \ No newline at end of file diff --git a/src/business_logic/jwt_manager/service_impls/jwt_service.py b/src/business_logic/jwt_manager/service_impls/jwt_service.py index b0cc95d7..3bbf62e8 100644 --- a/src/business_logic/jwt_manager/service_impls/jwt_service.py +++ b/src/business_logic/jwt_manager/service_impls/jwt_service.py @@ -2,39 +2,44 @@ import logging import jwt from typing import Any, Optional, Union - from src.data_access.postgresql.tables.rsa_keys import RSA_keys from src.business_logic.jwt_manager.dto import ( AccessTokenPayload, RefreshTokenPayload, IdTokenPayload, + AdminUIPayload ) +from src.di.providers.rsa_keys import provide_rsa_keys logger = logging.getLogger(__name__) -Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload] +Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload, AdminUIPayload] class JWTManager: - def __init__( - self, - keys: RSA_keys - ) -> None: - self.keys = keys - # print(f"print:!!!jwt_service.py;!!! self.keys: {self.keys}") - - def encode(self, payload: Payload, algorithm: str, secret: Optional[str] = None) -> str: - if secret: - key = secret - else: - key = self.keys.private_key - print(f"jwt_service.py; keys: {self.keys.private_key}") + def __init__(self) -> None: + self.algorithm = "RS256" + self.algorithms = ["RS256"] + self.keys = None + + + def check_rsa_keys(self) -> None: + if not self.keys: + self.keys = provide_rsa_keys() + if self.keys is None: + raise ValueError("Keys don't exist or Docker is not running") + async def encode(self, payload: Payload, algorithm: Optional[str] = None, secret: Optional[str] = None) -> str: + self.check_rsa_keys() + key = secret or self.keys.private_key token = jwt.encode( - payload=payload.dict(exclude_none=True), key=key, algorithm=algorithm + payload=payload.dict(exclude_none=True), + key=key, + algorithm=(algorithm or self.algorithm) ) return token - def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> dict[str, Any]: + async def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> dict[str, Any]: + self.check_rsa_keys() token = token.replace("Bearer ", "") if audience: decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms, @@ -44,3 +49,22 @@ def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> d **kwargs,) return decoded_info + + async def verify_token(self, token: str, aud: str = None) -> bool: + self.check_rsa_keys() + try: + if aud: + await self.decode(token=token, audience=aud) + else: + await self.decode(token) + return True + except: + return False + + async def get_module(self) -> int: + self.check_rsa_keys() + return self.keys.n + + async def get_pub_key_expanent(self) -> int: + self.check_rsa_keys() + return self.keys.e diff --git a/src/business_logic/services/admin_api.py b/src/business_logic/services/admin_api.py index 67ada7a7..1de19bea 100644 --- a/src/business_logic/services/admin_api.py +++ b/src/business_logic/services/admin_api.py @@ -1,17 +1,19 @@ # https://docs.google.com/spreadsheets/d/1zJAcCxaGz2CV9zKlqeBhRE8xfiqyJMy5-XPRH1I8HPg/edit#gid=0 -from typing import Optional, Any +from typing import Optional, Any, TYPE_CHECKING from src.data_access.postgresql.repositories import RoleRepository, UserRepository, PersistentGrantRepository, GroupRepository from sqlalchemy.ext.asyncio import AsyncSession from typing import Optional -from src.business_logic.services.jwt_token import JWTService +# from src.di.providers import provide_jwt_manager from src.business_logic.services.password import PasswordHash from src.data_access.postgresql.tables import Group, Role, User +# if TYPE_CHECKING: +# from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager class AdminService(): def __init__( self, - jwt_service: JWTService, + jwt_service#: JWTManager, ) -> None: self.jwt_service = jwt_service diff --git a/src/business_logic/services/admin_auth.py b/src/business_logic/services/admin_auth.py index f931a707..9f2606c4 100644 --- a/src/business_logic/services/admin_auth.py +++ b/src/business_logic/services/admin_auth.py @@ -3,12 +3,12 @@ import time from src.business_logic.dto import AdminCredentialsDTO from src.business_logic.services.password import PasswordHash -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories import UserRepository, PersistentGrantRepository from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import RedirectResponse from src.dyna_config import DOMAIN_NAME - +from src.business_logic.jwt_manager.dto import AdminUIPayload logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ def __init__( self, user_repo: UserRepository, password_service = PasswordHash(), - jwt_service = JWTService(), + jwt_service = provide_jwt_manager(), ) -> None: self.user_repo = user_repo self.password_service = password_service @@ -33,17 +33,16 @@ async def authorize( credentials.password, user_hash_password ) - return await self.jwt_service.encode_jwt( - payload={ - "sub":user_id, - "exp": exp_time + int(time.time()), - "aud":["admin","introspection", "revoke"] - } + return await self.jwt_service.encode( + payload=AdminUIPayload( + sub=user_id, + exp=exp_time + int(time.time()), + ) ) async def authenticate(self, token: str) -> Union[None, RedirectResponse]: - if await self.jwt_service.verify_token(token=token, aud="admin"): + if await self.jwt_service.verify_token(token=token, aud="oidc:admin_ui"): return None else: # return None diff --git a/src/business_logic/services/authorization.py b/src/business_logic/services/authorization.py index 95f73568..278a8943 100644 --- a/src/business_logic/services/authorization.py +++ b/src/business_logic/services/authorization.py @@ -7,7 +7,7 @@ from src.config.settings.app import AppSettings from src.dyna_config import DOMAIN_NAME -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.services.password import PasswordHash from src.business_logic.services.tokens import get_single_token from src.data_access.postgresql.repositories import ( @@ -32,7 +32,7 @@ def __init__( persistent_grant_repo: PersistentGrantRepository, device_repo: DeviceRepository, password_service: PasswordHash = PasswordHash(), - jwt_service: JWTService = JWTService(), + jwt_service: JWTManager = provide_jwt_manager(), ) -> None: self._request_model: Optional[DataRequestModel] = None self.client_repo = client_repo diff --git a/src/business_logic/services/authorization/authorization_service.py b/src/business_logic/services/authorization/authorization_service.py index 9e5d90bd..5481bb52 100644 --- a/src/business_logic/services/authorization/authorization_service.py +++ b/src/business_logic/services/authorization/authorization_service.py @@ -1,10 +1,10 @@ import logging -from typing import Optional +from typing import Optional, TYPE_CHECKING from src.business_logic.services.authorization.response_type_handlers.factory import ( ResponseTypeHandlerFactory, ) -from src.business_logic.services.jwt_token import JWTService +# from src.di.providers import provide_jwt_manager from src.business_logic.services.password import PasswordHash from src.data_access.postgresql.errors import ClientScopesError from src.data_access.postgresql.repositories import ( @@ -14,6 +14,8 @@ UserRepository, ) from src.presentation.api.models import DataRequestModel +if TYPE_CHECKING: + from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager logger = logging.getLogger(__name__) @@ -26,7 +28,7 @@ def __init__( persistent_grant_repo: PersistentGrantRepository, device_repo: DeviceRepository, password_service: PasswordHash, - jwt_service: JWTService, + jwt_service, ) -> None: self._request_model: Optional[DataRequestModel] = None self.client_repo = client_repo diff --git a/src/business_logic/services/endsession.py b/src/business_logic/services/endsession.py index af6c6a2e..cc003088 100644 --- a/src/business_logic/services/endsession.py +++ b/src/business_logic/services/endsession.py @@ -1,9 +1,10 @@ +from src.business_logic.jwt_manager import JWTManager from src.presentation.api.models.endsession import RequestEndSessionModel from src.data_access.postgresql.repositories.client import ClientRepository from src.data_access.postgresql.repositories.persistent_grant import ( PersistentGrantRepository, ) -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from typing import Union, Optional, Any from sqlalchemy.ext.asyncio import AsyncSession @@ -14,11 +15,10 @@ def __init__( session:AsyncSession, client_repo: ClientRepository, persistent_grant_repo: PersistentGrantRepository, - jwt_service = JWTService() ) -> None: self.client_repo = client_repo self.persistent_grant_repo = persistent_grant_repo - self.jwt_service = jwt_service + self.jwt_service = provide_jwt_manager() self._request_model: Optional[RequestEndSessionModel] = None self.session = session @@ -53,7 +53,7 @@ async def end_session(self) -> Optional[str]: async def _decode_id_token_hint( self, id_token_hint: str ) -> dict[str, Any]: - decoded_data = await self.jwt_service.decode_token(token=id_token_hint) + decoded_data = await self.jwt_service.decode(token=id_token_hint) return decoded_data async def _logout(self, client_id: str, user_id: int) -> None: diff --git a/src/business_logic/services/introspection.py b/src/business_logic/services/introspection.py index 4d4fedba..f7c24e5c 100644 --- a/src/business_logic/services/introspection.py +++ b/src/business_logic/services/introspection.py @@ -1,8 +1,9 @@ -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING from fastapi import Request from jwt.exceptions import ExpiredSignatureError, PyJWTError + from src.data_access.postgresql.errors import TokenIncorrectError -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories.client import ClientRepository from src.data_access.postgresql.repositories.persistent_grant import ( PersistentGrantRepository, @@ -13,6 +14,8 @@ BodyRequestIntrospectionModel, ) from sqlalchemy.ext.asyncio import AsyncSession +# if TYPE_CHECKING: +from src.business_logic.jwt_manager import JWTManager class IntrospectionService: @@ -56,7 +59,7 @@ def __init__( user_repo: UserRepository, client_repo: ClientRepository, persistent_grant_repo: PersistentGrantRepository, - jwt: JWTService = JWTService(), + jwt: JWTManager = provide_jwt_manager(), ) -> None: self.jwt = jwt self.request: Optional[Request] = None @@ -102,7 +105,7 @@ async def analyze_token(self) -> dict[str, Any]: response: dict[str, Any] = {} try: - decoded_token = await self.jwt.decode_token( + decoded_token = await self.jwt.decode( token=self.request_body.token, audience="introspection" ) except ExpiredSignatureError: diff --git a/src/business_logic/services/jwt_token.py b/src/business_logic/services/jwt_token.py index 70051f05..bcf42f8c 100644 --- a/src/business_logic/services/jwt_token.py +++ b/src/business_logic/services/jwt_token.py @@ -2,21 +2,26 @@ import jwt from typing import Any, no_type_check, Optional - -from src.di.providers.rsa_keys import provide_rsa_keys from src.config.rsa_keys import RSAKeypair logger = logging.getLogger(__name__) class JWTService: - def __init__(self, keys: RSAKeypair = provide_rsa_keys) -> None: + def __init__(self) -> None: self.algorithm = "RS256" self.algorithms = ["RS256"] - self.keys = keys + self.keys:Optional[RSAKeypair] = None + + def check_rsa_keys(self): + if not self.keys: + self.keys = 123 + if self.keys is None: + raise ValueError("Keys don't exist or Docker is not running") @no_type_check async def encode_jwt(self, payload: dict[str, Any] = {}, secret: None = None) -> str: + self.check_rsa_keys() token = jwt.encode( payload=payload, key=self.keys().private_key, algorithm=self.algorithm ) @@ -27,7 +32,7 @@ async def encode_jwt(self, payload: dict[str, Any] = {}, secret: None = None) -> @no_type_check async def decode_token(self, token: str, audience: str =None ,**kwargs: Any) -> dict[str, Any]: - + self.check_rsa_keys() token = token.replace("Bearer ", "") if audience: decoded = jwt.decode( @@ -47,6 +52,7 @@ async def decode_token(self, token: str, audience: str =None ,**kwargs: Any) -> return decoded async def verify_token(self, token: str, aud:str=None) -> bool: + self.check_rsa_keys() try: if aud: await self.decode_token(token=token, audience=aud) @@ -57,7 +63,9 @@ async def verify_token(self, token: str, aud:str=None) -> bool: return False async def get_module(self) -> int: + self.check_rsa_keys() return self.keys().n async def get_pub_key_expanent(self) -> int: + self.check_rsa_keys() return self.keys().e diff --git a/src/business_logic/services/tokens.py b/src/business_logic/services/tokens.py index 3bac2006..0d03507c 100644 --- a/src/business_logic/services/tokens.py +++ b/src/business_logic/services/tokens.py @@ -18,8 +18,6 @@ from itsdangerous import base64_encode from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession - -# from src.business_logic.services.jwt_token import JWTService from src.config.settings.app import AppSettings from src.data_access.postgresql.errors import ( ClaimsNotFoundError, @@ -41,7 +39,7 @@ ) from src.dyna_config import DOMAIN_NAME if TYPE_CHECKING: - from src.business_logic.services.jwt_token import JWTService + from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager from src.presentation.api.models import ( BodyRequestRevokeModel, BodyRequestTokenModel @@ -83,7 +81,7 @@ def get_base_payload( async def get_single_token( client_id: str, - jwt_service: JWTService, + jwt_service: JWTManager, expiration_time: int, scope: Optional[str] = None, **kwargs: Any, @@ -100,7 +98,7 @@ async def get_single_token( **kwargs, ) full_payload = {**base_payload} - access_token = await jwt_service.encode_jwt(payload=full_payload) + access_token = await jwt_service.encode(payload=full_payload) return access_token @@ -114,7 +112,7 @@ def __init__( device_repo: DeviceRepository, blacklisted_repo: BlacklistedTokenRepository, code_challenge_repo: CodeChallengeRepository, - jwt_service: JWTService, + jwt_service: JWTManager, ) -> None: self.session = session self.request: Optional[Request] = None @@ -207,7 +205,7 @@ def __init__(self, token_service: TokenService) -> None: token_service.code_challenge_repo ) self.user_repo: UserRepository = token_service.user_repo - self.jwt_service: JWTService = token_service.jwt_service + self.jwt_service: JWTManager = token_service.jwt_service self.blacklisted_repo: BlacklistedTokenRepository = ( token_service.blacklisted_repo ) @@ -521,7 +519,7 @@ async def create(self) -> Dict[str, Any]: scopes = ["No scope"] audience = ["admin", "introspection", "revoke"] - access_token = await self.jwt_service.encode_jwt( + access_token = await self.jwt_service.encode( { # "arc" : client_from_db.arc, # # ACR value is a set of arbitrary values that the client and idp agreed upon to communicate the level of authentication that happened. This is to give the client a level of confidence on the qualify of the authentication that took place. diff --git a/src/business_logic/services/userinfo.py b/src/business_logic/services/userinfo.py index 789d8934..b50a66e7 100644 --- a/src/business_logic/services/userinfo.py +++ b/src/business_logic/services/userinfo.py @@ -1,8 +1,9 @@ from typing import Any, Optional -from jwt.exceptions import PyJWTError +# from jwt.exceptions import PyJWTError -from src.business_logic.services.jwt_token import JWTService +from src.business_logic.jwt_manager import JWTManager +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories.client import ClientRepository from src.data_access.postgresql.repositories.persistent_grant import ( PersistentGrantRepository, @@ -18,7 +19,7 @@ def __init__( user_repo: UserRepository, client_repo: ClientRepository, persistent_grant_repo: PersistentGrantRepository, - jwt: JWTService = JWTService(), + jwt: JWTManager = provide_jwt_manager(), ) -> None: self.jwt = jwt self.authorization: Optional[str] = None @@ -92,5 +93,5 @@ async def get_user_info( async def get_user_info_jwt(self) -> str: result = await self.get_user_info() - token = await self.jwt.encode_jwt(payload=result) + token = await self.jwt.encode(payload=result) return token diff --git a/src/business_logic/services/well_known.py b/src/business_logic/services/well_known.py index b33a7d07..9bc3b649 100644 --- a/src/business_logic/services/well_known.py +++ b/src/business_logic/services/well_known.py @@ -1,11 +1,16 @@ + +from src.data_access.postgresql.tables.persistent_grant import TYPES_OF_GRANTS +from src.di.providers import provide_jwt_manager + from src.business_logic.dto.open_id_config import OpenIdConfiguration from src.business_logic.services.jwt_token import JWTService + from jwkest import long_to_base64, base64_to_long import logging from src.dyna_config import DOMAIN_NAME from jwkest import base64_to_long, long_to_base64 from fastapi import Request -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from typing import Any, Union from src.data_access.postgresql.repositories import WellKnownRepository from sqlalchemy.ext.asyncio import AsyncSession @@ -67,6 +72,7 @@ def get_all_urls(self, result: dict[str, Any]) -> dict[str, Any]: for route in self.request.app.routes } | {"false": "/ Not ready yet"} + def get_algorithms(self) -> list[str]: """Retrieves a list of algorithms from the JWTService. @@ -166,6 +172,7 @@ async def get_openid_configuration( return result_model async def get_jwks(self) -> dict[str, Any]: + """Retrieves the JWKS (JSON Web Key Set). Returns: diff --git a/src/config/rsa_keys/rsa_keys_service.py b/src/config/rsa_keys/rsa_keys_service.py index b7968e7b..ff470c05 100644 --- a/src/config/rsa_keys/rsa_keys_service.py +++ b/src/config/rsa_keys/rsa_keys_service.py @@ -1,6 +1,6 @@ from Crypto.PublicKey import RSA from sqlalchemy.orm import sessionmaker - +from sqlalchemy.exc import OperationalError from src.data_access.postgresql.repositories import RSAKeysRepository from src.data_access.postgresql.tables.rsa_keys import RSA_keys from .dto import RSAKeypair @@ -16,13 +16,16 @@ def __init__( self.rsa_keys_repo = rsa_keys_repo def get_rsa_keys(self) -> RSA_keys: - with self.session() as session: - if self.rsa_keys_repo.validate_keys_exists(session=session): - self.rsa_keys = self.rsa_keys_repo.get_keys_from_repository(session) - else: - self.rsa_keys = self.create_rsa_keys() # RSAKeypair - self.rsa_keys_repo.put_keys_to_repository(rsa_keys=self.rsa_keys, session=session) - self.rsa_keys = self.rsa_keys_repo.get_keys_from_repository(session=session) # RSA_keys + try: + with self.session() as session: + if self.rsa_keys_repo.validate_keys_exists(session=session): + self.rsa_keys = self.rsa_keys_repo.get_keys_from_repository(session) + else: + self.rsa_keys = self.create_rsa_keys() # RSAKeypair + self.rsa_keys_repo.put_keys_to_repository(rsa_keys=self.rsa_keys, session=session) + self.rsa_keys = self.rsa_keys_repo.get_keys_from_repository(session=session) # RSA_keys + except OperationalError: + self.rsa_keys = None return self.rsa_keys def create_rsa_keys(self) -> RSAKeypair: # or -> RSA_keys diff --git a/src/data_access/postgresql/database.py b/src/data_access/postgresql/database.py index 9d3d184c..51b9b637 100644 --- a/src/data_access/postgresql/database.py +++ b/src/data_access/postgresql/database.py @@ -42,15 +42,29 @@ def _create_connection_pool(self, db_url: str, max_connection_count: int) -> Asy async def get_connection(self) -> AsyncSession: async with self.__session_factory() as session: yield session - - +from sqlalchemy import text +from time import sleep class DatabaseSync: def __init__(self, database_url: str): - self.__sync_engine = self._create_sync_connection_pool(database_url) - self.__sync_session_factory = sessionmaker( - self.__sync_engine - ) - + # retry_count = 0 + # while retry_count<6: + # try: + self.__sync_engine = self._create_sync_connection_pool(database_url) + self.__sync_session_factory = sessionmaker( + self.__sync_engine + ) + # with self.__sync_session_factory() as sess: + # print(sess) + # a = sess.execute(text("SELECT * FROM clients")) + # print(a) + # break + # except Exception as err: + # retry_count+=1 + # if retry_count<2: + # print(err) + # sleep(2) + + @property def sync_engine(self) -> Engine: return self.__sync_engine diff --git a/src/di/providers/__init__.py b/src/di/providers/__init__.py index 353f037f..c386a1dd 100644 --- a/src/di/providers/__init__.py +++ b/src/di/providers/__init__.py @@ -1,45 +1,19 @@ from .config import provide_config from .db import provide_db, provide_db_only from .repositories import ( - provide_wellknown_repo, - provide_client_repo, - provide_device_repo, - provide_group_repo, - provide_persistent_grant_repo, - provide_role_repo, - provide_third_party_oidc_repo, - provide_user_repo, - provide_blacklisted_repo, + # provide_wellknown_repo, + # provide_client_repo, + # provide_device_repo, + # provide_group_repo, + # provide_persistent_grant_repo, + # provide_role_repo, + # provide_third_party_oidc_repo, + # provide_user_repo, + # provide_blacklisted_repo, provide_async_session, provide_async_session_stub, ProviderSession, ) -# from .services import ( -# provide_wellknown_service, -# provide_admin_auth_service, -# provide_admin_group_service, -# provide_admin_role_service, -# provide_admin_user_service, -# provide_auth_service, -# provide_auth_third_party_linkedin_service, -# provide_auth_third_party_oidc_service, -# provide_third_party_google_service, -# provide_third_party_facebook_service, -# provide_third_party_gitlab_service, -# provide_third_party_microsoft_service, -# provide_device_service, -# provide_endsession_service, -# provide_introspection_service, -# provide_jwt_service, -# provide_login_form_service, -# provide_password_service, -# provide_token_service, -# provide_userinfo_service, -# provide_client_service, -# ) -# from .services_factory import ( -# provide_auth_service_factory, -# # provide_third_party_auth_service_factory, -# ) + from .jwt_manager import provide_jwt_manager, provide_jwt_manager_stub from .token_factory import provide_token_service_factory \ No newline at end of file diff --git a/src/di/providers/jwt_manager.py b/src/di/providers/jwt_manager.py index b9286f83..97e24637 100644 --- a/src/di/providers/jwt_manager.py +++ b/src/di/providers/jwt_manager.py @@ -1,4 +1,3 @@ -from src.di.providers.rsa_keys import provide_rsa_keys from src.business_logic.jwt_manager import JWTManager from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol @@ -9,5 +8,4 @@ def provide_jwt_manager_stub() -> None: def provide_jwt_manager() -> JWTManagerProtocol: - rsa_keys = provide_rsa_keys() - return JWTManager(keys=rsa_keys) + return JWTManager() diff --git a/src/di/providers/rsa_keys.py b/src/di/providers/rsa_keys.py index 5dbee010..cec6b3fa 100644 --- a/src/di/providers/rsa_keys.py +++ b/src/di/providers/rsa_keys.py @@ -8,7 +8,8 @@ def provide_rsa_keys_stub() -> None: def provide_rsa_keys() -> RSA_keys: - sync_session_factory = DatabaseSync(database_url=DB_URL).sync_session_factory + db_url = DB_URL.replace("+asyncpg",'') + sync_session_factory = DatabaseSync(database_url=db_url).sync_session_factory rsa_keys = RSAKeysService( sync_session_factory=sync_session_factory, rsa_keys_repo=RSAKeysRepository() diff --git a/src/di/providers/services.py b/src/di/providers/services.py index d7bb4bbb..738df463 100644 --- a/src/di/providers/services.py +++ b/src/di/providers/services.py @@ -1,320 +1,6 @@ -# from httpx import AsyncClient -# from sqlalchemy.ext.asyncio import AsyncSession -# from src.business_logic.services.admin_auth import AdminAuthService -# from src.business_logic.services.admin_api import ( -# AdminGroupService, -# AdminRoleService, -# AdminUserService, -# ) -# from src.business_logic.services.authorization.authorization_service import ( -# AuthorizationService, -# ) -# from src.business_logic.services.third_party_oidc_service import ( -# AuthThirdPartyOIDCService, -# ) -# from src.business_logic.services.device_auth import DeviceService -# from src.business_logic.services.endsession import EndSessionService -# from src.business_logic.services.introspection import IntrospectionService -# from src.business_logic.services.jwt_token import JWTService -# from src.business_logic.services.login_form_service import LoginFormService -# from src.business_logic.services.password import PasswordHash -# from src.business_logic.services.third_party_oidc_service import ( -# ThirdPartyFacebookService, -# ThirdPartyGitLabService, -# ThirdPartyGoogleService, -# ThirdPartyLinkedinService, -# ThirdPartyMicrosoftService, -# ) -# from src.business_logic.services.tokens import TokenService -# from src.business_logic.services.userinfo import UserInfoService -# from src.business_logic.services.well_known import WellKnownService -# from src.business_logic.services.client import ClientService -# -# -# from src.data_access.postgresql.repositories import ( -# BlacklistedTokenRepository, -# ClientRepository, -# DeviceRepository, -# GroupRepository, -# PersistentGrantRepository, -# RoleRepository, -# ThirdPartyOIDCRepository, -# UserRepository, -# WellKnownRepository, -# CodeChallengeRepository, -# ) -# -# -# def provide_auth_service( -# client_repo: ClientRepository, -# user_repo: UserRepository, -# persistent_grant_repo: PersistentGrantRepository, -# device_repo: DeviceRepository, -# password_service: PasswordHash, -# jwt_service: JWTService, -# ) -> AuthorizationService: -# return AuthorizationService( -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# device_repo=device_repo, -# password_service=password_service, -# jwt_service=jwt_service, -# ) -# -# -# def provide_password_service() -> PasswordHash: -# return PasswordHash() -# -# -# def provide_endsession_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# persistent_grant_repo: PersistentGrantRepository, -# jwt_service: JWTService, -# ) -> EndSessionService: -# return EndSessionService( -# session=session, -# client_repo=client_repo, -# persistent_grant_repo=persistent_grant_repo, -# jwt_service=jwt_service, -# ) -# -# -# def provide_jwt_service() -> JWTService: -# return JWTService() -# -# -# def provide_introspection_service( -# session: AsyncSession, -# jwt: JWTService, -# user_repo: UserRepository, -# client_repo: ClientRepository, -# persistent_grant_repo: PersistentGrantRepository, -# ) -> IntrospectionService: -# return IntrospectionService( -# jwt=jwt, -# user_repo=user_repo, -# client_repo=client_repo, -# persistent_grant_repo=persistent_grant_repo, -# ) -# -# -# def provide_token_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# persistent_grant_repo: PersistentGrantRepository, -# user_repo: UserRepository, -# device_repo: DeviceRepository, -# code_challenge_repo: CodeChallengeRepository, -# jwt_service: JWTService, -# blacklisted_repo: BlacklistedTokenRepository, -# ) -> TokenService: -# return TokenService( -# session=session, -# client_repo=client_repo, -# persistent_grant_repo=persistent_grant_repo, -# user_repo=user_repo, -# device_repo=device_repo, -# code_challenge_repo=code_challenge_repo, -# jwt_service=jwt_service, -# blacklisted_repo=blacklisted_repo, -# ) -# -# -# def provide_admin_user_service( -# user_repo: UserRepository, role_repo: RoleRepository, session: AsyncSession -# ) -> AdminUserService: -# return AdminUserService( -# user_repo=user_repo, role_repo=role_repo, session=session -# ) -# -# -# def provide_admin_group_service( -# session: AsyncSession, group_repo: GroupRepository -# ) -> AdminGroupService: -# return AdminGroupService(session=session, group_repo=group_repo) -# -# -# def provide_admin_role_service( -# session: AsyncSession, -# role_repo: RoleRepository, -# ) -> AdminRoleService: -# return AdminRoleService( -# session=session, -# role_repo=role_repo, -# ) -# -# -# def provide_wellknown_service( -# session: AsyncSession, -# wlk_repo: WellKnownRepository, -# ) -> WellKnownService: -# return WellKnownService( -# session=session, -# wlk_repo=wlk_repo, -# ) -# -# -# def provide_userinfo_service( -# session: AsyncSession, -# jwt: JWTService, -# user_repo: UserRepository, -# client_repo: ClientRepository, -# persistent_grant_repo: PersistentGrantRepository, -# ) -> UserInfoService: -# return UserInfoService( -# session=session, -# jwt=jwt, -# user_repo=user_repo, -# client_repo=client_repo, -# persistent_grant_repo=persistent_grant_repo, -# ) -# -# -# def provide_login_form_service( -# client_repo: ClientRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# session: AsyncSession, -# ) -> LoginFormService: -# return LoginFormService( -# client_repo=client_repo, oidc_repo=oidc_repo, session=session -# ) -# -# -# def provide_admin_auth_service( -# user_repo: UserRepository, -# password_service: PasswordHash, -# jwt_service: JWTService, -# ) -> AdminAuthService: -# return AdminAuthService( -# user_repo=user_repo, -# password_service=password_service, -# jwt_service=jwt_service, -# ) -# -# -# def provide_device_service( -# client_repo: ClientRepository, -# device_repo: DeviceRepository, -# session: AsyncSession, -# ) -> DeviceService: -# return DeviceService( -# session=session, client_repo=client_repo, device_repo=device_repo -# ) -# -# -# def provide_auth_third_party_oidc_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# user_repo: UserRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# persistent_grant_repo: PersistentGrantRepository, -# http_client: AsyncClient, -# ) -> AuthThirdPartyOIDCService: -# return AuthThirdPartyOIDCService( -# session=session, -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# oidc_repo=oidc_repo, -# http_client=http_client, -# ) -# -# -# def provide_auth_third_party_linkedin_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# user_repo: UserRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# persistent_grant_repo: PersistentGrantRepository, -# http_client: AsyncClient, -# ) -> ThirdPartyLinkedinService: -# return ThirdPartyLinkedinService( -# session=session, -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# oidc_repo=oidc_repo, -# http_client=http_client, -# ) -# -# -# def provide_third_party_google_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# user_repo: UserRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# persistent_grant_repo: PersistentGrantRepository, -# http_client: AsyncClient, -# ) -> ThirdPartyGoogleService: -# return ThirdPartyGoogleService( -# session=session, -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# oidc_repo=oidc_repo, -# http_client=http_client, -# ) -# -# -# def provide_third_party_facebook_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# user_repo: UserRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# persistent_grant_repo: PersistentGrantRepository, -# http_client: AsyncClient, -# ) -> ThirdPartyFacebookService: -# return ThirdPartyFacebookService( -# session=session, -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# oidc_repo=oidc_repo, -# http_client=http_client, -# ) -# -# -# def provide_third_party_gitlab_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# user_repo: UserRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# persistent_grant_repo: PersistentGrantRepository, -# http_client: AsyncClient, -# ) -> ThirdPartyGitLabService: -# return ThirdPartyGitLabService( -# session=session, -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# oidc_repo=oidc_repo, -# http_client=http_client, -# ) -# -# -# def provide_third_party_microsoft_service( -# session: AsyncSession, -# client_repo: ClientRepository, -# user_repo: UserRepository, -# oidc_repo: ThirdPartyOIDCRepository, -# persistent_grant_repo: PersistentGrantRepository, -# http_client: AsyncClient, -# ) -> ThirdPartyMicrosoftService: -# return ThirdPartyMicrosoftService( -# session=session, -# client_repo=client_repo, -# user_repo=user_repo, -# persistent_grant_repo=persistent_grant_repo, -# oidc_repo=oidc_repo, -# http_client=http_client, -# ) -# -# -# def provide_client_service( -# client_repo: ClientRepository, -# ) -> ClientService: -# return ClientService( -# client_repo=client_repo, -# ) +from src.business_logic.services import ( + ClientService, +) + +from src.data_access.postgresql.repositories import ( + ClientRepository, diff --git a/src/di/providers/services_factory.py b/src/di/providers/services_factory.py index 94fa8077..0d10a536 100644 --- a/src/di/providers/services_factory.py +++ b/src/di/providers/services_factory.py @@ -15,28 +15,6 @@ ) from src.business_logic.services import JWTService, PasswordHash - - -def provide_auth_service_factory( - session: AsyncSession, - client_repo: ClientRepository, - persistent_grant_repo: PersistentGrantRepository, - user_repo: UserRepository, - device_repo: DeviceRepository, - jwt_service: JWTService, - password_service: PasswordHash, -) -> AuthServiceFactory: - return AuthServiceFactory( - session=session, - client_repo=client_repo, - persistent_grant_repo=persistent_grant_repo, - user_repo=user_repo, - device_repo=device_repo, - jwt_service=jwt_service, - password_service=password_service, - ) - - def provide_third_party_auth_service_factory( session: AsyncSession, client_repo: ClientRepository, diff --git a/src/main.py b/src/main.py index 809130cb..fcb82c33 100644 --- a/src/main.py +++ b/src/main.py @@ -23,8 +23,6 @@ import src.presentation.admin_ui.controllers as ui import src.di.providers as prov -# from src.di.providers.rsa_keys import provide_rsa_keys, provide_rsa_keys_stub # removed from `prov` due to circular import -from src.di.providers import provide_jwt_manager_stub, provide_jwt_manager import logging from src.log import LOGGING_CONFIG from src.data_access.postgresql.repositories import UserRepository @@ -51,9 +49,8 @@ def get_application(test: bool = False) -> NewFastApi: allow_methods=["*"], allow_headers=["*"], ) - application = setup_exception_handlers(application) - - application = setup_exception_handlers(application) + #application = setup_exception_handlers(application) + setup_di(application) container = Container() container.db() diff --git a/src/presentation/admin_api/routes/group.py b/src/presentation/admin_api/routes/group.py index 4b6a874c..4c6efd46 100644 --- a/src/presentation/admin_api/routes/group.py +++ b/src/presentation/admin_api/routes/group.py @@ -7,7 +7,7 @@ from src.data_access.postgresql.repositories import GroupRepository from src.business_logic.services.admin_api import AdminGroupService -from src.data_access.postgresql.errors.user import DuplicationError +# from src.data_access.postgresql.errors.user import DuplicationError from src.presentation.admin_api.models.group import * from src.di.providers import provide_async_session_stub diff --git a/src/presentation/api/routes/authorization.py b/src/presentation/api/routes/authorization.py index 9efde474..6b1952a6 100644 --- a/src/presentation/api/routes/authorization.py +++ b/src/presentation/api/routes/authorization.py @@ -11,7 +11,7 @@ from src.business_logic.authorization import AuthServiceFactory from src.business_logic.authorization.dto import AuthRequestModel -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.services.login_form_service import LoginFormService from src.business_logic.services.password import PasswordHash from src.data_access.postgresql.repositories import ( @@ -90,7 +90,7 @@ async def post_authorize( persistent_grant_repo=PersistentGrantRepository(session), device_repo=DeviceRepository(session), password_service=PasswordHash(), - jwt_service=JWTService(), + jwt_service=provide_jwt_manager(), ) setattr(request_body, "user_code", user_code) auth_service: AuthServiceProtocol = auth_service_factory.get_service_impl( diff --git a/src/presentation/api/routes/debug_swagger.py b/src/presentation/api/routes/debug_swagger.py index 74b5a0eb..c17d8b74 100644 --- a/src/presentation/api/routes/debug_swagger.py +++ b/src/presentation/api/routes/debug_swagger.py @@ -3,7 +3,7 @@ from fastapi import APIRouter -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager debug_router = APIRouter( prefix="/userinfo", @@ -18,7 +18,7 @@ async def get_default_token( exp: Optional[int] = int(time.time() + 10000), scope: str = "profile", ) -> str: - jwt = JWTService() + jwt = provide_jwt_manager() payload: dict[str, Any] = {"sub": "1", "scope": scope} if with_iss_me: payload["iss"] = "me" @@ -26,7 +26,7 @@ async def get_default_token( payload["aud"] = ["admin", "userinfo", "introspection", "revoke"] if exp: payload["exp"] = exp - return await jwt.encode_jwt(payload) + return await jwt.encode(payload) @debug_router.get("/decode_token", response_model=dict) @@ -35,7 +35,7 @@ async def get_decode_token( issuer: Optional[str] = None, audience: Optional[str] = None, ) -> dict[str, Any]: - jwt = JWTService() + jwt = provide_jwt_manager() kwargs = {} if issuer is not None: kwargs["issuer"] = issuer diff --git a/src/presentation/api/routes/registration.py b/src/presentation/api/routes/registration.py index 8c46dc57..4f0a4260 100644 --- a/src/presentation/api/routes/registration.py +++ b/src/presentation/api/routes/registration.py @@ -58,7 +58,7 @@ async def get_all_clients( access_token: str = Header(description="Access token"), auth: None = Depends(access_token_middleware), session: AsyncSession = Depends(provide_async_session_stub), -) -> dict[str, list[dict[str, Any]]]: + ): client_service = ClientService( session=session, client_repo=ClientRepository(session) ) diff --git a/src/presentation/api/routes/revoke.py b/src/presentation/api/routes/revoke.py index 6d489354..1ea45157 100644 --- a/src/presentation/api/routes/revoke.py +++ b/src/presentation/api/routes/revoke.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Header, Request from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.services.tokens import TokenService from src.data_access.postgresql.repositories import ( PersistentGrantRepository, @@ -49,7 +49,7 @@ async def post_revoke_token( device_repo=DeviceRepository(session), code_challenge_repo=CodeChallengeRepository(session), blacklisted_repo=BlacklistedTokenRepository(session), - jwt_service=JWTService(), + jwt_service=provide_jwt_manager(), ) token_class.request = request token_class.request_body = request_body diff --git a/src/presentation/api/routes/userinfo.py b/src/presentation/api/routes/userinfo.py index dc6ff550..96443fb5 100644 --- a/src/presentation/api/routes/userinfo.py +++ b/src/presentation/api/routes/userinfo.py @@ -105,4 +105,4 @@ async def get_userinfo_jwt( logger.info("Collecting Claims from DataBase.") result = await userinfo_class.get_user_info() result = {k: v for k, v in result.items() if v is not None} - return await userinfo_class.jwt.encode_jwt(payload=result) + return await userinfo_class.jwt.encode(payload=result, algorithm='RS256') diff --git a/src/presentation/middleware/access_token_validation.py b/src/presentation/middleware/access_token_validation.py index cce35682..57dba22e 100644 --- a/src/presentation/middleware/access_token_validation.py +++ b/src/presentation/middleware/access_token_validation.py @@ -2,7 +2,8 @@ from jwt.exceptions import InvalidAudienceError, ExpiredSignatureError, InvalidKeyError, MissingRequiredClaimError from fastapi import Request, Depends from typing import Any - +from src.business_logic.jwt_manager.interfaces import JWTManagerProtocol +from src.di.providers import provide_jwt_manager from src.business_logic.services.jwt_token import JWTService from src.data_access.postgresql.repositories.blacklisted_token import BlacklistedTokenRepository from sqlalchemy.ext.asyncio import AsyncSession @@ -18,7 +19,7 @@ async def access_token_middleware( request: Request, session: AsyncSession = Depends(provide_async_session_stub), ) -> Any: - jwt_service = JWTService() + jwt_service = provide_jwt_manager() token = request.headers.get("access-token") blacklisted_repo = BlacklistedTokenRepository(session) @@ -27,7 +28,7 @@ async def access_token_middleware( if await blacklisted_repo.exists(token=token): raise IncorrectAuthTokenError("Access Token revoked") try: - await jwt_service.decode_token(token, audience='admin') + await jwt_service.decode(token, audience='admin') except (InvalidAudienceError, MissingRequiredClaimError): raise IncorrectAuthTokenError("Access Token doesn't have admin permissions") except ExpiredSignatureError: diff --git a/src/presentation/middleware/authorization_validation.py b/src/presentation/middleware/authorization_validation.py index 831f3c9b..07abba0f 100644 --- a/src/presentation/middleware/authorization_validation.py +++ b/src/presentation/middleware/authorization_validation.py @@ -2,10 +2,9 @@ from typing import Any from fastapi import Request, Depends from sqlalchemy.ext.asyncio import AsyncSession - from src.data_access.postgresql.errors.auth_token import IncorrectAuthTokenError from typing import Any -from src.business_logic.services.jwt_token import JWTService +# from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories import BlacklistedTokenRepository from src.di.providers import provide_async_session_stub, provide_jwt_manager from jwt.exceptions import InvalidAudienceError, ExpiredSignatureError, InvalidKeyError, MissingRequiredClaimError @@ -17,7 +16,7 @@ async def authorization_middleware( request: Request, session: AsyncSession = Depends(provide_async_session_stub), ) -> Any: - jwt_service = JWTService() + jwt_service = provide_jwt_manager() token = request.headers.get("authorization") or request.headers.get("auth-swagger") if token is None: raise IncorrectAuthTokenError("No authorization or auth-swagger in Request") @@ -34,15 +33,16 @@ async def authorization_middleware( try: if aud == "revoke": aud = "revocation" - await jwt_service.decode_token(token=token, audience=[aud, 'admin']) + await jwt_service.decode(token=token, audience=[aud, 'admin']) except (InvalidAudienceError, MissingRequiredClaimError): raise IncorrectAuthTokenError(f"Authorization Token doesn't have {aud} permissions") except ExpiredSignatureError: raise IncorrectAuthTokenError("Authorization Token expired") except InvalidKeyError: raise IncorrectAuthTokenError("Authorization Token can not be decoded with our private key") - except: - raise IncorrectAuthTokenError("Authorization Token can not be decoded") + except Exception as e: + # raise IncorrectAuthTokenError("Authorization Token can not be decoded") + raise e else: logger.info("Authorization Passed") diff --git a/tests/conftest.py b/tests/conftest.py index 4996415e..4eada55f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,7 +39,7 @@ CodeChallengeRepository, ) from src.business_logic.services.password import PasswordHash -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.services.introspection import IntrospectionService from src.business_logic.services.tokens import TokenService from src.business_logic.services.login_form_service import LoginFormService @@ -188,7 +188,7 @@ async def authorization_service( persistent_grant_repo=PersistentGrantRepository(session=connection), device_repo=DeviceRepository(session=connection), password_service=PasswordHash(), - jwt_service=JWTService(), + jwt_service=provide_jwt_manager(), ) return auth_service @@ -199,7 +199,7 @@ async def end_session_service(connection: AsyncSession) -> EndSessionService: session=connection, client_repo=ClientRepository(session=connection), persistent_grant_repo=PersistentGrantRepository(session=connection), - jwt_service=JWTService(), + #jwt_service=provide_jwt_manager(), ) return end_sess_service @@ -213,7 +213,7 @@ async def introspection_service( client_repo=ClientRepository(session=connection), persistent_grant_repo=PersistentGrantRepository(session=connection), user_repo=UserRepository(session=connection), - jwt=JWTService(), + jwt=provide_jwt_manager(), ) return intro_service @@ -222,7 +222,7 @@ async def introspection_service( async def user_info_service(connection: AsyncSession) -> UserInfoService: user_info = UserInfoService( session=connection, - jwt=JWTService(), + jwt=provide_jwt_manager(), client_repo=ClientRepository(session=connection), persistent_grant_repo=PersistentGrantRepository(session=connection), user_repo=UserRepository(session=connection), @@ -239,7 +239,7 @@ async def token_service(connection: AsyncSession) -> TokenService: user_repo=UserRepository(session=connection), device_repo=DeviceRepository(session=connection), code_challenge_repo=CodeChallengeRepository(session=connection), - jwt_service=JWTService(), + jwt_service=provide_jwt_manager(), blacklisted_repo=BlacklistedTokenRepository(session=connection), ) return tk_service @@ -339,7 +339,7 @@ async def admin_auth_service(connection: AsyncSession) -> AdminAuthService: admin_auth_service = AdminAuthService( user_repo=UserRepository(session=connection), password_service=PasswordHash(), - jwt_service=JWTService(), + jwt_service=provide_jwt_manager(), ) return admin_auth_service diff --git a/tests/test_admin_api/test_admin_group.py b/tests/test_admin_api/test_admin_group.py index febf6eff..55eca38d 100644 --- a/tests/test_admin_api/test_admin_group.py +++ b/tests/test_admin_api/test_admin_group.py @@ -5,7 +5,7 @@ from sqlalchemy import insert, delete from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from src.data_access.postgresql.repositories.user import UserRepository from src.data_access.postgresql.repositories.groups import GroupRepository @@ -13,6 +13,8 @@ import logging from sqlalchemy import exc from typing import Any +from src.business_logic.jwt_manager.dto import AccessTokenPayload +import time logger = logging.getLogger(__name__) @@ -21,11 +23,15 @@ @pytest.mark.asyncio class TestAdminGroupEndpoint: async def setup_base(self, connection:AsyncSession, user_id: int = 1000) -> None: - self.access_token = await JWTService().encode_jwt( - payload={ - "stand": "CrazyDiamond", - "aud":["admin"] - } + self.access_token = await provide_jwt_manager().encode( + payload=AccessTokenPayload( + sub = 1, + iat=1, + exp=int(time.time()) + 100000, + client_id='123123', + arc=1, + aud = "admin" + ) ) self.group_repo = GroupRepository(connection) self.role_repo = RoleRepository(connection) @@ -224,12 +230,14 @@ async def test_delete_group(self, connection: AsyncSession, client: AsyncClient) ) assert response.status_code == status.HTTP_200_OK headers = {"access-token": self.access_token} - response = await client.request( - "GET", - f"/administration/groups/{group_id}", - headers=headers - ) - assert response.status_code == status.HTTP_404_NOT_FOUND + # response = await client.request( + # "GET", + # f"/administration/groups/{group_id}", + # headers=headers + # ) + with pytest.raises(ValueError): + await self.group_repo.get_by_id(group_id=group_id) + # assert response.status_code == status.HTTP_404_NOT_FOUND async def test_create_update_group(self, connection: AsyncSession, client: AsyncClient) -> None: await self.setup_base( diff --git a/tests/test_admin_api/test_admin_role.py b/tests/test_admin_api/test_admin_role.py index 98ecfc4d..55a2b626 100644 --- a/tests/test_admin_api/test_admin_role.py +++ b/tests/test_admin_api/test_admin_role.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import sessionmaker -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from src.data_access.postgresql.repositories.user import UserRepository from src.data_access.postgresql.repositories.groups import GroupRepository @@ -15,6 +15,8 @@ from sqlalchemy import exc from sqlalchemy.ext.asyncio.engine import AsyncEngine from typing import Any +from src.business_logic.jwt_manager.dto import AccessTokenPayload +import time logger = logging.getLogger(__name__) @@ -23,11 +25,15 @@ @pytest.mark.asyncio class TestAdminRoleEndpoint: async def setup_base(self, connection:AsyncSession, user_id: int = 1000) -> None: - self.access_token = await JWTService().encode_jwt( - payload={ - "stand": "CrazyDiamond", - "aud":["admin"] - } + self.access_token = await provide_jwt_manager().encode( + payload=AccessTokenPayload( + sub = 1, + iat=1, + exp=int(time.time()) + 100000, + client_id='123123', + arc=1, + aud = "admin" + ) ) self.group_repo = GroupRepository(connection) self.role_repo = RoleRepository(connection) @@ -198,12 +204,8 @@ async def test_delete_role(self, connection: AsyncSession, client: AsyncClient) ) assert response.status_code == status.HTTP_200_OK headers = {"access-token": self.access_token} - response = await client.request( - "GET", - f"/administration/roles/{role_id}", - headers=headers, - ) - assert response.status_code == status.HTTP_404_NOT_FOUND + with pytest.raises(ValueError): + await self.role_repo.get_role_by_id(role_id=role_id) async def test_create_update_role(self, connection: AsyncSession, client: AsyncClient) -> None: await self.setup_base(connection) diff --git a/tests/test_admin_api/test_admin_user.py b/tests/test_admin_api/test_admin_user.py index 5c43b3b8..6de96329 100644 --- a/tests/test_admin_api/test_admin_user.py +++ b/tests/test_admin_api/test_admin_user.py @@ -5,8 +5,8 @@ from sqlalchemy import insert, delete, text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import sessionmaker - -from src.business_logic.services.jwt_token import JWTService +import time +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.persistent_grant import PersistentGrant from src.data_access.postgresql.repositories.user import UserRepository from src.data_access.postgresql.repositories.groups import GroupRepository @@ -15,6 +15,7 @@ from sqlalchemy import exc from sqlalchemy.ext.asyncio.engine import AsyncEngine from typing import Any +from src.business_logic.jwt_manager.dto import AccessTokenPayload logger = logging.getLogger(__name__) @@ -23,11 +24,15 @@ @pytest.mark.asyncio class TestAdminUserEndpoint: async def setup_base(self, connection:AsyncSession, user_id: int = 1000) -> None: - self.access_token = await JWTService().encode_jwt( - payload={ - "stand": "CrazyDiamond", - "aud":["admin"] - } + self.access_token = await provide_jwt_manager().encode( + payload=AccessTokenPayload( + sub = 1, + iat=1, + exp=int(time.time()) + 100000, + client_id='123123', + arc=1, + aud = "admin" + ) ) self.group_repo = GroupRepository(connection) self.role_repo = RoleRepository(connection) diff --git a/tests/test_api/test_client_endpoint.py b/tests/test_api/test_client_endpoint.py index f67ec286..17989224 100644 --- a/tests/test_api/test_client_endpoint.py +++ b/tests/test_api/test_client_endpoint.py @@ -7,7 +7,7 @@ from sqlalchemy import select, insert from sqlalchemy.orm import sessionmaker -from src.business_logic.services import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories import ClientRepository from src.data_access.postgresql.tables.client import Client @@ -102,7 +102,7 @@ async def test_unsuccessful_update_non_existent_client_id(self, client: AsyncCli class TestClientAllEndpointGET: @pytest.mark.asyncio async def test_successful_get_all_clients(self, client: AsyncClient) -> None: - self.access_token = await JWTService().encode_jwt( + self.access_token = await provide_jwt_manager().encode( payload={ "stand": "CrazyDiamond", "aud": ["admin"] diff --git a/tests/test_api/test_introspection_endpoint.py b/tests/test_api/test_introspection_endpoint.py index 3e951a07..3e5e5d42 100644 --- a/tests/test_api/test_introspection_endpoint.py +++ b/tests/test_api/test_introspection_endpoint.py @@ -5,7 +5,7 @@ from fastapi import status from httpx import AsyncClient from typing import Any -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories import ( UserRepository, PersistentGrantRepository, @@ -24,7 +24,7 @@ class TestIntrospectionEndpoint: async def test_successful_introspection_request( self, connection: AsyncSession, client: AsyncClient ) -> None: - jwt = JWTService() + jwt = provide_jwt_manager() persistent_grant_repo = PersistentGrantRepository(connection) grant_type = "authorization_code" payload = { @@ -33,8 +33,8 @@ async def test_successful_introspection_request( "client_id": "test_client", "aud": ["introspection"], } - introspection_token = await jwt.encode_jwt(payload=payload) - access_token = await jwt.encode_jwt( + introspection_token = await jwt.encode(payload=payload) + access_token = await jwt.encode( payload={ "sub": "1", "client_id": "test_client", @@ -88,11 +88,11 @@ async def test_successful_introspection_request( async def test_successful_introspection_request_spoiled_token( self, connection: AsyncSession, client: AsyncClient ) -> None: - jwt = JWTService() + jwt = provide_jwt_manager() persistent_grant_repo = PersistentGrantRepository(connection) grant_type = "authorization_code" payload = {"sub": 1, "exp": time.time(), "aud": ["introspection"]} - introspection_token = await jwt.encode_jwt(payload=payload) + introspection_token = await jwt.encode(payload=payload) await persistent_grant_repo.create( grant_type=grant_type, @@ -102,7 +102,7 @@ async def test_successful_introspection_request_spoiled_token( expiration_time=1, ) headers = { - "authorization": await jwt.encode_jwt( + "authorization": await jwt.encode( payload={"sub": "1", "aud": ["introspection"]} ), "Content-Type": "application/x-www-form-urlencoded", @@ -122,7 +122,7 @@ async def test_successful_introspection_request_spoiled_token( async def test_unsuccessful_introspection_request_incorrect_token( self, connection: AsyncSession, client: AsyncClient ) -> None: - jwt = JWTService() + jwt = provide_jwt_manager() persistent_grant_repo = PersistentGrantRepository(connection) grant_type = "authorization_code" payload = { @@ -131,8 +131,8 @@ async def test_unsuccessful_introspection_request_incorrect_token( "client_id": "test_client", "aud": ["introspection"], } - introspection_token = await jwt.encode_jwt(payload=payload) - access_token = await jwt.encode_jwt( + introspection_token = await jwt.encode(payload=payload) + access_token = await jwt.encode( payload={ "sub": "1", "client_id": "test_client", diff --git a/tests/test_api/test_revoke_endpoint.py b/tests/test_api/test_revoke_endpoint.py index 50e2a1ad..c6a0871d 100644 --- a/tests/test_api/test_revoke_endpoint.py +++ b/tests/test_api/test_revoke_endpoint.py @@ -4,7 +4,7 @@ from fastapi import status from httpx import AsyncClient -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.repositories.persistent_grant import ( PersistentGrantRepository, ) @@ -18,7 +18,7 @@ class TestRevokationEndpoint: async def test_successful_revoke_request( self, connection: AsyncSession, client: AsyncClient ) -> None: - jwt = JWTService() + jwt = provide_jwt_manager() persistent_grant_repo = PersistentGrantRepository(connection) grant_type = "refresh_token" revoke_token = "----token_to_delete-----" @@ -32,7 +32,7 @@ async def test_successful_revoke_request( ) await connection.commit() headers = { - "authorization": await jwt.encode_jwt( + "authorization": await jwt.encode( payload={"sub": "1", "aud":["revocation"]} ), "Content-Type": "application/x-www-form-urlencoded", @@ -50,11 +50,11 @@ async def test_successful_revoke_request( async def test_token_does_not_exists( self, client: AsyncClient ) -> None: - jwt = JWTService() + jwt = provide_jwt_manager() grant_type = "code" revoke_token = "----token_not_exists-----" headers = { - "authorization": await jwt.encode_jwt( + "authorization": await jwt.encode( payload={"sub": "1", "aud":["revocation"]} ), "Content-Type": "application/x-www-form-urlencoded", diff --git a/tests/test_api/test_token_endpoint.py b/tests/test_api/test_token_endpoint.py index 2338bedf..b14c2e6c 100644 --- a/tests/test_api/test_token_endpoint.py +++ b/tests/test_api/test_token_endpoint.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio.engine import AsyncEngine -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.services.tokens import TokenService from src.data_access.postgresql.repositories.persistent_grant import ( PersistentGrantRepository, @@ -17,7 +17,7 @@ @pytest.mark.asyncio class TestTokenEndpoint: - jwt_service = JWTService() + jwt_service = provide_jwt_manager() refresh_token = None content_type = "application/x-www-form-urlencoded" @@ -243,7 +243,7 @@ async def test_refresh_token_authorization( to 'refresh_token' in params """ - test_token = await self.jwt_service.encode_jwt( + test_token = await self.jwt_service.encode( payload={"sub": 1, "exp": time.time() + 3600} ) diff --git a/tests/test_api/test_userinfo_endpoint.py b/tests/test_api/test_userinfo_endpoint.py index d6d9efb1..ce992984 100644 --- a/tests/test_api/test_userinfo_endpoint.py +++ b/tests/test_api/test_userinfo_endpoint.py @@ -6,7 +6,10 @@ from sqlalchemy import delete, insert from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine from sqlalchemy.orm import sessionmaker -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager +from src.business_logic.services import UserInfoService +from src.data_access.postgresql.repositories.persistent_grant import PersistentGrantRepository +# from src.business_logic.services.jwt_token import JWTService from src.business_logic.services import UserInfoService from src.data_access.postgresql.repositories.persistent_grant import ( PersistentGrantRepository, @@ -40,10 +43,8 @@ async def test_successful_userinfo_get_request( client: AsyncClient, engine: AsyncEngine, ) -> None: - jwt = JWTService() - token = await jwt.encode_jwt( - payload={"sub": 1, "scope": "profile", "aud": ["userinfo"]} - ) + jwt = provide_jwt_manager() + token = await jwt.encode(payload={"sub": 1, 'scope': 'profile', "aud":["userinfo"]}) headers = { "authorization": token, } @@ -60,9 +61,7 @@ async def test_successful_userinfo_jwt_get_request( client: AsyncClient, engine: AsyncEngine, ) -> None: - token = await user_info_service.jwt.encode_jwt( - payload={"sub": 1, "scope": "profile", "aud": ["userinfo"]} - ) + token = await user_info_service.jwt.encode(payload={"sub": 1, 'scope': 'profile', "aud":["userinfo"]}) headers = {"authorization": token, "accept": "application/json"} user_info_service.authorization = token response = await client.request("GET", "/userinfo/jwt", headers=headers) @@ -81,8 +80,8 @@ async def test_userinfo_and_userinfo_jwt_get_requests_with_incorrect_token( client: AsyncClient, engine: AsyncEngine, ) -> None: - token = await user_info_service.jwt.encode_jwt( - payload={"blablabla": "blablabla", "aud": ["userinfo"]} + token = await user_info_service.jwt.encode( + payload={"blablabla": "blablabla", "aud":["userinfo"]} ) for url in ("/userinfo/", "/userinfo/jwt"): headers = { @@ -100,9 +99,7 @@ async def test_userinfo_and_userinfo_jwt_get_requests_with_user_without_claims( client: AsyncClient, engine: AsyncEngine, ) -> None: - token = await user_info_service.jwt.encode_jwt( - payload={"sub": "2", "aud": ["userinfo"]} - ) + token = await user_info_service.jwt.encode(payload={"sub": "2", "aud":["userinfo"]}) for url in ("/userinfo/", "/userinfo/jwt"): headers = { "authorization": token, @@ -144,8 +141,8 @@ async def test_userinfo_post_request_with_incorrect_token( client: AsyncClient, engine: AsyncEngine, ) -> None: - token = await user_info_service.jwt.encode_jwt( - payload={"blablabla": "blablabla", "aud": ["userinfo"]} + token = await user_info_service.jwt.encode( + payload={"blablabla": "blablabla", "aud":["userinfo"]} ) headers = {"authorization": token} response = await client.request("POST", "/userinfo/", headers=headers) @@ -160,9 +157,7 @@ async def test_userinfo_post_request_with_user_without_claims( client: AsyncClient, engine: AsyncEngine, ) -> None: - token = await user_info_service.jwt.encode_jwt( - payload={"sub": "2", "aud": ["userinfo"]} - ) + token = await user_info_service.jwt.encode(payload={"sub": "2", "aud":["userinfo"]}) headers = { "authorization": token, } diff --git a/tests/test_api/test_well_known_endpoint.py b/tests/test_api/test_well_known_endpoint.py index 0aac30af..c634e85b 100644 --- a/tests/test_api/test_well_known_endpoint.py +++ b/tests/test_api/test_well_known_endpoint.py @@ -3,7 +3,7 @@ import pytest from fastapi import status from httpx import AsyncClient -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from jwkest import base64_to_long from Crypto.PublicKey.RSA import construct @@ -44,9 +44,9 @@ async def test_successful_jwks_request(self, client: AsyncClient) -> None: assert type(response_content["keys"]) == list assert type(response_content["keys"][0]) == dict - jwt_service = JWTService() + jwt_service = provide_jwt_manager() response_content = response_content["keys"][0] - test_token = await jwt_service.encode_jwt(payload={"sub":1}) + test_token = await jwt_service.encode(payload={"sub":1}) if response_content["alg"] == "RS256": n = base64_to_long(response_content["n"]) diff --git a/tests/test_e2e/test_device_flow.py b/tests/test_e2e/test_device_flow.py index 4d33cf10..c345fdd8 100644 --- a/tests/test_e2e/test_device_flow.py +++ b/tests/test_e2e/test_device_flow.py @@ -4,7 +4,7 @@ from sqlalchemy import exists, insert, select from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables import User, UserClaim from src.data_access.postgresql.tables.device import Device @@ -86,9 +86,9 @@ async def test_successful_device_flow(self, client: AsyncClient, connection: Asy assert response.status_code == status.HTTP_200_OK # Stage 7: User ends the session - jwt_service = JWTService() + jwt_service = provide_jwt_manager() TOKEN_HINT_DATA["sub"] = user_id - id_token_hint = await jwt_service.encode_jwt(payload=TOKEN_HINT_DATA) + id_token_hint = await jwt_service.encode(payload=TOKEN_HINT_DATA) end_session_params = {"id_token_hint": id_token_hint, "post_logout_redirect_uri": "http://thompson-chung.com/", "state": "test_state"} response = await client.request("GET", "/endsession/", params=end_session_params) diff --git a/tests/test_e2e/test_third_party_github_flow.py b/tests/test_e2e/test_third_party_github_flow.py index b8a15c16..512bbf40 100644 --- a/tests/test_e2e/test_third_party_github_flow.py +++ b/tests/test_e2e/test_third_party_github_flow.py @@ -8,7 +8,7 @@ IdentityProviderMapped, ) from src.data_access.postgresql.tables.users import UserClaim, User -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from sqlalchemy.ext.asyncio import AsyncSession from typing import Any @@ -99,14 +99,14 @@ async def replace_get(*args: Any, **kwargs: Any) -> str: assert response.status_code == status.HTTP_200_OK # Stage 5: EndSession endpoint deletes all records in the Persistent grant table for the corresponding user - jwt_service = JWTService() + jwt_service = provide_jwt_manager() token_hint_data = { "sub": user_id, "client_id": "spider_man", "type": "code", } - id_token_hint = await jwt_service.encode_jwt(payload=token_hint_data) + id_token_hint = await jwt_service.encode(payload=token_hint_data) logout_params = { "id_token_hint": id_token_hint, diff --git a/tests/test_e2e/test_third_party_gitlab_flow.py b/tests/test_e2e/test_third_party_gitlab_flow.py index 5281d4ce..d4baf598 100644 --- a/tests/test_e2e/test_third_party_gitlab_flow.py +++ b/tests/test_e2e/test_third_party_gitlab_flow.py @@ -8,7 +8,7 @@ IdentityProviderMapped, ) from src.data_access.postgresql.tables.users import UserClaim, User -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from sqlalchemy.ext.asyncio import AsyncSession from typing import Any @@ -99,14 +99,14 @@ async def replace_get(*args: Any, **kwargs: Any) -> str: assert response.status_code == status.HTTP_200_OK # Stage 5: EndSession endpoint deletes all records in the Persistent grant table for the corresponding user - jwt_service = JWTService() + jwt_service = JWTManager() token_hint_data = { "sub": user_id, "client_id": "spider_man", "type": "code", } - id_token_hint = await jwt_service.encode_jwt(payload=token_hint_data) + id_token_hint = await jwt_service.encode(payload=token_hint_data) logout_params = { "id_token_hint": id_token_hint, diff --git a/tests/test_e2e/test_third_party_google_flow.py b/tests/test_e2e/test_third_party_google_flow.py index b5617cc3..54369786 100644 --- a/tests/test_e2e/test_third_party_google_flow.py +++ b/tests/test_e2e/test_third_party_google_flow.py @@ -6,7 +6,7 @@ from sqlalchemy import insert, select from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.identity_resource import ( IdentityProviderMapped, IdentityProviderState, @@ -100,14 +100,14 @@ async def replace_get(*args: Any, **kwargs: Any) -> str: assert response.status_code == status.HTTP_200_OK # Stage 5: EndSession endpoint deletes all records in the Persistent grant table for the corresponding user - jwt_service = JWTService() + jwt_service = provide_jwt_manager() token_hint_data = { "sub": user_id, "client_id": "spider_man", "type": "code", } - id_token_hint = await jwt_service.encode_jwt(payload=token_hint_data) + id_token_hint = await jwt_service.encode(payload=token_hint_data) logout_params = { "id_token_hint": id_token_hint, diff --git a/tests/test_e2e/test_third_party_linkedin_flow.py b/tests/test_e2e/test_third_party_linkedin_flow.py index 03d48465..47c8bbc2 100644 --- a/tests/test_e2e/test_third_party_linkedin_flow.py +++ b/tests/test_e2e/test_third_party_linkedin_flow.py @@ -6,7 +6,7 @@ from sqlalchemy import insert, select from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.identity_resource import ( IdentityProviderMapped, IdentityProviderState, @@ -99,14 +99,14 @@ async def replace_get(*args: Any, **kwargs: Any) -> str: assert response.status_code == status.HTTP_200_OK # Stage 5: EndSession endpoint deletes all records in the Persistent grant table for the corresponding user - jwt_service = JWTService() + jwt_service = provide_jwt_manager() token_hint_data = { "sub": user_id, "client_id": "spider_man", "type": "code", } - id_token_hint = await jwt_service.encode_jwt(payload=token_hint_data) + id_token_hint = await jwt_service.encode(payload=token_hint_data) logout_params = { "id_token_hint": id_token_hint, diff --git a/tests/test_e2e/test_third_party_microsoft_flow.py b/tests/test_e2e/test_third_party_microsoft_flow.py index 69fe8fcc..a37f8af4 100644 --- a/tests/test_e2e/test_third_party_microsoft_flow.py +++ b/tests/test_e2e/test_third_party_microsoft_flow.py @@ -8,7 +8,7 @@ IdentityProviderMapped, ) from src.data_access.postgresql.tables.users import UserClaim, User -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from sqlalchemy.ext.asyncio import AsyncSession from typing import Any @@ -99,14 +99,14 @@ async def replace_get(*args: Any, **kwargs: Any) -> str: assert response.status_code == status.HTTP_200_OK # Stage 5: EndSession endpoint deletes all records in the Persistent grant table for the corresponding user - jwt_service = JWTService() + jwt_service = provide_jwt_manager() token_hint_data = { "sub": user_id, "client_id": "spider_man", "type": "code", } - id_token_hint = await jwt_service.encode_jwt(payload=token_hint_data) + id_token_hint = await jwt_service.encode(payload=token_hint_data) logout_params = { "id_token_hint": id_token_hint, diff --git a/tests/test_e2e/test_token_id_token_flow.py b/tests/test_e2e/test_token_id_token_flow.py index b58f3c3d..a0d36c12 100644 --- a/tests/test_e2e/test_token_id_token_flow.py +++ b/tests/test_e2e/test_token_id_token_flow.py @@ -4,7 +4,7 @@ from sqlalchemy import insert, delete from src.data_access.postgresql.tables.users import UserClaim -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio.engine import AsyncEngine diff --git a/tests/test_ui_admin/test_create.py b/tests/test_ui_admin/test_create.py index c66c224b..7ce867bf 100644 --- a/tests/test_ui_admin/test_create.py +++ b/tests/test_ui_admin/test_create.py @@ -50,7 +50,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, insert -from src.business_logic.services import JWTService +# from src.business_logic.services import JWTService async def fake_authenticate(*args, **kwargs): return None diff --git a/tests/test_ui_admin/test_delete.py b/tests/test_ui_admin/test_delete.py index 82c9d2e6..2ed1939f 100644 --- a/tests/test_ui_admin/test_delete.py +++ b/tests/test_ui_admin/test_delete.py @@ -3,7 +3,7 @@ from httpx import AsyncClient from sqlalchemy import insert, exists, select, delete from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.persistent_grant import PersistentGrant import logging from typing import Any diff --git a/tests/test_ui_admin/test_read.py b/tests/test_ui_admin/test_read.py index 08310d7b..f6ee9b6d 100644 --- a/tests/test_ui_admin/test_read.py +++ b/tests/test_ui_admin/test_read.py @@ -3,7 +3,7 @@ from httpx import AsyncClient from sqlalchemy import insert, exists, select, delete from sqlalchemy.ext.asyncio import AsyncSession -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.tables.persistent_grant import PersistentGrant import logging from typing import Any diff --git a/tests/test_unit/fixtures.py b/tests/test_unit/fixtures.py index f999e7cf..0142d537 100644 --- a/tests/test_unit/fixtures.py +++ b/tests/test_unit/fixtures.py @@ -5,7 +5,7 @@ import pytest_asyncio from pydantic import SecretStr -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.third_party_auth.dto.request import ( StateRequestModel, ThirdPartyAccessTokenRequestModel, @@ -164,7 +164,7 @@ def device_request_model() -> DeviceRequestModel: return request_model -service = JWTService() +service = provide_jwt_manager() TOKEN_HINT_DATA = { "sub": 3, @@ -177,16 +177,16 @@ def device_request_model() -> DeviceRequestModel: class TokenHint: - sv = JWTService() + sv = provide_jwt_manager() @classmethod async def get_token_hint(cls) -> str: - token_hint = await cls.sv.encode_jwt(payload=TOKEN_HINT_DATA) + token_hint = await cls.sv.encode(payload=TOKEN_HINT_DATA) return token_hint @classmethod async def get_short_token_hint(cls) -> str: - short_token_hint = await cls.sv.encode_jwt( + short_token_hint = await cls.sv.encode( payload=SHORT_TOKEN_HINT_DATA ) return short_token_hint diff --git a/tests/test_unit/test_middleware/test_accesstoken.py b/tests/test_unit/test_middleware/test_accesstoken.py index 4e95ea40..51d06f5b 100644 --- a/tests/test_unit/test_middleware/test_accesstoken.py +++ b/tests/test_unit/test_middleware/test_accesstoken.py @@ -1,12 +1,13 @@ import pytest import mock -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.presentation.middleware.access_token_validation import access_token_middleware from starlette.types import ASGIApp from fastapi import status from src.presentation.api import router from src.data_access.postgresql.repositories import BlacklistedTokenRepository from sqlalchemy.ext.asyncio.engine import AsyncEngine +from src.business_logic.jwt_manager import JWTManager from typing import Any async def new_decode_token(*args: Any, **kwargs: Any) -> bool: @@ -44,7 +45,7 @@ class TestAccessTokenMiddleware: async def test_successful_auth(self, connection: AsyncSession) -> None: test_token = "Bearer AccessToken" with mock.patch.object( - JWTService, "decode_token", new=new_decode_token + JWTManager, "decode", new=new_decode_token ): request = NewRequest(connection) diff --git a/tests/test_unit/test_middleware/test_authorization.py b/tests/test_unit/test_middleware/test_authorization.py index 3d12838e..7a19746c 100644 --- a/tests/test_unit/test_middleware/test_authorization.py +++ b/tests/test_unit/test_middleware/test_authorization.py @@ -5,7 +5,9 @@ from fastapi import status from starlette.types import ASGIApp from sqlalchemy.ext.asyncio.engine import AsyncEngine -from src.business_logic.services.jwt_token import JWTService + +from src.business_logic.jwt_manager import JWTManager +from src.di.providers import provide_jwt_manager from src.presentation.api import router from typing import Any, Callable, MutableMapping from fastapi import Request @@ -54,7 +56,7 @@ async def test_successful_auth(self, connection: AsyncSession) -> None: request = RequestTest(connection) with mock.patch.object( - JWTService, "decode_token", new=new_decode_token + JWTManager, "decode", new=new_decode_token ): for request_with_auth in self.REQUESTS_WITH_AUTH: request = RequestTest(connection) @@ -69,7 +71,7 @@ async def test_successful_auth_with_swagger(self, connection: AsyncSession) -> N request = RequestTest(connection) with mock.patch.object( - JWTService, "decode_token", new=new_decode_token + JWTManager, "decode", new=new_decode_token ): for request_with_auth in self.REQUESTS_WITH_AUTH: request = RequestTest(connection) diff --git a/tests/test_unit/test_services/test_admin_auth_service.py b/tests/test_unit/test_services/test_admin_auth_service.py index 64945558..892fc8a1 100644 --- a/tests/test_unit/test_services/test_admin_auth_service.py +++ b/tests/test_unit/test_services/test_admin_auth_service.py @@ -3,12 +3,15 @@ from sqlalchemy import delete from src.business_logic.services.admin_auth import AdminAuthService from src.business_logic.dto.admin_credentials import AdminCredentialsDTO -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.data_access.postgresql.errors.user import UserNotFoundError from src.data_access.postgresql.errors import WrongPasswordError -from time import time, sleep +from src.business_logic.jwt_manager.dto import AccessTokenPayload +import time from jwt.exceptions import InvalidAudienceError, ExpiredSignatureError from pydantic import SecretStr +from starlette.responses import RedirectResponse + @pytest.mark.asyncio class TestAdminAuthService: @@ -38,33 +41,37 @@ async def test_authenticate( self, admin_auth_service: AdminAuthService ) -> None: service = admin_auth_service - payload = { - "aud": [ - "admin", - ], - "exp": int(time()) + 1000, - } - token = await JWTService().encode_jwt(payload) + payload = AccessTokenPayload( + sub=1, + iat=1, + client_id=1, + aud="oidc:admin_ui", + exp=int(time.time()) + 1000, + ) + token = await provide_jwt_manager().encode(payload) result = await service.authenticate(token=token) assert result is None - payload = { - "aud": [ - "not_admin", - ], - "exp": int(time()) + 1000, - } - token = await JWTService().encode_jwt(payload) + payload = AccessTokenPayload( + aud="not_admin", + sub=1, + iat=1, + client_id=1, + exp=int(time.time()) + 1000, + ) + + token = await provide_jwt_manager().encode(payload) result = await service.authenticate(token=token) - assert result is not None + assert type(result) is RedirectResponse - payload = { - "aud": [ - "admin", - ], - "exp": 0, - } - token = await JWTService().encode_jwt(payload) + payload = AccessTokenPayload( + sub=1, + iat=1, + client_id=1, + aud="oidc:admin_ui", + exp=0, + ) + token = await provide_jwt_manager().encode(payload) result = await service.authenticate(token=token) assert result is not None diff --git a/tests/test_unit/test_services/test_end_session_service.py b/tests/test_unit/test_services/test_end_session_service.py index 23055b20..71e92d4d 100644 --- a/tests/test_unit/test_services/test_end_session_service.py +++ b/tests/test_unit/test_services/test_end_session_service.py @@ -8,7 +8,7 @@ PersistentGrantNotFoundError, ) -from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA +# from tests.test_unit.fixtures import end_session_request_model, TOKEN_HINT_DATA from src.business_logic.services.endsession import EndSessionService from src.presentation.api.models.endsession import RequestEndSessionModel from sqlalchemy.ext.asyncio.engine import AsyncEngine diff --git a/tests/test_unit/test_services/test_userinfo_service.py b/tests/test_unit/test_services/test_userinfo_service.py index 17ea188d..ed7eecb2 100644 --- a/tests/test_unit/test_services/test_userinfo_service.py +++ b/tests/test_unit/test_services/test_userinfo_service.py @@ -23,7 +23,7 @@ async def test_get_user_info_and_get_user_info_jwt( "aud": ["userinfo"], } - token = await service.jwt.encode_jwt(payload=data_to_code) + token = await service.jwt.encode(payload=data_to_code) service.authorization = token expected_part_one = { diff --git a/tests/test_unit/test_services/test_wellknown_service.py b/tests/test_unit/test_services/test_wellknown_service.py index d6dd6724..c232a66f 100644 --- a/tests/test_unit/test_services/test_wellknown_service.py +++ b/tests/test_unit/test_services/test_wellknown_service.py @@ -5,7 +5,7 @@ from jwkest import base64_to_long from src.dyna_config import DOMAIN_NAME -from src.business_logic.services.jwt_token import JWTService +from src.di.providers import provide_jwt_manager from src.business_logic.services.well_known import WellKnownService from typing import Any, no_type_check @@ -142,9 +142,9 @@ async def test_jwks_RSA( wlk_services: WellKnownService, ) -> None: wks = wlk_services - jwt_service = JWTService() + jwt_service = provide_jwt_manager() result = await wks.get_jwks() - test_token = await jwt_service.encode_jwt(payload={"sub": 1}) + test_token = await jwt_service.encode(payload={"sub": 1}) if result["alg"] == "RS256": n = base64_to_long(result["n"])