11import math
22from typing import Any , Callable , TYPE_CHECKING
3- from asyncio import TaskGroup
43from logging import getLogger
54from fastapi import Request
65from sqlalchemy import (
@@ -305,7 +304,7 @@ async def create(self, data: CruddyModel, request: Request | None = None):
305304 inserted_row = result .first ()
306305 if inserted_row is None :
307306 raise ValueError (f"The payload { values } failed to create a new record" )
308- created_record = self .model (** inserted_row ._mapping )
307+ created_record = self .model (** inserted_row ._mapping )
309308 if created_record is not None :
310309 if self .lifecycle ["after_create" ]:
311310 await self .lifecycle ["after_create" ](created_record )
@@ -315,19 +314,19 @@ async def create(self, data: CruddyModel, request: Request | None = None):
315314
316315 async def get_by_id (
317316 self , id : possible_id_values , where : Json = None , request : Request | None = None
318- ) -> Any | None :
317+ ):
319318 # retrieve user data by id
320- async with self .adapter .getSession (request ) as session :
321- if self .lifecycle ["before_get_one" ]:
322- await self .lifecycle ["before_get_one" ](id , where )
323- selectables = list (self .view_keys )
324- columns = [getattr (self .model , x ) for x in selectables ]
325- query = select (* columns ).where (
326- and_ (
327- self .identity_function (id ),
328- * self .query_forge (model = self .model , where = where ),
329- )
319+ if self .lifecycle ["before_get_one" ]:
320+ await self .lifecycle ["before_get_one" ](id , where )
321+ selectables = list (self .view_keys )
322+ columns = [getattr (self .model , x ) for x in selectables ]
323+ query = select (* columns ).where (
324+ and_ (
325+ self .identity_function (id ),
326+ * self .query_forge (model = self .model , where = where ),
330327 )
328+ )
329+ async with self .adapter .getSession (request ) as session :
331330 result = (await session .execute (query )).fetchone ()
332331 if result is not None :
333332 result = self .view_model (** result ._mapping )
@@ -460,19 +459,12 @@ def splitter(sort_string: str):
460459 query_conf ["limit" ]
461460 )
462461 # total record
463-
464- async with (
465- self .adapter .getSession (request ) as session1 ,
466- self .adapter .getSession (request ) as session2 ,
467- ):
468- async with TaskGroup () as tg :
469- task1 = tg .create_task (session1 .execute (count_query ))
470- task2 = tg .create_task (session2 .execute (query ))
471- count : Result = task1 .result ()
472- records : Result = task2 .result ()
462+ async with self .adapter .getSession (request ) as session :
463+ records : Result = await session .execute (query )
464+ await session .flush ()
465+ count : Result = await session .execute (count_query )
473466 total_record = count .scalar () or 0
474467 result = records .fetchall ()
475-
476468 # possible pass in outside functions to map/alter data?
477469 # total page
478470 total_page = math .ceil (total_record / query_conf ["limit" ])
@@ -566,15 +558,10 @@ def splitter(sort_string: str):
566558 )
567559 # total record
568560
569- async with (
570- self .adapter .getSession (request ) as session1 ,
571- self .adapter .getSession (request ) as session2 ,
572- ):
573- async with TaskGroup () as tg :
574- task1 = tg .create_task (session1 .execute (count_query ))
575- task2 = tg .create_task (session2 .execute (query ))
576- count : Result = task1 .result ()
577- records : Result = task2 .result ()
561+ async with self .adapter .getSession (request ) as session :
562+ records : Result = await session .execute (query )
563+ await session .flush ()
564+ count : Result = await session .execute (count_query )
578565 total_record = count .scalar () or 0
579566 result = records .fetchall ()
580567
0 commit comments