Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0c0ac8d
C1: "locust" to .toml
Jun 1, 2023
4efe665
C3: locustfile.py
Jun 2, 2023
b21d0dc
C4: client, device, auth, token_2
Jun 5, 2023
ce2181f
C5: Userinfo - dont work.
Jun 6, 2023
47538b4
Merge branch 'main' into perfomance_testing
Jun 6, 2023
a0c0b06
C6: Userinfo
Jun 7, 2023
a7701d0
C7: User, Client, Authorization, Device, Userinfo, Endsession, Token.
Jun 9, 2023
474b4e8
C8: introspection
Jun 12, 2023
bea32f1
C9: revoke
Jun 12, 2023
d481ef5
C10: third party providder - POST
Jun 13, 2023
80d549b
everything is working
Jun 15, 2023
198c2a1
C1: structure, migrations
Jun 15, 2023
ac1f571
circular import
Jun 16, 2023
dbf1817
C2: logic is moved to jwt_service.py
Jun 16, 2023
a81da5d
Fixed bug in PKCE deciphering (#122)
maxamly Jun 16, 2023
e3cf059
C3: logic is moved to jwt_service.py
Jun 17, 2023
81018be
C4: provide_rsa_keys, Depends()
Jun 20, 2023
fd2103f
C5: repo, test_repo
Jun 21, 2023
5c53121
C6: one key for several processes
Jun 22, 2023
87b91cb
Merge branch 'perfomance_testing' into RSA_keys_for_several_processes
Jun 22, 2023
7562dd2
Merge branch 'main' into RSA_keys_for_several_processes
Jun 22, 2023
3a43343
C7: FastAPIError: Invalid args for response field!
Jun 23, 2023
7ed4384
C8: sync session, but Error 422
Jun 26, 2023
d59524d
C9: sync session, /token/
Jun 27, 2023
e20fe01
C10: clean
Jun 27, 2023
4e2c5d7
C1: endpoints in svager work
Jun 28, 2023
2e7b244
C1: endpoints in svager work
Jun 28, 2023
19677aa
C2: endpoints, tests work
Jun 28, 2023
f7c4a28
Merge remote-tracking branch 'origin/replace_container_in_JWTService'…
Jun 28, 2023
c7a65ce
Merge branch 'main' into replace_container_in_JWTService
Jun 28, 2023
b556bd2
C3: done
Jun 28, 2023
896ceab
C4: cleaned
Jun 28, 2023
ac47f47
poetry
DanyaKrats Jun 29, 2023
2287d14
fix
DanyaKrats Jun 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,213 changes: 508 additions & 1,705 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ types-mock = "^5.0.0"
fastapi-utils = "^0.2.1"
deptry = "^0.11.0"
bump2version = "^1.0.1"
locust = "^2.15.1"


[tool.deptry]
ignore_unused = [
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
pythonpath = . src
norecursedirs = tests/performance
4 changes: 2 additions & 2 deletions src/business_logic/jwt_manager/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


class JWTManagerProtocol(Protocol):
def encode(self, payload: Payload, algorithm: str) -> str:
async def encode(self, payload: Payload, algorithm: str) -> str:
raise NotImplementedError

def decode(self, token: str, audience: str,**kwargs: Any) -> dict[str, Any]:
async def decode(self, token: str, audience: str,**kwargs: Any) -> dict[str, Any]:
raise NotImplementedError
18 changes: 10 additions & 8 deletions src/business_logic/jwt_manager/service_impls/jwt_service.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
from __future__ import annotations
import logging
import jwt
from src.config.rsa_keys import RSAKeypair
from src.di import Container
from typing import Any, Optional, Union

from src.data_access.postgresql.tables.rsa_keys import RSA_keys
from src.business_logic.jwt_manager.dto import (
AccessTokenPayload,
RefreshTokenPayload,
IdTokenPayload,
)
from typing import Any, Optional, Union


logger = logging.getLogger(__name__)


Payload = Union[AccessTokenPayload, RefreshTokenPayload, IdTokenPayload]


class JWTManager:
def __init__(self, keys: RSAKeypair = Container().config().keys) -> None:
def __init__(
self,
keys: RSA_keys
) -> None:
self.keys = keys
# print(f"print:!!!jwt_service.py;!!! self.keys: {self.keys}")

def encode(self, payload: Payload, algorithm: str, secret: Optional[str] = None) -> str:
if secret:
key = secret
else:
key = self.keys.private_key
print(f"jwt_service.py; keys: {self.keys.private_key}")

token = jwt.encode(
payload=payload.dict(exclude_none=True), key=key, algorithm=algorithm
Expand All @@ -41,4 +43,4 @@ def decode(self, token: str, audience: Optional[str] = None, **kwargs: Any) -> d
decoded_info = jwt.decode(token, key=self.keys.public_key, algorithms=self.algorithms,
**kwargs,)

return decoded_info
return decoded_info
15 changes: 8 additions & 7 deletions src/business_logic/services/jwt_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@

import jwt
from typing import Any, no_type_check, Optional

from src.di.providers.rsa_keys import provide_rsa_keys
from src.config.rsa_keys import RSAKeypair
from src.di import Container

logger = logging.getLogger(__name__)


class JWTService:
def __init__(self, keys: RSAKeypair = Container().config().keys) -> None:
def __init__(self, keys: RSAKeypair = provide_rsa_keys) -> None:
self.algorithm = "RS256"
self.algorithms = ["RS256"]
self.keys = keys

@no_type_check
async def encode_jwt(self, payload: dict[str, Any] = {}, secret: None = None) -> str:
token = jwt.encode(
payload=payload, key=self.keys.private_key, algorithm=self.algorithm
payload=payload, key=self.keys().private_key, algorithm=self.algorithm
)

logger.info(f"Created token.")
Expand All @@ -31,15 +32,15 @@ async def decode_token(self, token: str, audience: str =None ,**kwargs: Any) ->
if audience:
decoded = jwt.decode(
token,
key=self.keys.public_key,
key=self.keys().public_key,
algorithms=self.algorithms,
audience=audience,
**kwargs,
)
return decoded
decoded = jwt.decode(
token,
key=self.keys.public_key,
key=self.keys().public_key,
algorithms=self.algorithms,
**kwargs,
)
Expand All @@ -56,7 +57,7 @@ async def verify_token(self, token: str, aud:str=None) -> bool:
return False

async def get_module(self) -> int:
return self.keys.n
return self.keys().n

async def get_pub_key_expanent(self) -> int:
return self.keys.e
return self.keys().e
4 changes: 2 additions & 2 deletions src/business_logic/services/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession

from src.business_logic.services.jwt_token import JWTService
# from src.business_logic.services.jwt_token import JWTService
from src.config.settings.app import AppSettings
from src.data_access.postgresql.errors import (
ClaimsNotFoundError,
Expand Down Expand Up @@ -186,7 +186,7 @@ async def revoke_token(self) -> None:
raise GrantNotFoundError
elif token_type_hint == "access_token":
decoded_token = await self.jwt_service.decode_token(
self.request_body.token, audience="revoke"
self.request_body.token, audience="revocation"
)
await self.blacklisted_repo.create(
token=self.request_body.token, expiration=decoded_token["exp"]
Expand Down
2 changes: 1 addition & 1 deletion src/business_logic/services/well_known.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def get_jwks(self) -> dict[str, Any]:
"alg": jwt_service.algorithm,
"use": "sig",
# "kid" : ... ,
"n": long_to_base64(await jwt_service.get_module()),
"n": long_to_base64(int(await jwt_service.get_module())),
"e": long_to_base64(await jwt_service.get_pub_key_expanent()),
}
logger.info(
Expand Down
1 change: 1 addition & 0 deletions src/config/rsa_keys/create_rsa_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dto import RSAKeypair

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class CreateRSAKeypair:
Expand Down
39 changes: 39 additions & 0 deletions src/config/rsa_keys/rsa_keys_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from Crypto.PublicKey import RSA
from sqlalchemy.orm import sessionmaker

from src.data_access.postgresql.repositories import RSAKeysRepository
from src.data_access.postgresql.tables.rsa_keys import RSA_keys
from .dto import RSAKeypair

class RSAKeysService:

def __init__(
self,
sync_session_factory: sessionmaker,
rsa_keys_repo: RSAKeysRepository
) -> None:
self.session = sync_session_factory
self.rsa_keys_repo = rsa_keys_repo

def get_rsa_keys(self) -> RSA_keys:
with self.session() as session:
if self.rsa_keys_repo.validate_keys_exists(session=session):
self.rsa_keys = self.rsa_keys_repo.get_keys_from_repository(session)
else:
self.rsa_keys = self.create_rsa_keys() # RSAKeypair
self.rsa_keys_repo.put_keys_to_repository(rsa_keys=self.rsa_keys, session=session)
self.rsa_keys = self.rsa_keys_repo.get_keys_from_repository(session=session) # RSA_keys
return self.rsa_keys

def create_rsa_keys(self) -> RSAKeypair: # or -> RSA_keys
key = RSA.generate(2048)
private_key = key.export_key("PEM")
public_key = key.public_key().export_key("PEM")

self.rsa_keys = RSAKeypair(
private_key=private_key,
public_key=public_key,
n=key.n,
e=key.e,
)
return self.rsa_keys
3 changes: 0 additions & 3 deletions src/config/settings/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from pydantic import PostgresDsn, SecretStr

from src.config.rsa_keys import CreateRSAKeypair, RSAKeypair
from src.config.settings.base import BaseAppSettings


Expand All @@ -27,8 +26,6 @@ class AppSettings(BaseAppSettings):

allowed_hosts: List[str] = ["*"]

keys: RSAKeypair = CreateRSAKeypair().execute()

class Config:
validate_assignment = True

Expand Down
2 changes: 1 addition & 1 deletion src/data_access/postgresql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .database import Database
from .database import Database, DatabaseSync
from .tables import (
Base,
Client,
Expand Down
33 changes: 31 additions & 2 deletions src/data_access/postgresql/database.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging

from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from src.config.settings.app import AppSettings

logger = logging.getLogger(__name__)

Expand All @@ -26,6 +26,10 @@ def session_factory(self) -> AsyncSession:
def engine(self) -> AsyncSession:
return self.__engine

@property
def sync_engine(self) -> Engine:
return self.__sync_engine

def _create_connection_pool(self, db_url: str, max_connection_count: int) -> AsyncEngine:
logger.info("Creating PostgreSQL connection pool.")

Expand All @@ -38,3 +42,28 @@ def _create_connection_pool(self, db_url: str, max_connection_count: int) -> Asy
async def get_connection(self) -> AsyncSession:
async with self.__session_factory() as session:
yield session


class DatabaseSync:
def __init__(self, database_url: str):
self.__sync_engine = self._create_sync_connection_pool(database_url)
self.__sync_session_factory = sessionmaker(
self.__sync_engine
)

@property
def sync_engine(self) -> Engine:
return self.__sync_engine

@property
def sync_session_factory(self) -> sessionmaker:
return self.__sync_session_factory

def _create_sync_connection_pool(self, db_url:str) -> Engine:
logger.info("Creating PostgreSQL sync engine.")
db_url = db_url.replace("asyncpg", "psycopg2")
sync_engine = create_engine(db_url)

logger.info("Sync engine created.")

return sync_engine
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""create rsa keys table

Revision ID: 6a2b3685577f
Revises: 432f08914e0c
Create Date: 2023-06-15 17:18:45.282222

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '6a2b3685577f'
down_revision = '432f08914e0c'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"rsa_keys",
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.Column("private_key", sa.LargeBinary),
sa.Column("public_key", sa.LargeBinary),
sa.Column('n', sa.Numeric),
sa.Column('e', sa.Integer),
sa.Column('expiration_encode', sa.Integer),
sa.Column('expiration_decode', sa.Integer),
)

def downgrade() -> None:
op.drop_table("rsa_keys")
1 change: 1 addition & 0 deletions src/data_access/postgresql/repositories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .wellknown import WellKnownRepository
from .blacklisted_token import BlacklistedTokenRepository
from .code_challenge import CodeChallengeRepository
from .rsa_keys import RSAKeysRepository
42 changes: 42 additions & 0 deletions src/data_access/postgresql/repositories/rsa_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional

from sqlalchemy import exists, select, insert
from sqlalchemy.orm import Session

from src.config.rsa_keys import RSAKeypair
from src.data_access.postgresql.tables.rsa_keys import RSA_keys

class RSAKeysRepository:

def get_keys_from_repository(self, session: Session) -> Optional[RSA_keys]:
result = session.execute(select(RSA_keys))
rsa_keys_list = result.scalars().all()
rsa_keys = rsa_keys_list[-1] if len(rsa_keys_list) > 0 else None
return rsa_keys

def put_keys_to_repository(self, rsa_keys: RSAKeypair, session: Session) -> None:
session.execute(insert(RSA_keys).values(
private_key=rsa_keys.private_key,
public_key=rsa_keys.public_key,
n=rsa_keys.n,
e=rsa_keys.e
))
session.commit()

# def validate_keys_exists(self, session: Session) -> bool:
# result = session.execute(
# select([1]).where(exists().where(RSA_keys.id.isnot(None)).select_from(RSA_keys))
# )
# return result.scalars().first() is not None

# def validate_keys_exists(self, session: Session) -> bool:
# stmt = exists().where(RSA_keys.id.isnot(None))
# result = session.execute(select(stmt))
# return result.scalar() is not None

def validate_keys_exists(self, session: Session) -> bool:
res = session.execute((
select(RSA_keys.id)
))
return res.scalar() is not None

2 changes: 2 additions & 0 deletions src/data_access/postgresql/tables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .device import Device
from .blacklisted_token import BlacklistedToken
from .code_challenge import CodeChallenge, CodeChallengeMethod
# from .rsa_keys import RSA_keys

__all__ = [
Client,
Expand All @@ -52,4 +53,5 @@
ClientSecret,
ClientRedirectUri,
Base,
# RSA_keys
]
20 changes: 20 additions & 0 deletions src/data_access/postgresql/tables/rsa_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sqlalchemy import (
Column,
Integer,
LargeBinary,
Numeric,
String,
Text
)

from .base import BaseModel

class RSA_keys(BaseModel):
__tablename__ = "rsa_keys"

private_key = Column(LargeBinary)
public_key = Column(LargeBinary)
n = Column(Numeric)
e = Column(Integer)
expiration_encode = Column(Integer, default=0)
expiration_decode = Column(Integer, default=0)
Loading