2828log_database = cfg .get ("log.database" , False )
2929log_database_pool = cfg .get ("log.database_pool" , False )
3030
31+
32+ # This provides one session per *thread*
33+ ThreadSession = scoped_session (sessionmaker ())
34+
35+
36+ # This provides one session per *handler*
37+ # It is recommended to use the handler session via
38+ # self.Session, which has some knowledge
39+ # of the current user. See `handlers/base.py`
40+ #
41+ # DBSession has been renamed to HandlerSession
42+ # to make it clearer what it is doing.
43+
44+ # We've renamed DBSession
45+ # It is not recommended to use DBSession directly;
3146session_context_id = contextvars .ContextVar ("request_id" , default = None )
32- # left here for backward compatibility:
3347DBSession = scoped_session (sessionmaker (), scopefunc = session_context_id .get )
48+ HandlerSession = DBSession
3449
3550
3651class _VerifiedSession (sa .orm .session .Session ):
@@ -61,7 +76,7 @@ def __init__(self, user_or_token, **kwargs):
6176 or be generating an unverified session to only query
6277 the user with a certain id. Example:
6378
64- with DBSession () as session:
79+ with HandlerSession () as session:
6580 user = session.scalars(
6681 sa.select(User).where(User.id == user_id)
6782 ).first()
@@ -159,7 +174,7 @@ def bulk_verify(mode, collection, accessor):
159174 ).subquery ()
160175
161176 inaccessible_row_ids = (
162- DBSession ()
177+ HandlerSession ()
163178 .scalars (
164179 sa .select (record_cls .id )
165180 .outerjoin (
@@ -258,7 +273,7 @@ def init_db(
258273 "max_overflow" : 10 ,
259274 "pool_recycle" : 3600 ,
260275 }
261- conn = sa .create_engine (
276+ engine = sa .create_engine (
262277 url ,
263278 client_encoding = "utf8" ,
264279 executemany_mode = "values_plus_batch" ,
@@ -268,8 +283,15 @@ def init_db(
268283 ** {** default_engine_args , ** engine_args },
269284 )
270285
271- DBSession .configure (bind = conn , autoflush = autoflush , future = True )
272- Base .metadata .bind = conn
286+ HandlerSession .configure (bind = engine , autoflush = autoflush , future = True )
287+ # Convenience attribute to easily access the engine, otherwise would need
288+ # HandlerSession.session_factory.kw["bind"]
289+ HandlerSession .engine = engine
290+
291+ ThreadSession .configure (bind = engine , autoflush = autoflush , future = True )
292+ ThreadSession .engine = engine
293+
294+ Base .metadata .bind = engine
273295
274296 return conn
275297
@@ -478,8 +500,8 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
478500 """
479501 # return only selected columns if requested
480502 if columns is not None :
481- return DBSession ().query (* columns ).select_from (cls )
482- return DBSession ().query (cls )
503+ return HandlerSession ().query (* columns ).select_from (cls )
504+ return HandlerSession ().query (cls )
483505
484506 def select_accessible_rows (self , cls , user_or_token , columns = None ):
485507 """Construct a Select object that, when executed, returns the rows of a
@@ -571,9 +593,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
571593
572594 # return only selected columns if requested
573595 if columns is not None :
574- query = DBSession ().query (* columns ).select_from (cls )
596+ query = HandlerSession ().query (* columns ).select_from (cls )
575597 else :
576- query = DBSession ().query (cls )
598+ query = HandlerSession ().query (cls )
577599
578600 # traverse the relationship chain via sequential JOINs
579601 for relationship_name in self .relationship_names :
@@ -735,9 +757,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
735757
736758 # return only selected columns if requested
737759 if columns is None :
738- base = DBSession ().query (cls )
760+ base = HandlerSession ().query (cls )
739761 else :
740- base = DBSession ().query (* columns ).select_from (cls )
762+ base = HandlerSession ().query (* columns ).select_from (cls )
741763
742764 # ensure the target class has all the relationships referred to
743765 # in this instance
@@ -922,9 +944,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
922944
923945 # retrieve specified columns if requested
924946 if columns is not None :
925- query = DBSession ().query (* columns ).select_from (cls )
947+ query = HandlerSession ().query (* columns ).select_from (cls )
926948 else :
927- query = DBSession ().query (cls )
949+ query = HandlerSession ().query (cls )
928950
929951 # keep track of columns that will be null in the case of an unsuccessful
930952 # match for OR logic.
@@ -1076,9 +1098,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
10761098 # otherwise, all records are inaccessible
10771099 if columns is not None :
10781100 return (
1079- DBSession ().query (* columns ).select_from (cls ).filter (sa .literal (False ))
1101+ HandlerSession ().query (* columns ).select_from (cls ).filter (sa .literal (False ))
10801102 )
1081- return DBSession ().query (cls ).filter (sa .literal (False ))
1103+ return HandlerSession ().query (cls ).filter (sa .literal (False ))
10821104
10831105 def select_accessible_rows (self , cls , user_or_token , columns = None ):
10841106 """Construct a Select object that, when executed, returns the rows of a
@@ -1148,7 +1170,7 @@ def __init__(self, query_or_query_generator):
11481170
11491171 Query (SQLA 1.4):
11501172 >>>> CustomUserAccessControl(
1151- DBSession ().query(Department).join(Employee).group_by(
1173+ HandlerSession ().query(Department).join(Employee).group_by(
11521174 Department.id
11531175 ).having(sa.func.bool_and(Employee.is_manager.is_(True)))
11541176 )
@@ -1166,8 +1188,8 @@ def __init__(self, query_or_query_generator):
11661188 Query (SQLA 1.4):
11671189 >>>> def access_logic(cls, user_or_token):
11681190 ... if user_or_token.is_system_admin:
1169- ... return DBSession ().query(cls)
1170- ... return DBSession ().query(cls).join(Employee).group_by(
1191+ ... return HandlerSession ().query(cls)
1192+ ... return HandlerSession ().query(cls).join(Employee).group_by(
11711193 ... cls.id
11721194 ... ).having(sa.func.bool_and(Employee.is_manager.is_(True)))
11731195 >>>> CustomUserAccessControl(access_logic)
@@ -1303,7 +1325,7 @@ def is_accessible_by(self, user_or_token, mode="read"):
13031325
13041326 # Query for the value of the access_func for this particular record and
13051327 # return the result.
1306- result = DBSession ().execute (stmt ).scalar_one () > 0
1328+ result = HandlerSession ().execute (stmt ).scalar_one () > 0
13071329 if result is None :
13081330 result = False
13091331
@@ -1353,7 +1375,7 @@ def get_if_accessible_by(
13531375
13541376 # TODO: vectorize this
13551377 for pk in standardized :
1356- instance = DBSession ().query (cls ).options (options ).get (pk .item ())
1378+ instance = HandlerSession ().query (cls ).options (options ).get (pk .item ())
13571379 if instance is None or not instance .is_accessible_by (
13581380 user_or_token , mode = mode
13591381 ):
@@ -1467,7 +1489,7 @@ def get(
14671489 standardized = np .atleast_1d (id_or_list )
14681490 result = []
14691491
1470- with DBSession () as session :
1492+ with HandlerSession () as session :
14711493 # TODO: vectorize this
14721494 for pk in standardized :
14731495 if options :
@@ -1519,7 +1541,7 @@ def get_all(
15191541 If columns is specified, will return a list of tuples
15201542 containing the data from each column requested.
15211543 """
1522- with DBSession () as session :
1544+ with HandlerSession () as session :
15231545 stmt = cls .select (user_or_token , mode , options , columns )
15241546 values = session .scalars (stmt ).all ()
15251547
@@ -1566,7 +1588,7 @@ def select(
15661588 stmt = stmt .options (option )
15671589 return stmt
15681590
1569- query = DBSession .query_property ()
1591+ query = HandlerSession .query_property ()
15701592
15711593 id = sa .Column (
15721594 sa .Integer ,
@@ -1608,8 +1630,8 @@ def __repr__(self):
16081630 def to_dict (self ):
16091631 """Serialize this object to a Python dictionary."""
16101632 if sa .inspection .inspect (self ).expired :
1611- self = DBSession ().merge (self )
1612- DBSession ().refresh (self )
1633+ self = HandlerSession ().merge (self )
1634+ HandlerSession ().refresh (self )
16131635 return {k : v for k , v in self .__dict__ .items () if not k .startswith ("_" )}
16141636
16151637 @classmethod
@@ -1632,7 +1654,7 @@ def get_if_readable_by(cls, ident, user_or_token, options=[]):
16321654 obj : baselayer.app.models.Base
16331655 The requested entity.
16341656 """
1635- obj = DBSession ().query (cls ).options (options ).get (ident )
1657+ obj = HandlerSession ().query (cls ).options (options ).get (ident )
16361658
16371659 if obj is not None and not obj .is_readable_by (user_or_token ):
16381660 raise AccessError ("Insufficient permissions." )
@@ -1659,7 +1681,7 @@ def is_readable_by(self, user_or_token):
16591681 def create_or_get (cls , id ):
16601682 """Return a new `cls` if an instance with the specified primary key
16611683 does not exist, else return the existing instance."""
1662- obj = DBSession ().query (cls ).get (id )
1684+ obj = HandlerSession ().query (cls ).get (id )
16631685 if obj is not None :
16641686 return obj
16651687 else :
@@ -1876,7 +1898,7 @@ class User(Base):
18761898 role_ids = association_proxy (
18771899 "roles" ,
18781900 "id" ,
1879- creator = lambda r : DBSession ().query (Role ).get (r ),
1901+ creator = lambda r : HandlerSession ().query (Role ).get (r ),
18801902 )
18811903 tokens = relationship (
18821904 "Token" ,
@@ -1982,7 +2004,7 @@ class Token(Base):
19822004 lazy = "selectin" ,
19832005 )
19842006 acl_ids = association_proxy (
1985- "acls" , "id" , creator = lambda acl : DBSession ().query (ACL ).get (acl )
2007+ "acls" , "id" , creator = lambda acl : HandlerSession ().query (ACL ).get (acl )
19862008 )
19872009 permissions = acl_ids
19882010
0 commit comments