Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3621296
added datasets jobs linking
ilongin Nov 26, 2025
664f543
refactoring
ilongin Nov 26, 2025
bec3daa
setting job_id to latest one
ilongin Nov 26, 2025
92882bc
improving performance
ilongin Nov 26, 2025
e34d0d4
refactor
ilongin Nov 26, 2025
7912330
refactor
ilongin Nov 26, 2025
f08bd4a
fixing logic
ilongin Nov 27, 2025
ba589ee
fixing tests
ilongin Nov 27, 2025
fdd26a8
Merge branch 'main' into ilongin/1477-fix-dataset-job-association
ilongin Nov 27, 2025
38b8b5c
Merge branch 'main' into ilongin/1477-fix-dataset-job-association
ilongin Nov 28, 2025
5c3ebca
refactoring
ilongin Nov 28, 2025
5721ef3
refactoring
ilongin Nov 28, 2025
63bd802
Merge branch 'main' into ilongin/1477-fix-dataset-job-association
ilongin Dec 1, 2025
42e12b0
fixing foreing keys
ilongin Dec 1, 2025
485bfcb
removing prints
ilongin Dec 1, 2025
d87b837
added missing methods
ilongin Dec 1, 2025
e7cc0f7
returned number of max recursion to 100 from 1000
ilongin Dec 2, 2025
72320d3
handling max depth error
ilongin Dec 2, 2025
ad78b32
added missing index and some refactoring
ilongin Dec 2, 2025
ad8de8e
added debug logs
ilongin Dec 2, 2025
23146e4
fixing removing dataset version table and added test
ilongin Dec 2, 2025
6dc3739
removing not needed check
ilongin Dec 2, 2025
aeb9307
added extra defense mechanism
ilongin Dec 3, 2025
799bdfc
fixing tests
ilongin Dec 4, 2025
4b6ec5d
connecting temp dataset created with persist with a job as well
ilongin Dec 4, 2025
54c436e
added transaction
ilongin Dec 5, 2025
46dafae
fixing test
ilongin Dec 7, 2025
9d82c96
make link_dataset_version_to_job to update dataset version job_id as …
ilongin Dec 8, 2025
bfca529
Merge branch 'main' into ilongin/1477-fix-dataset-job-association
ilongin Dec 8, 2025
dd9070b
added print
ilongin Dec 8, 2025
14588e3
Merge branch 'main' into ilongin/1477-fix-dataset-job-association
ilongin Dec 9, 2025
2d6de13
fixing print
ilongin Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
create_dataset_uri,
parse_dataset_name,
parse_dataset_uri,
parse_schema,
)
from datachain.error import (
DataChainError,
Expand Down Expand Up @@ -1581,7 +1582,7 @@ def _instantiate(ds_uri: str) -> None:
leave=False,
)

schema = DatasetRecord.parse_schema(remote_ds_version.schema)
schema = parse_schema(remote_ds_version.schema)

