Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,41 @@ def test_restricted_samples_float(tmp_path):
assert data is not None


def test_restricted_indices(tmp_path):
filename = tmp_path / "dummy_well_data.hdf5"
write_dummy_data(filename)
# Create dataset without restrictions to get the full length
full_dataset = WellDataset(
path=str(tmp_path),
use_normalization=False,
return_grid=True,
)
full_length = len(full_dataset)

# Exclude specific indices
indices_to_exclude = [0, 1, 5, 10]
dataset = WellDataset(
path=str(tmp_path),
use_normalization=False,
return_grid=True,
restrict_indices=indices_to_exclude,
)

expected_length = full_length - len(indices_to_exclude)
assert (
len(dataset) == expected_length
), f"Restricted dataset should contain {expected_length} samples (18 - 4 excluded), but found {len(dataset)}"

# Verify we can still access data
data = dataset[0]
assert data is not None

# Verify that the restriction set doesn't contain excluded indices
assert all(
idx not in indices_to_exclude for idx in dataset.restriction_set
), "Restriction set should not contain any excluded indices"


@pytest.mark.parametrize("start_output_steps_at_t", [-1, 4])
def test_full_trajectory_mode_minimum_steps(tmp_path, start_output_steps_at_t):
filename = tmp_path / "dummy_well_data.hdf5"
Expand Down
41 changes: 37 additions & 4 deletions the_well/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class WellDataset(Dataset):
Whether to restrict the number of trajectories to a subset of the dataset. Integer inputs restrict to a number. Float to a percentage.
restrict_num_samples:
Whether to restrict the number of samples to a subset of the dataset. Integer inputs restrict to a number. Float to a percentage.
restrict_indices:
List of global indices to skip/exclude from the dataset. Only one restriction type should be used.
restriction_seed:
Seed used to generate restriction set. Necessary to ensure same set is sampled across runs.
cache_small:
Expand Down Expand Up @@ -205,6 +207,7 @@ def __init__(
flatten_tensors: bool = True,
restrict_num_trajectories: Optional[float | int] = None,
restrict_num_samples: Optional[float | int] = None,
restrict_indices: Optional[list[int]] = None,
restriction_seed: int = 0,
cache_small: bool = True,
max_cache_size: float = 1e9,
Expand Down Expand Up @@ -272,6 +275,7 @@ def __init__(
self.flatten_tensors = flatten_tensors
self.restrict_num_trajectories = restrict_num_trajectories
self.restrict_num_samples = restrict_num_samples
self.restrict_indices = restrict_indices
self.restriction_seed = restriction_seed
self.return_grid = return_grid
self.normalize_time_grid = normalize_time_grid
Expand Down Expand Up @@ -341,23 +345,41 @@ def __init__(

# If we're limiting number of samples/trajectories...
self.restriction_set = None
if restrict_num_samples is not None or restrict_num_trajectories is not None:
if (
restrict_num_samples is not None
or restrict_num_trajectories is not None
or restrict_indices is not None
):
self._build_restriction_set(
restrict_num_samples, restrict_num_trajectories, restriction_seed
restrict_num_samples,
restrict_num_trajectories,
restrict_indices,
restriction_seed,
)


def _build_restriction_set(
self,
restrict_num_samples: Optional[int | float],
restrict_num_trajectories: Optional[int | float],
restrict_indices: Optional[list[int]],
seed: int,
):
"""Builds a restriction set for the dataset based on the specified restrictions"""
gen = np.random.default_rng(seed)
if restrict_num_samples is not None and restrict_num_trajectories is not None:
non_none_count = sum(
[
restrict_num_samples is not None,
restrict_num_trajectories is not None,
restrict_indices is not None,
]
)

if non_none_count > 1:
warnings.warn(
"Both restrict_num_samples and restrict_num_trajectories are set. Using restrict_num_samples."
"More than one restriction is set. Using restrict_num_samples."
)

global_indices = np.arange(self.len)
if restrict_num_trajectories is not None:
# Compute total number of trajectories, collect all indices corresponding to them, then select a subset
Expand Down Expand Up @@ -395,6 +417,17 @@ def _build_restriction_set(
current_index += self.n_windows_per_trajectory[file_index]
global_indices = np.array(global_indices)

if restrict_indices is not None:
# Skip the specified indices by creating a mask
skip_set = set(restrict_indices)
# Filter out indices to skip, keeping only those not in skip_set
global_indices = np.array(
[idx for idx in global_indices if idx not in skip_set]
)
if len(global_indices) == 0:
warnings.warn(
"All indices were excluded by restrict_indices. Dataset will be empty."
)
if restrict_num_samples is not None:
if 0.0 < restrict_num_samples < 1.0:
restrict_num_samples = int(self.len * restrict_num_samples)
Expand Down
Loading