Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
interpreter/cling/tools/packaging @vgvassilev
/io/ @pcanal
/io/xml/ @pcanal @linev
/io/ml/ @vepadulano @siliataider
/main/ @pcanal
/math/ @hagebeck
/math/minuit @guitargeek
Expand Down
7 changes: 2 additions & 5 deletions bindings/pyroot/pythonizations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ if(dataframe)
ROOT/_pythonization/_rdf_conversion_maps.py
ROOT/_pythonization/_rdf_pyz.py
ROOT/_pythonization/_rdisplay.py
ROOT/_pythonization/_rdf_namespace.py)
ROOT/_pythonization/_rdf_namespace.py
ROOT/_pythonization/_io_ml_dataloader.py)
endif()

if(roofit)
Expand Down Expand Up @@ -59,10 +60,6 @@ if(tmva)
ROOT/_pythonization/_tmva/_tree_inference.py
ROOT/_pythonization/_tmva/_utils.py
ROOT/_pythonization/_tmva/_gnn.py)
if(dataframe)
list(APPEND PYROOT_EXTRA_PYTHON_SOURCES
ROOT/_pythonization/_tmva/_batchgenerator.py)
endif()
endif()

set(py_sources
Expand Down
14 changes: 12 additions & 2 deletions bindings/pyroot/pythonizations/python/ROOT/_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,9 @@ def TMVA(self):
hasRDF = "dataframe" in self.gROOT.GetConfigFeatures()
if hasRDF:
try:
from ._pythonization._tmva import inject_rbatchgenerator
from ._pythonization._tmva._rtensor import _AsRTensor
from ._pythonization._tmva._tree_inference import SaveXGBoost

inject_rbatchgenerator(ns)
ns.Experimental.AsRTensor = _AsRTensor
ns.Experimental.SaveXGBoost = SaveXGBoost
except Exception:
Expand Down Expand Up @@ -472,3 +470,15 @@ def uhi(self):
except ImportError:
raise Exception("Failed to pythonize the namespace uhi")
return uhi_module

@property
def IO(self):
self._finalSetup()
ns = self._fallback_getattr("IO")

from ._pythonization._io_ml_dataloader import _inject_dataloader_api

_inject_dataloader_api(ns.ML.Experimental)

del type(self).IO
return ns
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_template(
max_vec_sizes: dict[str, int] = dict(),
) -> Tuple[str, list[int]]:
"""
Generate a template for the RBatchGenerator based on the given
Generate a template for the DataLoader based on the given
RDataFrame and columns.

Args:
Expand All @@ -44,7 +44,7 @@ def get_template(
max_vec_sizes (list[int]): The length of each vector based column.

Returns:
template (str): Template for the RBatchGenerator
template (str): Template for the DataLoader
"""

if not columns:
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
sampling_ratio: float = 1.0,
replacement: bool = False,
):
"""Wrapper around the Cpp RBatchGenerator
"""Wrapper around the Cpp DataLoader

