Skip to content

Commit 1bcea32

Browse files
committed
Refactor Sessions to clarify their usage
DBSession has been renamed to HandlerSession. The following exist now: - HandlerSession: a DB session scoped to the current handler; should therefore only be used inside handlers. - ThreadSession: a DB session scoped to the current thread; should be used outside of handlers, e.g. when firing off tasks in threads, or inside of services. A special HandlerSession, with knowledge of the current user, and that verifies permissions on commit, is available as `self.Session` on the base handler. This is the *preferred way of accessing* HandlerSession. To make it easier to get hold of the current engine, the Sessions each have a `.engine` attribute: `HandlerSession.engine`. This avoids having to access the engine as `HandlerSession.session_factory.kw["bind"]`.
1 parent 57b7faa commit 1bcea32

File tree

8 files changed

+108
-83
lines changed

8 files changed

+108
-83
lines changed

app/access.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy.orm import joinedload
66

77
from baselayer.app.custom_exceptions import AccessError # noqa: F401
8-
from baselayer.app.models import DBSession, Role, Token, User # noqa: F401
8+
from baselayer.app.models import HandlerSession, Role, Token, User # noqa: F401
99

1010

1111
def auth_or_token(method):
@@ -26,7 +26,7 @@ def wrapper(self, *args, **kwargs):
2626
token_header = self.request.headers.get("Authorization", None)
2727
if token_header is not None and token_header.startswith("token "):
2828
token_id = token_header.replace("token", "").strip()
29-
with DBSession() as session:
29+
with HandlerSession() as session:
3030
token = session.scalars(
3131
sa.select(Token)
3232
.options(

app/handlers/base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..env import load_env
2323
from ..flow import Flow
2424
from ..json_util import to_json
25-
from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id
25+
from ..models import HandlerSession, User, VerifiedSession, bulk_verify, session_context_id
2626

2727
env, cfg = load_env()
2828
log = make_log("basehandler")
@@ -49,7 +49,7 @@ def get_current_user(self):
4949
user_id = int(self.user_id())
5050
oauth_uid = self.get_secure_cookie("user_oauth_uid")
5151
if user_id and oauth_uid:
52-
with DBSession() as session:
52+
with HandlerSession() as session:
5353
try:
5454
user = session.scalars(
5555
sqlalchemy.select(User).where(User.id == user_id)
@@ -74,7 +74,7 @@ def get_current_user(self):
7474
return None
7575

7676
def login_user(self, user):
77-
with DBSession() as session:
77+
with HandlerSession() as session:
7878
try:
7979
self.set_secure_cookie("user_id", str(user.id))
8080
user = session.scalars(
@@ -120,7 +120,7 @@ def log_exception(self, typ=None, value=None, tb=None):
120120
)
121121

122122
def on_finish(self):
123-
DBSession.remove()
123+
HandlerSession.remove()
124124

125125

126126
class BaseHandler(PSABaseHandler):
@@ -153,7 +153,7 @@ def Session(self):
153153
# must merge the user object with the current session
154154
# ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#adding-new-or-existing-items
155155
session.add(self.current_user)
156-
session.bind = DBSession.session_factory.kw["bind"]
156+
session.bind = HandlerSession.engine
157157
yield session
158158

159159
def verify_permissions(self):
@@ -164,20 +164,20 @@ def verify_permissions(self):
164164
"""
165165

166166
# get items to be inserted
167-
new_rows = [row for row in DBSession().new]
167+
new_rows = [row for row in HandlerSession().new]
168168

169169
# get items to be updated
170170
updated_rows = [
171-
row for row in DBSession().dirty if DBSession().is_modified(row)
171+
row for row in HandlerSession().dirty if HandlerSession().is_modified(row)
172172
]
173173

174174
# get items to be deleted
175-
deleted_rows = [row for row in DBSession().deleted]
175+
deleted_rows = [row for row in HandlerSession().deleted]
176176

177177
# get items that were read
178178
read_rows = [
179179
row
180-
for row in set(DBSession().identity_map.values())
180+
for row in set(HandlerSession().identity_map.values())
181181
- (set(updated_rows) | set(new_rows) | set(deleted_rows))
182182
]
183183

@@ -194,15 +194,15 @@ def verify_permissions(self):
194194
# update transaction state in DB, but don't commit yet. this updates
195195
# or adds rows in the database and uses their new state in joins,
196196
# for permissions checking purposes.
197-
DBSession().flush()
197+
HandlerSession().flush()
198198
bulk_verify("create", new_rows, self.current_user)
199199

200200
def verify_and_commit(self):
201201
"""Verify permissions on the current database session and commit if
202202
successful, otherwise raise an AccessError.
203203
"""
204204
self.verify_permissions()
205-
DBSession().commit()
205+
HandlerSession().commit()
206206

207207
def prepare(self):
208208
self.cfg = self.application.cfg
@@ -225,7 +225,7 @@ def prepare(self):
225225
N = 5
226226
for i in range(1, N + 1):
227227
try:
228-
assert DBSession.session_factory.kw["bind"] is not None
228+
assert HandlerSession.engine is not None
229229
except Exception as e:
230230
if i == N:
231231
raise e

app/model_util.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ def status(message):
2121
else:
2222
print(f"\r[✓] {message}")
2323
finally:
24-
models.DBSession().commit()
24+
models.HandlerSession().commit()
2525

2626

2727
def drop_tables():
28-
conn = models.DBSession.session_factory.kw["bind"]
29-
print(f"Dropping tables on database {conn.url.database}")
28+
engine = models.HandlerSession.engine
29+
print(f"Dropping tables on database {engine.url.database}")
3030
meta = sa.MetaData()
31-
meta.reflect(bind=conn)
32-
meta.drop_all(bind=conn)
31+
meta.reflect(bind=engine)
32+
meta.drop_all(bind=engine)
3333

3434

3535
def create_tables(retry=5, add=True):
@@ -45,17 +45,16 @@ def create_tables(retry=5, add=True):
4545
tables.
4646
4747
"""
48-
conn = models.DBSession.session_factory.kw["bind"]
4948
tables = models.Base.metadata.sorted_tables
5049
if tables and not add:
5150
print("Existing tables found; not creating additional tables")
5251
return
5352

5453
for i in range(1, retry + 1):
5554
try:
56-
conn = models.DBSession.session_factory.kw["bind"]
57-
print(f"Creating tables on database {conn.url.database}")
58-
models.Base.metadata.create_all(conn)
55+
engine = models.HandlerSession.engine
56+
print(f"Creating tables on database {engine.url.database}")
57+
models.Base.metadata.create_all(engine)
5958

6059
table_list = ", ".join(list(models.Base.metadata.tables.keys()))
6160
print(f"Refreshed tables: {table_list}")

app/models.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,24 @@
2828
log_database = cfg.get("log.database", False)
2929
log_database_pool = cfg.get("log.database_pool", False)
3030

31+
32+
# This provides one session per *thread*
33+
ThreadSession = scoped_session(sessionmaker())
34+
35+
36+
# This provides one session per *handler*
37+
# It is recommended to use the handler session via
38+
# self.Session, which has some knowledge
39+
# of the current user. See `handlers/base.py`
40+
#
41+
# DBSession has been renamed to HandlerSession
42+
# to make it clearer what it is doing.
43+
44+
# We've renamed DBSession
45+
# It is not recommended to use DBSession directly;
3146
session_context_id = contextvars.ContextVar("request_id", default=None)
32-
# left here for backward compatibility:
3347
DBSession = scoped_session(sessionmaker(), scopefunc=session_context_id.get)
48+
HandlerSession = DBSession
3449

3550

3651
class _VerifiedSession(sa.orm.session.Session):
@@ -61,7 +76,7 @@ def __init__(self, user_or_token, **kwargs):
6176
or be generating an unverified session to only query
6277
the user with a certain id. Example:
6378
64-
with DBSession() as session:
79+
with HandlerSession() as session:
6580
user = session.scalars(
6681
sa.select(User).where(User.id == user_id)
6782
).first()
@@ -159,7 +174,7 @@ def bulk_verify(mode, collection, accessor):
159174
).subquery()
160175

161176
inaccessible_row_ids = (
162-
DBSession()
177+
HandlerSession()
163178
.scalars(
164179
sa.select(record_cls.id)
165180
.outerjoin(
@@ -258,7 +273,7 @@ def init_db(
258273
"max_overflow": 10,
259274
"pool_recycle": 3600,
260275
}
261-
conn = sa.create_engine(
276+
engine = sa.create_engine(
262277
url,
263278
client_encoding="utf8",
264279
executemany_mode="values_plus_batch",
@@ -268,8 +283,15 @@ def init_db(
268283
**{**default_engine_args, **engine_args},
269284
)
270285

271-
DBSession.configure(bind=conn, autoflush=autoflush, future=True)
272-
Base.metadata.bind = conn
286+
HandlerSession.configure(bind=engine, autoflush=autoflush, future=True)
287+
# Convenience attribute to easily access the engine, otherwise would need
288+
# HandlerSession.session_factory.kw["bind"]
289+
HandlerSession.engine = engine
290+
291+
ThreadSession.configure(bind=engine, autoflush=autoflush, future=True)
292+
ThreadSession.engine = engine
293+
294+
Base.metadata.bind = engine
273295

274296
return conn
275297

@@ -478,8 +500,8 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
478500
"""
479501
# return only selected columns if requested
480502
if columns is not None:
481-
return DBSession().query(*columns).select_from(cls)
482-
return DBSession().query(cls)
503+
return HandlerSession().query(*columns).select_from(cls)
504+
return HandlerSession().query(cls)
483505

484506
def select_accessible_rows(self, cls, user_or_token, columns=None):
485507
"""Construct a Select object that, when executed, returns the rows of a
@@ -571,9 +593,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
571593

572594
# return only selected columns if requested
573595
if columns is not None:
574-
query = DBSession().query(*columns).select_from(cls)
596+
query = HandlerSession().query(*columns).select_from(cls)
575597
else:
576-
query = DBSession().query(cls)
598+
query = HandlerSession().query(cls)
577599

578600
# traverse the relationship chain via sequential JOINs
579601
for relationship_name in self.relationship_names:
@@ -735,9 +757,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
735757

736758
# return only selected columns if requested
737759
if columns is None:
738-
base = DBSession().query(cls)
760+
base = HandlerSession().query(cls)
739761
else:
740-
base = DBSession().query(*columns).select_from(cls)
762+
base = HandlerSession().query(*columns).select_from(cls)
741763

742764
# ensure the target class has all the relationships referred to
743765
# in this instance
@@ -922,9 +944,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
922944

923945
# retrieve specified columns if requested
924946
if columns is not None:
925-
query = DBSession().query(*columns).select_from(cls)
947+
query = HandlerSession().query(*columns).select_from(cls)
926948
else:
927-
query = DBSession().query(cls)
949+
query = HandlerSession().query(cls)
928950

929951
# keep track of columns that will be null in the case of an unsuccessful
930952
# match for OR logic.
@@ -1076,9 +1098,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None):
10761098
# otherwise, all records are inaccessible
10771099
if columns is not None:
10781100
return (
1079-
DBSession().query(*columns).select_from(cls).filter(sa.literal(False))
1101+
HandlerSession().query(*columns).select_from(cls).filter(sa.literal(False))
10801102
)
1081-
return DBSession().query(cls).filter(sa.literal(False))
1103+
return HandlerSession().query(cls).filter(sa.literal(False))
10821104

10831105
def select_accessible_rows(self, cls, user_or_token, columns=None):
10841106
"""Construct a Select object that, when executed, returns the rows of a
@@ -1148,7 +1170,7 @@ def __init__(self, query_or_query_generator):
11481170
11491171
Query (SQLA 1.4):
11501172
>>>> CustomUserAccessControl(
1151-
DBSession().query(Department).join(Employee).group_by(
1173+
HandlerSession().query(Department).join(Employee).group_by(
11521174
Department.id
11531175
).having(sa.func.bool_and(Employee.is_manager.is_(True)))
11541176
)
@@ -1166,8 +1188,8 @@ def __init__(self, query_or_query_generator):
11661188
Query (SQLA 1.4):
11671189
>>>> def access_logic(cls, user_or_token):
11681190
... if user_or_token.is_system_admin:
1169-
... return DBSession().query(cls)
1170-
... return DBSession().query(cls).join(Employee).group_by(
1191+
... return HandlerSession().query(cls)
1192+
... return HandlerSession().query(cls).join(Employee).group_by(
11711193
... cls.id
11721194
... ).having(sa.func.bool_and(Employee.is_manager.is_(True)))
11731195
>>>> CustomUserAccessControl(access_logic)
@@ -1303,7 +1325,7 @@ def is_accessible_by(self, user_or_token, mode="read"):
13031325

13041326
# Query for the value of the access_func for this particular record and
13051327
# return the result.
1306-
result = DBSession().execute(stmt).scalar_one() > 0
1328+
result = HandlerSession().execute(stmt).scalar_one() > 0
13071329
if result is None:
13081330
result = False
13091331

@@ -1353,7 +1375,7 @@ def get_if_accessible_by(
13531375

13541376
# TODO: vectorize this
13551377
for pk in standardized:
1356-
instance = DBSession().query(cls).options(options).get(pk.item())
1378+
instance = HandlerSession().query(cls).options(options).get(pk.item())
13571379
if instance is None or not instance.is_accessible_by(
13581380
user_or_token, mode=mode
13591381
):
@@ -1467,7 +1489,7 @@ def get(
14671489
standardized = np.atleast_1d(id_or_list)
14681490
result = []
14691491

1470-
with DBSession() as session:
1492+
with HandlerSession() as session:
14711493
# TODO: vectorize this
14721494
for pk in standardized:
14731495
if options:
@@ -1519,7 +1541,7 @@ def get_all(
15191541
If columns is specified, will return a list of tuples
15201542
containing the data from each column requested.
15211543
"""
1522-
with DBSession() as session:
1544+
with HandlerSession() as session:
15231545
stmt = cls.select(user_or_token, mode, options, columns)
15241546
values = session.scalars(stmt).all()
15251547

@@ -1566,7 +1588,7 @@ def select(
15661588
stmt = stmt.options(option)
15671589
return stmt
15681590

1569-
query = DBSession.query_property()
1591+
query = HandlerSession.query_property()
15701592

15711593
id = sa.Column(
15721594
sa.Integer,
@@ -1608,8 +1630,8 @@ def __repr__(self):
16081630
def to_dict(self):
16091631
"""Serialize this object to a Python dictionary."""
16101632
if sa.inspection.inspect(self).expired:
1611-
self = DBSession().merge(self)
1612-
DBSession().refresh(self)
1633+
self = HandlerSession().merge(self)
1634+
HandlerSession().refresh(self)
16131635
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
16141636

16151637
@classmethod
@@ -1632,7 +1654,7 @@ def get_if_readable_by(cls, ident, user_or_token, options=[]):
16321654
obj : baselayer.app.models.Base
16331655
The requested entity.
16341656
"""
1635-
obj = DBSession().query(cls).options(options).get(ident)
1657+
obj = HandlerSession().query(cls).options(options).get(ident)
16361658

16371659
if obj is not None and not obj.is_readable_by(user_or_token):
16381660
raise AccessError("Insufficient permissions.")
@@ -1659,7 +1681,7 @@ def is_readable_by(self, user_or_token):
16591681
def create_or_get(cls, id):
16601682
"""Return a new `cls` if an instance with the specified primary key
16611683
does not exist, else return the existing instance."""
1662-
obj = DBSession().query(cls).get(id)
1684+
obj = HandlerSession().query(cls).get(id)
16631685
if obj is not None:
16641686
return obj
16651687
else:
@@ -1876,7 +1898,7 @@ class User(Base):
18761898
role_ids = association_proxy(
18771899
"roles",
18781900
"id",
1879-
creator=lambda r: DBSession().query(Role).get(r),
1901+
creator=lambda r: HandlerSession().query(Role).get(r),
18801902
)
18811903
tokens = relationship(
18821904
"Token",
@@ -1982,7 +2004,7 @@ class Token(Base):
19822004
lazy="selectin",
19832005
)
19842006
acl_ids = association_proxy(
1985-
"acls", "id", creator=lambda acl: DBSession().query(ACL).get(acl)
2007+
"acls", "id", creator=lambda acl: HandlerSession().query(ACL).get(acl)
19862008
)
19872009
permissions = acl_ids
19882010

0 commit comments

Comments
 (0)