local_ds = self.create_dataset(
local_ds_name,
Expand Down
212 changes: 208 additions & 4 deletions src/datachain/data_storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
Column,
DateTime,
ForeignKey,
Index,
Integer,
Table,
Text,
UniqueConstraint,
cast,
desc,
literal,
select,
Expand All @@ -39,6 +41,7 @@
DatasetStatus,
DatasetVersion,
StorageURI,
parse_schema,
)
from datachain.error import (
CheckpointNotFoundError,
Expand Down Expand Up @@ -78,6 +81,7 @@ class AbstractMetastore(ABC, Serializable):
namespace_class: type[Namespace] = Namespace
project_class: type[Project] = Project
dataset_class: type[DatasetRecord] = DatasetRecord
dataset_version_class: type[DatasetVersion] = DatasetVersion
dataset_list_class: type[DatasetListRecord] = DatasetListRecord
dataset_list_version_class: type[DatasetListVersion] = DatasetListVersion
dependency_class: type[DatasetDependency] = DatasetDependency
Expand Down Expand Up @@ -484,6 +488,38 @@ def create_checkpoint(
) -> Checkpoint:
"""Creates new checkpoint"""

#
# Dataset Version Jobs (many-to-many)
#

@abstractmethod
def link_dataset_version_to_job(
self,
dataset_version_id: int,
job_id: str,
is_creator: bool = False,
conn=None,
) -> None:
"""Link dataset version to job."""

@abstractmethod
def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
"""Get all ancestor job IDs for a given job."""

@abstractmethod
def get_dataset_version_for_job_ancestry(
self,
dataset_name: str,
namespace_name: str,
project_name: str,
job_id: str,
conn=None,
) -> DatasetVersion | None:
"""
Find the dataset version that was created by any job in the ancestry.
Returns the most recently linked version from these jobs.
"""


class AbstractDBMetastore(AbstractMetastore):
"""
Expand All @@ -498,6 +534,7 @@ class AbstractDBMetastore(AbstractMetastore):
DATASET_TABLE = "datasets"
DATASET_VERSION_TABLE = "datasets_versions"
DATASET_DEPENDENCY_TABLE = "datasets_dependencies"
DATASET_VERSION_JOBS_TABLE = "dataset_version_jobs"
JOBS_TABLE = "jobs"
CHECKPOINTS_TABLE = "checkpoints"

Expand Down Expand Up @@ -1125,7 +1162,7 @@ def update_dataset(
dataset_values[field] = None
else:
values[field] = json.dumps(value)
dataset_values[field] = DatasetRecord.parse_schema(value)
dataset_values[field] = parse_schema(value)
elif field == "project_id":
if not value:
raise ValueError("Cannot set empty project_id for dataset")
Expand Down Expand Up @@ -1176,9 +1213,7 @@ def update_dataset_version(

if field == "schema":
values[field] = json.dumps(value) if value else None
version_values[field] = (
DatasetRecord.parse_schema(value) if value else None
)
version_values[field] = parse_schema(value) if value else None
elif field == "feature_schema":
if value is None:
values[field] = None
Expand Down Expand Up @@ -1850,6 +1885,52 @@ def _checkpoints(self) -> "Table":
@abstractmethod
def _checkpoints_insert(self) -> "Insert": ...

@staticmethod
def _dataset_version_jobs_columns() -> "list[SchemaItem]":
"""Junction table for dataset versions and jobs many-to-many relationship."""
return [
Column("id", Integer, primary_key=True),
Column(
"dataset_version_id",
Integer,
ForeignKey("datasets_versions.id", ondelete="CASCADE"),
nullable=False,
),
Column(
"job_id",
Text,
ForeignKey("jobs.id", ondelete="CASCADE"),
nullable=False,
),
Column("is_creator", Boolean, nullable=False, default=False),
Column("created_at", DateTime(timezone=True)),
UniqueConstraint("dataset_version_id", "job_id"),
Index("dc_idx_dvj_query", "job_id", "is_creator", "created_at"),
]

@cached_property
def _dataset_version_jobs_fields(self) -> list[str]:
return [c.name for c in self._dataset_version_jobs_columns() if c.name] # type: ignore[attr-defined]

@cached_property
def _dataset_version_jobs(self) -> "Table":
return Table(
self.DATASET_VERSION_JOBS_TABLE,
self.db.metadata,
*self._dataset_version_jobs_columns(),
)

@abstractmethod
def _dataset_version_jobs_insert(self) -> "Insert": ...

def _dataset_version_jobs_select(self, *columns) -> "Select":
if not columns:
return self._dataset_version_jobs.select()
return select(*columns)

def _dataset_version_jobs_delete(self) -> "Delete":
return self._dataset_version_jobs.delete()

def _checkpoints_select(self, *columns) -> "Select":
if not columns:
return self._checkpoints.select()
Expand Down Expand Up @@ -1928,3 +2009,126 @@ def get_last_checkpoint(self, job_id: str, conn=None) -> Checkpoint | None:
if not rows:
return None
return self.checkpoint_class.parse(*rows[0])

def link_dataset_version_to_job(
self,
dataset_version_id: int,
job_id: str,
is_creator: bool = False,
conn=None,
) -> None:
query = self._dataset_version_jobs_insert().values(
dataset_version_id=dataset_version_id,
job_id=job_id,
is_creator=is_creator,
created_at=datetime.now(timezone.utc),
)
if hasattr(query, "on_conflict_do_nothing"):
query = query.on_conflict_do_nothing(
index_elements=["dataset_version_id", "job_id"]
)
self.db.execute(query, conn=conn)

def get_ancestor_job_ids(self, job_id: str, conn=None) -> list[str]:
# Use recursive CTE to walk up the parent chain
# Format: WITH RECURSIVE ancestors(id, parent_job_id, depth) AS (...)
# Include depth tracking to prevent infinite recursion in case of
# circular dependencies
max_depth = 100

ancestors_cte = (
self._jobs_select(
self._jobs.c.id.label("id"),
self._jobs.c.parent_job_id.label("parent_job_id"),
literal(0).label("depth"),
)
.where(self._jobs.c.id == job_id)
.cte(name="ancestors", recursive=True)
)

# Recursive part: join with parent jobs, incrementing depth and checking limit
ancestors_recursive = ancestors_cte.union_all(
self._jobs_select(
self._jobs.c.id.label("id"),
self._jobs.c.parent_job_id.label("parent_job_id"),
(ancestors_cte.c.depth + 1).label("depth"),
).select_from(
self._jobs.join(
ancestors_cte,
(
self._jobs.c.id
== cast(ancestors_cte.c.parent_job_id, self._jobs.c.id.type)
)
& (ancestors_cte.c.depth < max_depth),
)
)
)

# Select all ancestor IDs except the starting job itself
query = select(ancestors_recursive.c.id).where(
ancestors_recursive.c.id != job_id
)

results = list(self.db.execute(query, conn=conn))
return [str(row[0]) for row in results]

def _get_dataset_version_for_job_ancestry_query(
self,
dataset_name: str,
namespace_name: str,
project_name: str,
job_ancestry: list[str],
) -> "Select":
"""Build query to find dataset version created by job ancestry."""
return (
self._datasets_versions_select()
.select_from(
self._dataset_version_jobs.join(
self._datasets_versions,
self._dataset_version_jobs.c.dataset_version_id
== self._datasets_versions.c.id,
)
.join(
self._datasets,
self._datasets_versions.c.dataset_id == self._datasets.c.id,
)
.join(
self._projects,
self._datasets.c.project_id == self._projects.c.id,
)
.join(
self._namespaces,
self._projects.c.namespace_id == self._namespaces.c.id,
)
)
.where(
self._datasets.c.name == dataset_name,
self._namespaces.c.name == namespace_name,
self._projects.c.name == project_name,
self._dataset_version_jobs.c.job_id.in_(job_ancestry),
self._dataset_version_jobs.c.is_creator == True, # noqa: E712
)
.order_by(desc(self._dataset_version_jobs.c.created_at))
.limit(1)
)

def get_dataset_version_for_job_ancestry(
self,
dataset_name: str,
namespace_name: str,
project_name: str,
job_id: str,
conn=None,
) -> DatasetVersion | None:
# Get job ancestry (current job + all ancestors)
job_ancestry = [job_id, *self.get_ancestor_job_ids(job_id, conn=conn)]

query = self._get_dataset_version_for_job_ancestry_query(
dataset_name, namespace_name, project_name, job_ancestry
)

results = list(self.db.execute(query, conn=conn))
if not results:
return None

return self.dataset_version_class.parse(*results[0])
5 changes: 5 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def _init_tables(self) -> None:
self.default_table_names.append(self._jobs.name)
self.db.create_table(self._checkpoints, if_not_exists=True)
self.default_table_names.append(self._checkpoints.name)
self.db.create_table(self._dataset_version_jobs, if_not_exists=True)
self.default_table_names.append(self._dataset_version_jobs.name)

def _init_namespaces_projects(self) -> None:
"""
Expand Down Expand Up @@ -581,6 +583,9 @@ def _jobs_insert(self) -> "Insert":
def _checkpoints_insert(self) -> "Insert":
return sqlite.insert(self._checkpoints)

def _dataset_version_jobs_insert(self) -> "Insert":
return sqlite.insert(self._dataset_version_jobs)

#
# Namespaces
#
Expand Down
Loading
Loading