Args:
rdataframe (RNode): Name of RNode object.
Expand Down Expand Up @@ -238,15 +238,15 @@ def __init__(

self.train_columns = [c for c in self.all_columns if c not in self.target_columns + [self.weights_column]]

from ROOT import TMVA, EnableThreadSafety
import ROOT

# The RBatchGenerator will create a separate C++ thread for I/O.
# The DataLoader will create a separate C++ thread for I/O.
# Enable thread safety in ROOT from here, to make sure there is no
# interference between the main Python thread (which might call into
# cling via cppyy) and the I/O thread.
EnableThreadSafety()
ROOT.EnableThreadSafety()

self.generator = TMVA.Experimental.Internal.RBatchGenerator(template)(
self.generator = ROOT.IO.ML.Experimental.Internal.RBatchGenerator(template)(
self.noded_rdfs,
chunk_size,
block_size,
Expand Down Expand Up @@ -318,7 +318,7 @@ def GetSample(self):
try:
import numpy as np
except ImportError:
raise ImportError("Failed to import numpy in batchgenerator init")
raise ImportError("Failed to import numpy needed for the ML dataloader")

# Split the target and weight
if not self.target_given:
Expand Down Expand Up @@ -349,15 +349,15 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray:
"""Convert a RTensor into a NumPy array

Args:
batch (RTensor): Batch returned from the RBatchGenerator
batch (RTensor): Batch returned from the DataLoader

Returns:
np.ndarray: converted batch
"""
try:
import numpy as np
except ImportError:
raise ImportError("Failed to import numpy in batchgenerator init")
raise ImportError("Failed to import numpy needed for the ML dataloader")

data = batch.GetData()
batch_size, num_columns = tuple(batch.GetShape())
Expand Down Expand Up @@ -391,7 +391,7 @@ def ConvertBatchToPyTorch(self, batch: Any) -> torch.Tensor:
"""Convert a RTensor into a PyTorch tensor

Args:
batch (RTensor): Batch returned from the RBatchGenerator
batch (RTensor): Batch returned from the DataLoader

Returns:
torch.Tensor: converted batch
Expand Down Expand Up @@ -432,7 +432,7 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
Convert a RTensor into a TensorFlow tensor

Args:
batch (RTensor): Batch returned from the RBatchGenerator
batch (RTensor): Batch returned from the DataLoader

Returns:
tensorflow.Tensor: converted batch
Expand Down Expand Up @@ -510,7 +510,7 @@ def __exit__(self, type, value, traceback):
return True


class TrainRBatchGenerator:
class TrainDataLoader:
def __init__(self, base_generator: BaseGenerator, conversion_function: Callable):
"""
A generator that returns the training batches of the given
Expand Down Expand Up @@ -602,11 +602,11 @@ def __exit__(self, type, value, traceback):
return True


class ValidationRBatchGenerator:
class ValidationDataLoader:
def __init__(self, base_generator: BaseGenerator, conversion_function: Callable):
"""
A generator that returns the validation batches of the given base
generator. NOTE: The ValidationRBatchGenerator only returns batches
generator. NOTE: The ValidationDataLoader only returns batches
if the training has been run.

Args:
Expand Down Expand Up @@ -692,7 +692,7 @@ def CreateNumPyGenerators(
sampling_type: str = "",
sampling_ratio: float = 1.0,
replacement: bool = False,
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
) -> Tuple[TrainDataLoader, ValidationDataLoader]:
"""
Return two batch generators based on the given ROOT file and tree or RDataFrame
The first generator returns training batches, while the second generator
Expand Down Expand Up @@ -758,9 +758,9 @@ def CreateNumPyGenerators(
Requires load_eager = True and sampling_type = 'undersampling'. Defaults to False.

Returns:
TrainRBatchGenerator or
Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
If validation split is 0, return TrainBatchGenerator.
TrainDataLoader or
Tuple[TrainDataLoader, ValidationDataLoader]:
If validation split is 0, return TrainDataLoader.

Otherwise two generators are returned. One used to load training
batches, and one to load validation batches. NOTE: the validation
Expand Down Expand Up @@ -789,12 +789,12 @@ def CreateNumPyGenerators(
replacement,
)

train_generator = TrainRBatchGenerator(base_generator, base_generator.ConvertBatchToNumpy)
train_generator = TrainDataLoader(base_generator, base_generator.ConvertBatchToNumpy)

if validation_split == 0.0:
return train_generator, None

validation_generator = ValidationRBatchGenerator(base_generator, base_generator.ConvertBatchToNumpy)
validation_generator = ValidationDataLoader(base_generator, base_generator.ConvertBatchToNumpy)

return train_generator, validation_generator

Expand Down Expand Up @@ -884,9 +884,9 @@ def CreateTFDatasets(
Requires load_eager = True and sampling_type = 'undersampling'. Defaults to False.

Returns:
TrainRBatchGenerator or
Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
If validation split is 0, return TrainBatchGenerator.
TrainDataLoader or
Tuple[TrainDataLoader, ValidationDataLoader]:
If validation split is 0, return TrainDataLoader.

Otherwise two generators are returned. One used to load training
batches, and one to load validation batches. NOTE: the validation
Expand Down Expand Up @@ -916,8 +916,8 @@ def CreateTFDatasets(
replacement,
)

train_generator = TrainRBatchGenerator(base_generator, base_generator.ConvertBatchToTF)
validation_generator = ValidationRBatchGenerator(base_generator, base_generator.ConvertBatchToTF)
train_generator = TrainDataLoader(base_generator, base_generator.ConvertBatchToTF)
validation_generator = ValidationDataLoader(base_generator, base_generator.ConvertBatchToTF)

num_train_columns = len(train_generator.train_columns)
num_target_columns = len(train_generator.target_columns)
Expand Down Expand Up @@ -984,7 +984,7 @@ def CreatePyTorchGenerators(
sampling_type: str = "",
sampling_ratio: float = 1.0,
replacement: bool = False,
) -> Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
) -> Tuple[TrainDataLoader, ValidationDataLoader]:
"""
Return two Tensorflow Datasets based on the given ROOT file and tree or RDataFrame
The first generator returns training batches, while the second generator
Expand Down Expand Up @@ -1050,9 +1050,9 @@ def CreatePyTorchGenerators(
Requires load_eager = True and sampling_type = 'undersampling'. Defaults to False.

Returns:
TrainRBatchGenerator or
Tuple[TrainRBatchGenerator, ValidationRBatchGenerator]:
If validation split is 0, return TrainBatchGenerator.
TrainDataLoader or
Tuple[TrainDataLoader, ValidationDataLoader]:
If validation split is 0, return TrainDataLoader.

Otherwise two generators are returned. One used to load training
batches, and one to load validation batches. NOTE: the validation
Expand Down Expand Up @@ -1080,11 +1080,28 @@ def CreatePyTorchGenerators(
replacement,
)

train_generator = TrainRBatchGenerator(base_generator, base_generator.ConvertBatchToPyTorch)
train_generator = TrainDataLoader(base_generator, base_generator.ConvertBatchToPyTorch)

if validation_split == 0.0:
return train_generator

validation_generator = ValidationRBatchGenerator(base_generator, base_generator.ConvertBatchToPyTorch)
validation_generator = ValidationDataLoader(base_generator, base_generator.ConvertBatchToPyTorch)

return train_generator, validation_generator


def _inject_dataloader_api(parentmodule):
"""
Inject the public Python API in the ROOT.IO.ML namespace. This includes the
functions to create dataloaders for ML training.
"""

fns = [
CreateNumPyGenerators,
CreateTFDatasets,
CreatePyTorchGenerators,
]

for python_func in fns:
func_name = python_func.__name__
setattr(parentmodule, func_name, python_func)
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,6 @@
from ._dataloader import DataLoader
from ._factory import Factory


def inject_rbatchgenerator(ns):
from ._batchgenerator import (
CreateNumPyGenerators,
CreatePyTorchGenerators,
CreateTFDatasets,
)

python_batchgenerator_functions = [
CreateNumPyGenerators,
CreateTFDatasets,
CreatePyTorchGenerators,
]

for python_func in python_batchgenerator_functions:
func_name = python_func.__name__
setattr(ns.Experimental, func_name, python_func)

return ns


# list of python classes that are used to pythonize TMVA classes
python_classes = [Factory, DataLoader, CrossValidation]

Expand Down Expand Up @@ -138,7 +117,6 @@ def pythonize_tmva(klass, name):
func_names = get_defined_attributes(python_klass)

for func_name in func_names:

# if the TMVA class already has a function with the same name as our
# pythonization, we rename it and prefix it with an underscore
if hasattr(klass, func_name):
Expand Down
6 changes: 3 additions & 3 deletions bindings/pyroot/pythonizations/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ ROOT_ADD_PYUNITTEST(pyroot_tcomplex tcomplex_operators.py)
# Tests with memory usage
ROOT_ADD_PYUNITTEST(pyroot_memory memory.py)

# rbatchgenerator tests
if (tmva)
ROOT_ADD_PYUNITTEST(batchgen rbatchgenerator_completeness.py PYTHON_DEPS numpy tensorflow torch)
# ML dataloader tests
if (dataframe)
ROOT_ADD_PYUNITTEST(io_ml_dataloader io_ml_dataloader.py PYTHON_DEPS numpy tensorflow torch)
endif()

ROOT_ADD_PYUNITTEST(regression_18441 regression_18441.py)
Expand Down
Loading
Loading