Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions autotest/test_mf6_set_data_replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""
Test set_data() replace parameter (issue #2663). This parameter
toggles whether .set_data() has update or replacement semantics.
"""

from pathlib import Path

import numpy as np
import pytest

import flopy

pytestmark = pytest.mark.mf6


def count_stress_periods(file_path):
"""Count the number of 'BEGIN period' statements in an input file."""
with open(file_path, "r") as f:
return sum(1 for line in f if line.strip().upper().startswith("BEGIN PERIOD"))


@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"])
@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"])
def test_set_data_replace_array_based_pkg(function_tmpdir, replace, use_pandas):
name = "array_based"
og_ws = Path(function_tmpdir) / "original"
og_ws.mkdir(exist_ok=True)

nlay, nrow, ncol = 1, 10, 10
nper_original = 48
nper_new = 12

sim = flopy.mf6.MFSimulation(
sim_name=name,
sim_ws=str(og_ws),
exe_name="mf6",
use_pandas=use_pandas,
)
tdis = flopy.mf6.ModflowTdis(
sim,
nper=nper_original,
perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)],
)
ims = flopy.mf6.ModflowIms(sim)
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
dis = flopy.mf6.ModflowGwfdis(
gwf,
nlay=nlay,
nrow=nrow,
ncol=ncol,
delr=100.0,
delc=100.0,
top=100.0,
botm=0.0,
)
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0)
oc = flopy.mf6.ModflowGwfoc(
gwf,
budget_filerecord=f"{name}.cbc",
head_filerecord=f"{name}.hds",
saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")],
)
rch_data = {kper: 0.001 + kper * 0.0001 for kper in range(nper_original)}
rcha = flopy.mf6.ModflowGwfrcha(gwf, recharge=rch_data)

sim.write_simulation()

original_rch_file = og_ws / f"{name}.rcha"
original_sp_count = count_stress_periods(original_rch_file)
assert original_sp_count == nper_original

# Update RCH
new_rch_data = {kper: 0.002 + kper * 0.0002 for kper in range(nper_new)}
rcha.recharge.set_data(new_rch_data, replace=replace)

# Update TDIS
tdis.nper = nper_new
tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)]

mod_ws = Path(function_tmpdir) / f"modified_replace_{replace}"
mod_ws.mkdir(exist_ok=True)
sim.set_sim_path(str(mod_ws))
sim.write_simulation()

modified_rch_file = mod_ws / f"{name}.rcha"
modified_sp_count = count_stress_periods(modified_rch_file)

if replace:
# With replace=True, should only have 12 stress periods
assert modified_sp_count == nper_new, (
f"Expected {nper_new} stress periods "
f"with replace=True, got {modified_sp_count}"
)
else:
# With replace=False (backwards compatible), all 48 periods remain
assert modified_sp_count == nper_original, (
f"Expected {nper_original} stress periods "
f"with replace=False, got {modified_sp_count}"
)

with open(modified_rch_file, "r") as f:
content = f.read()
assert "0.00200000" in content or "2.00000000E-03" in content
assert "0.00420000" in content or "4.20000000E-03" in content


@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"])
@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"])
def test_set_data_replace_list_based_pkg(function_tmpdir, replace, use_pandas):
name = "list_based"
sim_ws = Path(function_tmpdir) / "wel_original"
sim_ws.mkdir(exist_ok=True)

nlay, nrow, ncol = 1, 10, 10
nper_original = 24
nper_new = 6

sim = flopy.mf6.MFSimulation(
sim_name=name, sim_ws=str(sim_ws), exe_name="mf6", use_pandas=use_pandas
)
tdis = flopy.mf6.ModflowTdis(
sim,
nper=nper_original,
perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)],
)
ims = flopy.mf6.ModflowIms(sim)
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
dis = flopy.mf6.ModflowGwfdis(
gwf,
nlay=nlay,
nrow=nrow,
ncol=ncol,
delr=100.0,
delc=100.0,
top=100.0,
botm=0.0,
)
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0)
oc = flopy.mf6.ModflowGwfoc(
gwf,
budget_filerecord=f"{name}.cbc",
head_filerecord=f"{name}.hds",
saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")],
)
wel_data = {
kper: [[(0, 5, 5), -1000.0 - kper * 10.0]] for kper in range(nper_original)
}
wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=wel_data)

sim.write_simulation()

original_wel_file = sim_ws / f"{name}.wel"
original_sp_count = count_stress_periods(original_wel_file)
assert original_sp_count == nper_original

# Update WEL
new_wel_data = {
kper: [[(0, 5, 5), -2000.0 - kper * 20.0]] for kper in range(nper_new)
}
wel.stress_period_data.set_data(new_wel_data, replace=replace)

# Update TDIS
tdis.nper = nper_new
tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)]

mod_ws = Path(function_tmpdir) / f"wel_modified_replace_{replace}"
mod_ws.mkdir(exist_ok=True)
sim.set_sim_path(str(mod_ws))
sim.write_simulation()

modified_wel_file = mod_ws / f"{name}.wel"
modified_sp_count = count_stress_periods(modified_wel_file)

if replace:
# With replace=True, should only have 6 stress periods
assert modified_sp_count == nper_new, (
f"Expected {nper_new} stress periods with "
f"replace=True, got {modified_sp_count}"
)
else:
# With replace=False, all 24 periods remain
assert modified_sp_count == nper_original, (
f"Expected {nper_original} stress periods with "
f"replace=False, got {modified_sp_count}"
)


def test_set_data_update_array_based_pkg(function_tmpdir):
name = "update_array_based"
sim_ws = Path(function_tmpdir) / "compat"
sim_ws.mkdir(exist_ok=True)

sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6")
tdis = flopy.mf6.ModflowTdis(
sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)]
)
ims = flopy.mf6.ModflowIms(sim)
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10)
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0)
oc = flopy.mf6.ModflowGwfoc(gwf)

initial_data = dict.fromkeys(range(5), 0.001)
rch = flopy.mf6.ModflowGwfrcha(gwf, recharge=initial_data)

additional_data = dict.fromkeys(range(5, 10), 0.002)
rch.recharge.set_data(additional_data) # replace defaults to False

sim.write_simulation()

sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws))
gwf2 = sim2.get_model(name)
rch2 = gwf2.get_package("RCHA")

for kper in range(10):
data = rch2.recharge.get_data(key=kper)
assert np.allclose(data, 0.001 if kper < 5 else 0.002)


def test_set_data_update_list_based_pkg(function_tmpdir):
name = "update_list_based"
sim_ws = Path(function_tmpdir) / "wel_update"
sim_ws.mkdir(exist_ok=True)

sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6")
tdis = flopy.mf6.ModflowTdis(
sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)]
)
ims = flopy.mf6.ModflowIms(sim)
gwf = flopy.mf6.ModflowGwf(sim, modelname=name)
dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10)
ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0)
npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0)
oc = flopy.mf6.ModflowGwfoc(gwf)

initial_data = {kper: [[(0, 5, 5), -1000.0]] for kper in range(5)}
wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=initial_data)

additional_data = {kper: [[(0, 7, 7), -2000.0]] for kper in range(5, 10)}
wel.stress_period_data.set_data(additional_data) # replace defaults to False

sim.write_simulation()

sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws))
gwf2 = sim2.get_model(name)
wel2 = gwf2.get_package("WEL")

for kper in range(10):
data = wel2.stress_period_data.get_data(key=kper)
assert data is not None, f"Period {kper} should have data"
if kper < 5:
# Original data should be at (0, 5, 5)
assert len(data) == 1
assert data[0]["cellid"] == (0, 5, 5)
assert data[0]["q"] == -1000.0
else:
# Additional data should be at (0, 7, 7)
assert len(data) == 1
assert data[0]["cellid"] == (0, 7, 7)
assert data[0]["q"] == -2000.0
30 changes: 25 additions & 5 deletions flopy/mf6/data/mfdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,7 +1890,7 @@ def _build_period_data(
output[sp] = data
return output

def set_record(self, data_record):
def set_record(self, data_record, replace=False):
"""Sets data and metadata at layer `layer` and time `key` to
`data_record`. For unlayered data do not pass in `layer`.

Expand All @@ -1902,10 +1902,15 @@ def set_record(self, data_record):
and metadata (factor, iprn, filename, binary, data) for a given
stress period. How to define the dictionary of data and
metadata is described in the MFData class's set_record method.
replace : bool
Perform the operation with replacement semantics: all existing
stress period keys not present in the new dictionary will be
removed. If False, existing keys not in the new dictionary
will be preserved. Defaults False for backwards compatibility.
"""
self._set_data_record(data_record, is_record=True)
self._set_data_record(data_record, is_record=True, replace=replace)

def set_data(self, data, multiplier=None, layer=None, key=None):
def set_data(self, data, multiplier=None, layer=None, key=None, replace=False):
"""Sets the contents of the data at layer `layer` and time `key` to
`data` with multiplier `multiplier`. For unlayered data do not pass
in `layer`.
Expand All @@ -1926,15 +1931,30 @@ def set_data(self, data, multiplier=None, layer=None, key=None):
key : int
Zero based stress period to assign data too. Does not apply
if `data` is a dictionary.
replace : bool
If True and `data` is a dictionary, perform the operation
with replacement semantics: all existing stress period keys
not present in the new dictionary will be removed. If False,
existing keys not in the new dictionary will be preserved.
Defaults False for backwards compatibility.
"""
self._set_data_record(data, multiplier, layer, key)
self._set_data_record(data, multiplier, layer, key, replace=replace)

def _set_data_record(
self, data, multiplier=None, layer=None, key=None, is_record=False
self, data, multiplier=None, layer=None, key=None, is_record=False, replace=False
):
if isinstance(data, dict):
# each item in the dictionary is a list for one stress period
# the dictionary key is the stress period the list is for

# If replacing, remove keys not in the new data
if replace and self._data_storage:
keys_to_remove = set(self._data_storage.keys()) - set(data.keys())
for k in keys_to_remove:
self.remove_transient_key(k)
if k in self.empty_keys:
del self.empty_keys[k]

del_keys = []
for key, list_item in data.items():
if list_item is None:
Expand Down
29 changes: 25 additions & 4 deletions flopy/mf6/data/mfdatalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,7 @@ def get_data(self, key=None, apply_mult=False, **kwargs):
else:
return None

def set_record(self, data_record, autofill=False, check_data=True):
def set_record(self, data_record, autofill=False, check_data=True, replace=False):
"""Sets the contents of the data based on the contents of
'data_record`.

Expand All @@ -1795,15 +1795,21 @@ def set_record(self, data_record, autofill=False, check_data=True):
Automatically correct data
check_data : bool
Whether to verify the data
replace : bool
Perform the operation with replacement semantics: all existing
stress period keys not present in the new dictionary will be
removed. If False, existing keys not in the new dictionary
will be preserved. Defaults False for backwards compatibility.
"""
self._set_data_record(
data_record,
autofill=autofill,
check_data=check_data,
is_record=True,
replace=replace,
)

def set_data(self, data, key=None, autofill=False):
def set_data(self, data, key=None, autofill=False, replace=False):
"""Sets the contents of the data at time `key` to `data`.

Parameters
Expand All @@ -1819,17 +1825,32 @@ def set_data(self, data, key=None, autofill=False):
if `data` is a dictionary.
autofill : bool
Automatically correct data.
replace : bool
If True and `data` is a dictionary, perform the operation
with replacement semantics: all existing stress period keys
not present in the new dictionary will be removed. If False,
existing keys not in the new dictionary will be preserved.
Defaults False for backwards compatibility.
"""
self._set_data_record(data, key, autofill)
self._set_data_record(data, key, autofill, replace=replace)

def _set_data_record(
self, data, key=None, autofill=False, check_data=False, is_record=False
self, data, key=None, autofill=False, check_data=False, is_record=False, replace=False
):
self._cache_model_grid = True
if isinstance(data, dict):
if "filename" not in data and "data" not in data:
# each item in the dictionary is a list for one stress period
# the dictionary key is the stress period the list is for

# If replacing, remove keys not in the new data
if replace and self._data_storage:
keys_to_remove = set(self._data_storage.keys()) - set(data.keys())
for k in keys_to_remove:
self.remove_transient_key(k)
if k in self.empty_keys:
del self.empty_keys[k]

del_keys = []
for key, list_item in data.items():
if list_item is None:
Expand Down
Loading
Loading