Skip to content

Commit 14ef1ad

Browse files
committed
cleanup and fix resource handling
1 parent ae6a529 commit 14ef1ad

File tree

19 files changed

+350
-201
lines changed

19 files changed

+350
-201
lines changed

src/datachain/catalog/catalog.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import traceback
1111
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
12+
from contextlib import ExitStack
1213
from copy import copy
1314
from dataclasses import dataclass
1415
from functools import cached_property, reduce
@@ -314,6 +315,16 @@ def download(self, recursive: bool = False, pbar=None) -> None:
314315
if self.sources:
315316
self.client.fetch_nodes(self.iternodes(recursive), shared_progress_bar=pbar)
316317

318+
def close(self) -> None:
319+
if self.listing:
320+
self.listing.close()
321+
322+
def __enter__(self) -> "NodeGroup":
323+
return self
324+
325+
def __exit__(self, exc_type, exc_value, traceback) -> None:
326+
self.close()
327+
317328

318329
def prepare_output_for_cp(
319330
node_groups: list[NodeGroup],
@@ -1918,38 +1929,40 @@ def cp(
19181929
no_glob,
19191930
client_config=client_config,
19201931
)
1932+
with ExitStack() as stack:
1933+
for node_group in node_groups:
1934+
stack.enter_context(node_group)
1935+
always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
1936+
node_groups, output, force, no_cp
1937+
)
1938+
total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
1939+
if not total_files:
1940+
return
19211941

1922-
always_copy_dir_contents, copy_to_filename = prepare_output_for_cp(
1923-
node_groups, output, force, no_cp
1924-
)
1925-
total_size, total_files = collect_nodes_for_cp(node_groups, recursive)
1926-
if not total_files:
1927-
return
1928-
1929-
desc_max_len = max(len(output) + 16, 19)
1930-
bar_format = (
1931-
"{desc:<"
1932-
f"{desc_max_len}"
1933-
"}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
1934-
"[{elapsed}<{remaining}, {rate_fmt:>8}]"
1935-
)
1942+
desc_max_len = max(len(output) + 16, 19)
1943+
bar_format = (
1944+
"{desc:<"
1945+
f"{desc_max_len}"
1946+
"}{percentage:3.0f}%|{bar}| {n_fmt:>5}/{total_fmt:<5} "
1947+
"[{elapsed}<{remaining}, {rate_fmt:>8}]"
1948+
)
19361949

1937-
if not no_cp:
1938-
with get_download_bar(bar_format, total_size) as pbar:
1939-
for node_group in node_groups:
1940-
node_group.download(recursive=recursive, pbar=pbar)
1950+
if not no_cp:
1951+
with get_download_bar(bar_format, total_size) as pbar:
1952+
for node_group in node_groups:
1953+
node_group.download(recursive=recursive, pbar=pbar)
19411954

1942-
instantiate_node_groups(
1943-
node_groups,
1944-
output,
1945-
bar_format,
1946-
total_files,
1947-
force,
1948-
recursive,
1949-
no_cp,
1950-
always_copy_dir_contents,
1951-
copy_to_filename,
1952-
)
1955+
instantiate_node_groups(
1956+
node_groups,
1957+
output,
1958+
bar_format,
1959+
total_files,
1960+
force,
1961+
recursive,
1962+
no_cp,
1963+
always_copy_dir_contents,
1964+
copy_to_filename,
1965+
)
19531966

