Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
)

if TYPE_CHECKING:
from src.business_logic.services import JWTService, PasswordHash
from src.business_logic.services import PasswordHash
from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager


FactoryMethod = Callable[..., AuthServiceProtocol]
Expand All @@ -41,7 +42,7 @@ def __init__(
persistent_grant_repo: PersistentGrantRepository,
device_repo: DeviceRepository,
password_service: PasswordHash,
jwt_service: JWTService,
jwt_service: JWTManager,
) -> None:
self.session = session
self._client_repo = client_repo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

if TYPE_CHECKING:
from src.business_logic.authorization.interfaces import AuthServiceProtocol
from src.business_logic.services import JWTService, PasswordHash
from src.business_logic.services import PasswordHash # JWTService,
from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager
from src.data_access.postgresql.repositories import (
ClientRepository,
UserRepository,
Expand All @@ -26,7 +27,7 @@ def _create_id_token_auth_service(
client_repo: ClientRepository,
user_repo: UserRepository,
password_service: PasswordHash,
jwt_service: JWTService,
jwt_service: JWTManager,
**kwargs: Any,
) -> AuthServiceProtocol:
return IdTokenAuthService(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

if TYPE_CHECKING:
from src.business_logic.authorization.interfaces import AuthServiceProtocol
from src.business_logic.services import JWTService, PasswordHash
from src.business_logic.services import PasswordHash # JWTService,
from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager
from src.data_access.postgresql.repositories import (
ClientRepository,
UserRepository,
Expand All @@ -28,7 +29,7 @@ def _create_id_token_token_auth_service(
client_repo: ClientRepository,
user_repo: UserRepository,
password_service: PasswordHash,
jwt_service: JWTService,
jwt_service: JWTManager,
**kwargs: Any,
) -> AuthServiceProtocol:
return IdTokenTokenAuthService(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

if TYPE_CHECKING:
from src.business_logic.authorization.interfaces import AuthServiceProtocol
from src.business_logic.services import JWTService, PasswordHash
from src.business_logic.services import PasswordHash #JWTService,
from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager
from src.data_access.postgresql.repositories import (
ClientRepository,
UserRepository,
Expand All @@ -26,7 +27,7 @@ def _create_token_auth_service(
client_repo: ClientRepository,
user_repo: UserRepository,
password_service: PasswordHash,
jwt_service: JWTService,
jwt_service: JWTManager,
**kwargs: Any,
) -> AuthServiceProtocol:
return TokenAuthService(
Expand Down
6 changes: 3 additions & 3 deletions src/business_logic/get_tokens/service_impls/auth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ async def _get_access_token(self, request_data: RequestTokenModel, user_id: int,
jti=str(uuid.uuid4()),
acr=0,
)
return self._jwt_manager.encode(payload=payload, algorithm='RS256')
return await self._jwt_manager.encode(payload=payload, algorithm='RS256')

async def _get_refresh_token(self, request_data: RequestTokenModel) -> str:
payload = RefreshTokenPayload(
jti=str(uuid.uuid4())
)
return self._jwt_manager.encode(payload=payload, algorithm='RS256')
return await self._jwt_manager.encode(payload=payload, algorithm='RS256')

async def _get_id_token(self, request_data: RequestTokenModel, user_id: int, unix_time: int) -> str:
payload = IdTokenPayload(
Expand All @@ -123,4 +123,4 @@ async def _get_id_token(self, request_data: RequestTokenModel, user_id: int, uni
jti=str(uuid.uuid4()),
acr=0,
)
return self._jwt_manager.encode(payload=payload, algorithm='RS256')
return await self._jwt_manager.encode(payload=payload, algorithm='RS256')
3 changes: 2 additions & 1 deletion src/business_logic/jwt_manager/dto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
AccessTokenPayload,
RefreshTokenPayload,
IdTokenPayload,
AdminUIPayload
)


__all__ = ['AccessTokenPayload', 'RefreshTokenPayload', 'IdTokenPayload']
__all__ = ['AccessTokenPayload', 'RefreshTokenPayload', 'IdTokenPayload', 'AdminUIPayload']
12 changes: 9 additions & 3 deletions src/business_logic/jwt_manager/dto/input.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from pydantic import BaseModel
from typing import Optional, Any, Union

from src.dyna_config import DOMAIN_NAME
import uuid

class BaseJWTPayload(BaseModel):
sub: Union[int, str] # user id
iss: str # auth service uri
iat: int # time of creation
exp: int # time when token will expire
aud: Optional[Union[str, list[str]]] # name for whom token was generated
client_id: str # id of the client who issued a token
jti: str # uniques identifier for token, UUID4
acr: Optional[int] # default 0
jti: str = str(uuid.uuid4()) # uniques identifier for token, UUID4
iss: str = DOMAIN_NAME # auth service uri


class AccessTokenPayload(BaseJWTPayload):
Expand All @@ -31,3 +32,8 @@ class IdTokenPayload(BaseJWTPayload):
picture: Optional[str] = None
zoneinfo: Optional[str] = None
locale: Optional[str] = None

class AdminUIPayload(BaseModel):
sub:int
exp:int
aud:list[str] = ["oidc:admin_ui"]
58 changes: 41 additions & 17 deletions src/business_logic/jwt_manager/service_impls/jwt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,44 @@
import logging
import jwt
from typing import Any, Optional, Union

from src.data_access.postgresql.tables.rsa_keys import RSA_keys
from src.business_logic.jwt_manager.dto import (
AccessTokenPayload,
RefreshTokenPayload,
IdTokenPayload,
AdminUIPayload
)
from src.di.providers.rsa_keys import provide_rsa_keys

logger = logging.getLogger(__name__)

Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload]
Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload, AdminUIPayload]

class JWTManager:
def __init__(
self,
keys: RSA_keys
) -> None:
self.keys = keys
# print(f"print:!!!jwt_service.py;!!! self.keys: {self.keys}")

def encode(self, payload: Payload, algorithm: str, secret: Optional[str] = None) -> str:
if secret:
key = secret
else:
key = self.keys.private_key
print(f"jwt_service.py; keys: {self.keys.private_key}")
def __init__(self) -> None:
self.algorithm = "RS256"
self.algorithms = ["RS256"]
self.keys = None


def check_rsa_keys(self) -> None:
if not self.keys:
self.keys = provide_rsa_keys()
if self.keys is None:
raise ValueError("Keys don't exist or Docker is not running")

async def encode(self, payload: Payload, algorithm: Optional[str] = None, secret: Optional[str] = None) -> str:
self.check_rsa_keys()
key = secret or self.keys.private_key
token = jwt.encode(
payload=payload.dict(exclude_none=True), key=key, algorithm=algorithm
payload=payload.dict(exclude_none=True),
key=key,
algorithm=(algorithm or self.algorithm)
)
return token

def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> dict[str, Any]:
async def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> dict[str, Any]:
self.check_rsa_keys()
token = token.replace("Bearer ", "")
if audience:
decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms,
Expand All @@ -44,3 +49,22 @@ def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> d
**kwargs,)

return decoded_info

async def verify_token(self, token: str, aud: str = None) -> bool:
self.check_rsa_keys()
try:
if aud:
await self.decode(token=token, audience=aud)
else:
await self.decode(token)
return True
except:
return False

async def get_module(self) -> int:
self.check_rsa_keys()
return self.keys.n

async def get_pub_key_expanent(self) -> int:
self.check_rsa_keys()
return self.keys.e
8 changes: 5 additions & 3 deletions src/business_logic/services/admin_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
# https://docs.google.com/spreadsheets/d/1zJAcCxaGz2CV9zKlqeBhRE8xfiqyJMy5-XPRH1I8HPg/edit#gid=0
from typing import Optional, Any
from typing import Optional, Any, TYPE_CHECKING
from src.data_access.postgresql.repositories import RoleRepository, UserRepository, PersistentGrantRepository, GroupRepository
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from src.business_logic.services.jwt_token import JWTService
# from src.di.providers import provide_jwt_manager
from src.business_logic.services.password import PasswordHash
from src.data_access.postgresql.tables import Group, Role, User
# if TYPE_CHECKING:
# from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager


class AdminService():
def __init__(
self,
jwt_service: JWTService,
jwt_service#: JWTManager,
) -> None:
self.jwt_service = jwt_service

Expand Down
19 changes: 9 additions & 10 deletions src/business_logic/services/admin_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import time
from src.business_logic.dto import AdminCredentialsDTO
from src.business_logic.services.password import PasswordHash
from src.business_logic.services.jwt_token import JWTService
from src.di.providers import provide_jwt_manager
from src.data_access.postgresql.repositories import UserRepository, PersistentGrantRepository
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi.responses import RedirectResponse
from src.dyna_config import DOMAIN_NAME

from src.business_logic.jwt_manager.dto import AdminUIPayload
logger = logging.getLogger(__name__)


Expand All @@ -17,7 +17,7 @@ def __init__(
self,
user_repo: UserRepository,
password_service = PasswordHash(),
jwt_service = JWTService(),
jwt_service = provide_jwt_manager(),
) -> None:
self.user_repo = user_repo
self.password_service = password_service
Expand All @@ -33,17 +33,16 @@ async def authorize(
credentials.password,
user_hash_password
)
return await self.jwt_service.encode_jwt(
payload={
"sub":user_id,
"exp": exp_time + int(time.time()),
"aud":["admin","introspection", "revoke"]
}
return await self.jwt_service.encode(
payload=AdminUIPayload(
sub=user_id,
exp=exp_time + int(time.time()),
)
)


async def authenticate(self, token: str) -> Union[None, RedirectResponse]:
if await self.jwt_service.verify_token(token=token, aud="admin"):
if await self.jwt_service.verify_token(token=token, aud="oidc:admin_ui"):
return None
else:
# return None
Expand Down
4 changes: 2 additions & 2 deletions src/business_logic/services/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from src.config.settings.app import AppSettings

from src.dyna_config import DOMAIN_NAME
from src.business_logic.services.jwt_token import JWTService
from src.di.providers import provide_jwt_manager
from src.business_logic.services.password import PasswordHash
from src.business_logic.services.tokens import get_single_token
from src.data_access.postgresql.repositories import (
Expand All @@ -32,7 +32,7 @@ def __init__(
persistent_grant_repo: PersistentGrantRepository,
device_repo: DeviceRepository,
password_service: PasswordHash = PasswordHash(),
jwt_service: JWTService = JWTService(),
jwt_service: JWTManager = provide_jwt_manager(),
) -> None:
self._request_model: Optional[DataRequestModel] = None
self.client_repo = client_repo
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Optional
from typing import Optional, TYPE_CHECKING

from src.business_logic.services.authorization.response_type_handlers.factory import (
ResponseTypeHandlerFactory,
)
from src.business_logic.services.jwt_token import JWTService
# from src.di.providers import provide_jwt_manager
from src.business_logic.services.password import PasswordHash
from src.data_access.postgresql.errors import ClientScopesError
from src.data_access.postgresql.repositories import (
Expand All @@ -14,6 +14,8 @@
UserRepository,
)
from src.presentation.api.models import DataRequestModel
if TYPE_CHECKING:
from src.business_logic.jwt_manager.service_impls.jwt_service import JWTManager

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +28,7 @@ def __init__(
persistent_grant_repo: PersistentGrantRepository,
device_repo: DeviceRepository,
password_service: PasswordHash,
jwt_service: JWTService,
jwt_service,
) -> None:
self._request_model: Optional[DataRequestModel] = None
self.client_repo = client_repo
Expand Down
8 changes: 4 additions & 4 deletions src/business_logic/services/endsession.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from src.business_logic.jwt_manager import JWTManager
from src.presentation.api.models.endsession import RequestEndSessionModel
from src.data_access.postgresql.repositories.client import ClientRepository
from src.data_access.postgresql.repositories.persistent_grant import (
PersistentGrantRepository,
)
from src.business_logic.services.jwt_token import JWTService
from src.di.providers import provide_jwt_manager
from typing import Union, Optional, Any
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -14,11 +15,10 @@ def __init__(
session:AsyncSession,
client_repo: ClientRepository,
persistent_grant_repo: PersistentGrantRepository,
jwt_service = JWTService()
) -> None:
self.client_repo = client_repo
self.persistent_grant_repo = persistent_grant_repo
self.jwt_service = jwt_service
self.jwt_service = provide_jwt_manager()
self._request_model: Optional[RequestEndSessionModel] = None
self.session = session

Expand Down Expand Up @@ -53,7 +53,7 @@ async def end_session(self) -> Optional[str]:
async def _decode_id_token_hint(
self, id_token_hint: str
) -> dict[str, Any]:
decoded_data = await self.jwt_service.decode_token(token=id_token_hint)
decoded_data = await self.jwt_service.decode(token=id_token_hint)
return decoded_data

async def _logout(self, client_id: str, user_id: int) -> None:
Expand Down
Loading