19541967
def du(
19551968
self,

src/datachain/data_storage/metastore.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from abc import ABC, abstractmethod
55
from collections.abc import Iterator
6+
from contextlib import contextmanager, suppress
67
from datetime import datetime, timezone
78
from functools import cached_property, reduce
89
from itertools import groupby
@@ -118,6 +119,16 @@ def close_on_exit(self) -> None:
118119
differently."""
119120
self.close()
120121

122+
@contextmanager
123+
def _init_guard(self):
124+
"""Ensure resources acquired during __init__ are released on failure."""
125+
try:
126+
yield
127+
except Exception:
128+
with suppress(Exception):
129+
self.close_on_exit()
130+
raise
131+
121132
def cleanup_tables(self, temp_table_names: list[str]) -> None:
122133
"""Cleanup temp tables."""
123134

src/datachain/data_storage/sqlite.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,12 @@ def __init__(
374374

375375
self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)
376376

377-
self._init_meta_table()
378-
self._init_meta_schema_value()
379-
self._check_schema_version()
380-
self._init_tables()
381-
self._init_namespaces_projects()
377+
with self._init_guard():
378+
self._init_meta_table()
379+
self._init_meta_schema_value()
380+
self._check_schema_version()
381+
self._init_tables()
382+
self._init_namespaces_projects()
382383

383384
def __exit__(self, exc_type, exc_value, traceback) -> None:
384385
"""Close connection upon exit from context manager."""

src/datachain/listing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
self.dataset_name = dataset_name # dataset representing bucket listing
3636
self.dataset_version = dataset_version # dataset representing bucket listing
3737
self.column = column
38+
self._closed = False
3839

3940
def clone(self) -> "Listing":
4041
return self.__class__(
@@ -53,7 +54,13 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
5354
self.close()
5455

5556
def close(self) -> None:
56-
self.warehouse.close()
57+
if self._closed:
58+
return
59+
self._closed = True
60+
try:
61+
self.warehouse.close_on_exit()
62+
finally:
63+
self.metastore.close_on_exit()
5764

5865
@property
5966
def uri(self):

src/datachain/query/session.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import atexit
2-
import gc
32
import logging
43
import os
54
import re
@@ -8,6 +7,7 @@
87
from collections.abc import Callable
98
from typing import TYPE_CHECKING, ClassVar
109
from uuid import uuid4
10+
from weakref import WeakSet
1111

1212
from datachain.catalog import get_catalog
1313
from datachain.data_storage import JobQueryType, JobStatus
@@ -57,6 +57,7 @@ class Session:
5757

5858
GLOBAL_SESSION_CTX: "Session | None" = None
5959
SESSION_CONTEXTS: ClassVar[list["Session"]] = []
60+
_ALL_SESSIONS: ClassVar[WeakSet["Session"]] = WeakSet()
6061
ORIGINAL_EXCEPT_HOOK = None
6162

6263
# Job management - class-level to ensure one job per process
@@ -92,6 +93,7 @@ def __init__(
9293
self.catalog = catalog or get_catalog(
9394
client_config=client_config, in_memory=in_memory
9495
)
96+
Session._ALL_SESSIONS.add(self)
9597

9698
def __enter__(self):
9799
# Push the current context onto the stack
@@ -109,6 +111,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
109111

110112
if Session.SESSION_CONTEXTS:
111113
Session.SESSION_CONTEXTS.pop()
114+
Session._ALL_SESSIONS.discard(self)
112115

113116
def get_or_create_job(self) -> "Job":
114117
"""
@@ -311,6 +314,7 @@ def except_hook(exc_type, exc_value, exc_traceback):
311314

312315
@classmethod
313316
def cleanup_for_tests(cls):
317+
cls._close_all_contexts()
314318
if cls.GLOBAL_SESSION_CTX is not None:
315319
cls.GLOBAL_SESSION_CTX.__exit__(None, None, None)
316320
cls.GLOBAL_SESSION_CTX = None
@@ -333,15 +337,26 @@ def cleanup_for_tests(cls):
333337

334338
@staticmethod
335339
def _global_cleanup():
340+
Session._close_all_contexts()
336341
if Session.GLOBAL_SESSION_CTX is not None:
337342
Session.GLOBAL_SESSION_CTX.__exit__(None, None, None)
338343

339-
for obj in gc.get_objects(): # Get all tracked objects
344+
for session in list(Session._ALL_SESSIONS):
340345
try:
341-
if isinstance(obj, Session):
342-
# Cleanup temp dataset for session variables.
343-
obj.__exit__(None, None, None)
346+
session.__exit__(None, None, None)
344347
except ReferenceError:
345348
continue # Object has been finalized already
346349
except Exception as e: # noqa: BLE001
347350
logger.error(f"Exception while cleaning up session: {e}") # noqa: G004
351+
352+
@classmethod
353+
def _close_all_contexts(cls) -> None:
354+
while cls.SESSION_CONTEXTS:
355+
session = cls.SESSION_CONTEXTS.pop()
356+
try:
357+
session.__exit__(None, None, None)
358+
except Exception as exc: # noqa: BLE001
359+
logger.error(
360+
"Exception while closing session context during cleanup: %s",
361+
exc,
362+
)

src/datachain/sql/sqlite/base.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sqlite3
44
import warnings
55
from collections.abc import Callable, Iterable
6+
from contextlib import closing
67
from datetime import MAXYEAR, MINYEAR, datetime, timezone
78
from functools import cache
89
from types import MappingProxyType
@@ -111,7 +112,10 @@ def setup():
111112
compiles(numeric.int_hash_64, "sqlite")(compile_int_hash_64)
112113
compiles(numeric.bit_hamming_distance, "sqlite")(compile_bit_hamming_distance)
113114

114-
if load_usearch_extension(sqlite3.connect(":memory:")):
115+
with closing(sqlite3.connect(":memory:")) as _usearch_conn:
116+
usearch_available = load_usearch_extension(_usearch_conn)
117+
118+
if usearch_available:
115119
compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
116120
compiles(array.euclidean_distance, "sqlite")(compile_euclidean_distance_ext)
117121
else:
@@ -145,23 +149,34 @@ def functions_exist(
145149
f"Found value of type {type(n).__name__}: {n!r}"
146150
)
147151

152+
close_connection = False
148153
if connection is None:
149154
connection = sqlite3.connect(":memory:")
155+
close_connection = True
150156

151-
if not names:
152-
return True
153-
column1 = sa.column("column1", sa.String)
154-
func_name_query = column1.not_in(
155-
sa.select(sa.column("name", sa.String)).select_from(func.pragma_function_list())
156-
)
157-
query = (
158-
sa.select(func.count() == 0)
159-
.select_from(sa.values(column1).data([(n,) for n in names]))
160-
.where(func_name_query)
161-
)
162-
comp = query.compile(dialect=sqlite_dialect)
163-
args = (comp.string, comp.params) if comp.params else (comp.string,)
164-
return bool(connection.execute(*args).fetchone()[0])
157+
try:
158+
if not names:
159+
return True
160+
column1 = sa.column("column1", sa.String)
161+
func_name_query = column1.not_in(
162+
sa.select(sa.column("name", sa.String)).select_from(
163+
func.pragma_function_list()
164+
)
165+
)
166+
query = (
167+
sa.select(func.count() == 0)
168+
.select_from(sa.values(column1).data([(n,) for n in names]))
169+
.where(func_name_query)
170+
)
171+
comp = query.compile(dialect=sqlite_dialect)
172+
if comp.params:
173+
result = connection.execute(comp.string, comp.params)
174+
else:
175+
result = connection.execute(comp.string)
176+
return bool(result.fetchone()[0])
177+
finally:
178+
if close_connection:
179+
connection.close()
165180

166181

167182
def create_user_defined_sql_functions(connection):

tests/conftest.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os.path
33
import signal
44
import subprocess # nosec B404
5+
import sys
56
import uuid
67
from collections.abc import Generator
78
from datetime import datetime
@@ -40,6 +41,10 @@
4041

4142
from .utils import DEFAULT_TREE, instantiate_tree, reset_session_job_state
4243

44+
distributed_pythonpath = os.environ.get("DATACHAIN_DISTRIBUTED_PYTHONPATH")
45+
if distributed_pythonpath and distributed_pythonpath not in sys.path:
46+
sys.path.insert(0, distributed_pythonpath)
47+
4348
DEFAULT_DATACHAIN_BIN = "datachain"
4449
DEFAULT_DATACHAIN_GIT_REPO = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4550

@@ -109,10 +114,11 @@ def monkeypatch_session() -> Generator[MonkeyPatch, None, None]:
109114

110115

111116
@pytest.fixture(autouse=True)
112-
def clean_session() -> None:
117+
def clean_session() -> Generator[None, None, None]:
113118
"""
114-
Make sure we clean leftover session before each test case
119+
Clean leftover sessions after each test while storage handles are still open.
115120
"""
121+
yield
116122
Session.cleanup_for_tests()
117123

118124

@@ -181,11 +187,13 @@ def metastore(monkeypatch):
181187

182188
yield _metastore
183189

190+
Session.cleanup_for_tests()
184191
_metastore.cleanup_for_tests()
185192
else:
186193
_metastore = SQLiteMetastore(db_file=":memory:")
187194
yield _metastore
188195

196+
Session.cleanup_for_tests()
189197
cleanup_sqlite_db(_metastore.db.clone(), _metastore.default_table_names)
190198

191199
# Close the connection so that the SQLite file is no longer open, to avoid
@@ -256,11 +264,13 @@ def metastore_tmpfile(monkeypatch, tmp_path):
256264

257265
yield _metastore
258266

267+
Session.cleanup_for_tests()
259268
_metastore.cleanup_for_tests()
260269
else:
261270
_metastore = SQLiteMetastore(db_file=str(tmp_path / "test.db"))
262271
yield _metastore
263272

273+
Session.cleanup_for_tests()
264274
cleanup_sqlite_db(_metastore.db.clone(), _metastore.default_table_names)
265275

266276
# Close the connection so that the SQLite file is no longer open, to avoid
@@ -529,7 +539,10 @@ def cloud_test_catalog(
529539
metastore,
530540
warehouse,
531541
):
532-
return get_cloud_test_catalog(cloud_server, tmp_path, metastore, warehouse)
542+
catalog = get_cloud_test_catalog(cloud_server, tmp_path, metastore, warehouse)
543+
yield catalog
544+
545+
reset_session_job_state()
533546

534547

535548
@pytest.fixture

tests/func/model/test_yolo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ def running_img_masks() -> torch.Tensor:
3939
mask1_file = os.path.join(os.path.dirname(__file__), "data", "running-mask1.png")
4040
mask1_np = np.array(Image.open(mask1_file))
4141

42-
return torch.tensor([mask0_np.astype(np.float32), mask1_np.astype(np.float32)])
42+
stacked = np.stack(
43+
[mask0_np.astype(np.float32), mask1_np.astype(np.float32)],
44+
axis=0,
45+
)
46+
return torch.from_numpy(stacked)
4347

4448

4549
def test_yolo_bbox_from_results_empty(running_img):

0 commit comments

Comments
 (0)