diff --git a/autotest/test_mf6_set_data_replace.py b/autotest/test_mf6_set_data_replace.py new file mode 100644 index 000000000..463bbf44b --- /dev/null +++ b/autotest/test_mf6_set_data_replace.py @@ -0,0 +1,263 @@ +""" +Test set_data() replace parameter (issue #2663). This parameter +toggles whether .set_data() has update or replacement semantics. +""" + +from pathlib import Path + +import numpy as np +import pytest + +import flopy + +pytestmark = pytest.mark.mf6 + + +def count_stress_periods(file_path): + """Count the number of 'BEGIN period' statements in an input file.""" + with open(file_path, "r") as f: + return sum(1 for line in f if line.strip().upper().startswith("BEGIN PERIOD")) + + +@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"]) +@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"]) +def test_set_data_replace_array_based_pkg(function_tmpdir, replace, use_pandas): + name = "array_based" + og_ws = Path(function_tmpdir) / "original" + og_ws.mkdir(exist_ok=True) + + nlay, nrow, ncol = 1, 10, 10 + nper_original = 48 + nper_new = 12 + + sim = flopy.mf6.MFSimulation( + sim_name=name, + sim_ws=str(og_ws), + exe_name="mf6", + use_pandas=use_pandas, + ) + tdis = flopy.mf6.ModflowTdis( + sim, + nper=nper_original, + perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)], + ) + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis( + gwf, + nlay=nlay, + nrow=nrow, + ncol=ncol, + delr=100.0, + delc=100.0, + top=100.0, + botm=0.0, + ) + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) + oc = flopy.mf6.ModflowGwfoc( + gwf, + budget_filerecord=f"{name}.cbc", + head_filerecord=f"{name}.hds", + saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")], + ) + rch_data = {kper: 0.001 + kper * 0.0001 for kper in range(nper_original)} + rcha = flopy.mf6.ModflowGwfrcha(gwf, recharge=rch_data) + + sim.write_simulation() + + original_rch_file = og_ws / f"{name}.rcha" + original_sp_count = count_stress_periods(original_rch_file) + assert original_sp_count == nper_original + + # Update RCH + new_rch_data = {kper: 0.002 + kper * 0.0002 for kper in range(nper_new)} + rcha.recharge.set_data(new_rch_data, replace=replace) + + # Update TDIS + tdis.nper = nper_new + tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] + + mod_ws = Path(function_tmpdir) / f"modified_replace_{replace}" + mod_ws.mkdir(exist_ok=True) + sim.set_sim_path(str(mod_ws)) + sim.write_simulation() + + modified_rch_file = mod_ws / f"{name}.rcha" + modified_sp_count = count_stress_periods(modified_rch_file) + + if replace: + # With replace=True, should only have 12 stress periods + assert modified_sp_count == nper_new, ( + f"Expected {nper_new} stress periods " + f"with replace=True, got {modified_sp_count}" + ) + else: + # With replace=False (backwards compatible), all 48 periods remain + assert modified_sp_count == nper_original, ( + f"Expected {nper_original} stress periods " + f"with replace=False, got {modified_sp_count}" + ) + + with open(modified_rch_file, "r") as f: + content = f.read() + assert "0.00200000" in content or "2.00000000E-03" in content + assert "0.00420000" in content or "4.20000000E-03" in content + + +@pytest.mark.parametrize("replace", [False, True], ids=["replace", "no_replace"]) +@pytest.mark.parametrize("use_pandas", [False, True], ids=["use_pandas", "no_pandas"]) +def test_set_data_replace_list_based_pkg(function_tmpdir, replace, use_pandas): + name = "list_based" + sim_ws = Path(function_tmpdir) / "wel_original" + sim_ws.mkdir(exist_ok=True) + + nlay, nrow, ncol = 1, 10, 10 + nper_original = 24 + nper_new = 6 + + sim = flopy.mf6.MFSimulation( + sim_name=name, sim_ws=str(sim_ws), exe_name="mf6", use_pandas=use_pandas + ) + tdis = flopy.mf6.ModflowTdis( + sim, + nper=nper_original, + perioddata=[(1.0, 1, 1.0) for _ in range(nper_original)], + ) + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis( + gwf, + nlay=nlay, + nrow=nrow, + ncol=ncol, + delr=100.0, + delc=100.0, + top=100.0, + botm=0.0, + ) + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, icelltype=1, k=10.0) + oc = flopy.mf6.ModflowGwfoc( + gwf, + budget_filerecord=f"{name}.cbc", + head_filerecord=f"{name}.hds", + saverecord=[("HEAD", "LAST"), ("BUDGET", "LAST")], + ) + wel_data = { + kper: [[(0, 5, 5), -1000.0 - kper * 10.0]] for kper in range(nper_original) + } + wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=wel_data) + + sim.write_simulation() + + original_wel_file = sim_ws / f"{name}.wel" + original_sp_count = count_stress_periods(original_wel_file) + assert original_sp_count == nper_original + + # Update WEL + new_wel_data = { + kper: [[(0, 5, 5), -2000.0 - kper * 20.0]] for kper in range(nper_new) + } + wel.stress_period_data.set_data(new_wel_data, replace=replace) + + # Update TDIS + tdis.nper = nper_new + tdis.perioddata = [(1.0, 1, 1.0) for _ in range(nper_new)] + + mod_ws = Path(function_tmpdir) / f"wel_modified_replace_{replace}" + mod_ws.mkdir(exist_ok=True) + sim.set_sim_path(str(mod_ws)) + sim.write_simulation() + + modified_wel_file = mod_ws / f"{name}.wel" + modified_sp_count = count_stress_periods(modified_wel_file) + + if replace: + # With replace=True, should only have 6 stress periods + assert modified_sp_count == nper_new, ( + f"Expected {nper_new} stress periods with " + f"replace=True, got {modified_sp_count}" + ) + else: + # With replace=False, all 24 periods remain + assert modified_sp_count == nper_original, ( + f"Expected {nper_original} stress periods with " + f"replace=False, got {modified_sp_count}" + ) + + +def test_set_data_update_array_based_pkg(function_tmpdir): + name = "update_array_based" + sim_ws = Path(function_tmpdir) / "compat" + sim_ws.mkdir(exist_ok=True) + + sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6") + tdis = flopy.mf6.ModflowTdis( + sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)] + ) + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10) + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0) + oc = flopy.mf6.ModflowGwfoc(gwf) + + initial_data = dict.fromkeys(range(5), 0.001) + rch = flopy.mf6.ModflowGwfrcha(gwf, recharge=initial_data) + + additional_data = dict.fromkeys(range(5, 10), 0.002) + rch.recharge.set_data(additional_data) # replace defaults to False + + sim.write_simulation() + + sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) + gwf2 = sim2.get_model(name) + rch2 = gwf2.get_package("RCHA") + + for kper in range(10): + data = rch2.recharge.get_data(key=kper) + assert np.allclose(data, 0.001 if kper < 5 else 0.002) + + +def test_set_data_update_list_based_pkg(function_tmpdir): + name = "update_list_based" + sim_ws = Path(function_tmpdir) / "wel_update" + sim_ws.mkdir(exist_ok=True) + + sim = flopy.mf6.MFSimulation(sim_name=name, sim_ws=str(sim_ws), exe_name="mf6") + tdis = flopy.mf6.ModflowTdis( + sim, nper=10, perioddata=[(1.0, 1, 1.0) for _ in range(10)] + ) + ims = flopy.mf6.ModflowIms(sim) + gwf = flopy.mf6.ModflowGwf(sim, modelname=name) + dis = flopy.mf6.ModflowGwfdis(gwf, nlay=1, nrow=10, ncol=10) + ic = flopy.mf6.ModflowGwfic(gwf, strt=100.0) + npf = flopy.mf6.ModflowGwfnpf(gwf, k=10.0) + oc = flopy.mf6.ModflowGwfoc(gwf) + + initial_data = {kper: [[(0, 5, 5), -1000.0]] for kper in range(5)} + wel = flopy.mf6.ModflowGwfwel(gwf, stress_period_data=initial_data) + + additional_data = {kper: [[(0, 7, 7), -2000.0]] for kper in range(5, 10)} + wel.stress_period_data.set_data(additional_data) # replace defaults to False + + sim.write_simulation() + + sim2 = flopy.mf6.MFSimulation.load(sim_ws=str(sim_ws)) + gwf2 = sim2.get_model(name) + wel2 = gwf2.get_package("WEL") + + for kper in range(10): + data = wel2.stress_period_data.get_data(key=kper) + assert data is not None, f"Period {kper} should have data" + if kper < 5: + # Original data should be at (0, 5, 5) + assert len(data) == 1 + assert data[0]["cellid"] == (0, 5, 5) + assert data[0]["q"] == -1000.0 + else: + # Additional data should be at (0, 7, 7) + assert len(data) == 1 + assert data[0]["cellid"] == (0, 7, 7) + assert data[0]["q"] == -2000.0 diff --git a/flopy/mf6/data/mfdataarray.py b/flopy/mf6/data/mfdataarray.py index e00c877b1..27d7e2a92 100644 --- a/flopy/mf6/data/mfdataarray.py +++ b/flopy/mf6/data/mfdataarray.py @@ -1890,7 +1890,7 @@ def _build_period_data( output[sp] = data return output - def set_record(self, data_record): + def set_record(self, data_record, replace=False): """Sets data and metadata at layer `layer` and time `key` to `data_record`. For unlayered data do not pass in `layer`. @@ -1902,10 +1902,15 @@ def set_record(self, data_record): and metadata (factor, iprn, filename, binary, data) for a given stress period. How to define the dictionary of data and metadata is described in the MFData class's set_record method. + replace : bool + Perform the operation with replacement semantics: all existing + stress period keys not present in the new dictionary will be + removed. If False, existing keys not in the new dictionary + will be preserved. Defaults False for backwards compatibility. """ - self._set_data_record(data_record, is_record=True) + self._set_data_record(data_record, is_record=True, replace=replace) - def set_data(self, data, multiplier=None, layer=None, key=None): + def set_data(self, data, multiplier=None, layer=None, key=None, replace=False): """Sets the contents of the data at layer `layer` and time `key` to `data` with multiplier `multiplier`. For unlayered data do not pass in `layer`. @@ -1926,15 +1931,30 @@ def set_data(self, data, multiplier=None, layer=None, key=None): key : int Zero based stress period to assign data too. Does not apply if `data` is a dictionary. + replace : bool + If True and `data` is a dictionary, perform the operation + with replacement semantics: all existing stress period keys + not present in the new dictionary will be removed. If False, + existing keys not in the new dictionary will be preserved. + Defaults False for backwards compatibility. """ - self._set_data_record(data, multiplier, layer, key) + self._set_data_record(data, multiplier, layer, key, replace=replace) def _set_data_record( - self, data, multiplier=None, layer=None, key=None, is_record=False + self, data, multiplier=None, layer=None, key=None, is_record=False, replace=False ): if isinstance(data, dict): # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for + + # If replacing, remove keys not in the new data + if replace and self._data_storage: + keys_to_remove = set(self._data_storage.keys()) - set(data.keys()) + for k in keys_to_remove: + self.remove_transient_key(k) + if k in self.empty_keys: + del self.empty_keys[k] + del_keys = [] for key, list_item in data.items(): if list_item is None: diff --git a/flopy/mf6/data/mfdatalist.py b/flopy/mf6/data/mfdatalist.py index 529177c41..b31f47cec 100644 --- a/flopy/mf6/data/mfdatalist.py +++ b/flopy/mf6/data/mfdatalist.py @@ -1780,7 +1780,7 @@ def get_data(self, key=None, apply_mult=False, **kwargs): else: return None - def set_record(self, data_record, autofill=False, check_data=True): + def set_record(self, data_record, autofill=False, check_data=True, replace=False): """Sets the contents of the data based on the contents of 'data_record`. @@ -1795,15 +1795,21 @@ def set_record(self, data_record, autofill=False, check_data=True): Automatically correct data check_data : bool Whether to verify the data + replace : bool + Perform the operation with replacement semantics: all existing + stress period keys not present in the new dictionary will be + removed. If False, existing keys not in the new dictionary + will be preserved. Defaults False for backwards compatibility. """ self._set_data_record( data_record, autofill=autofill, check_data=check_data, is_record=True, + replace=replace, ) - def set_data(self, data, key=None, autofill=False): + def set_data(self, data, key=None, autofill=False, replace=False): """Sets the contents of the data at time `key` to `data`. Parameters @@ -1819,17 +1825,32 @@ def set_data(self, data, key=None, autofill=False): if `data` is a dictionary. autofill : bool Automatically correct data. + replace : bool + If True and `data` is a dictionary, perform the operation + with replacement semantics: all existing stress period keys + not present in the new dictionary will be removed. If False, + existing keys not in the new dictionary will be preserved. + Defaults False for backwards compatibility. """ - self._set_data_record(data, key, autofill) + self._set_data_record(data, key, autofill, replace=replace) def _set_data_record( - self, data, key=None, autofill=False, check_data=False, is_record=False + self, data, key=None, autofill=False, check_data=False, is_record=False, replace=False ): self._cache_model_grid = True if isinstance(data, dict): if "filename" not in data and "data" not in data: # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for + + # If replacing, remove keys not in the new data + if replace and self._data_storage: + keys_to_remove = set(self._data_storage.keys()) - set(data.keys()) + for k in keys_to_remove: + self.remove_transient_key(k) + if k in self.empty_keys: + del self.empty_keys[k] + del_keys = [] for key, list_item in data.items(): if list_item is None: diff --git a/flopy/mf6/data/mfdataplist.py b/flopy/mf6/data/mfdataplist.py index a7aef4625..5b9e20825 100644 --- a/flopy/mf6/data/mfdataplist.py +++ b/flopy/mf6/data/mfdataplist.py @@ -2271,7 +2271,7 @@ def get_data(self, key=None, apply_mult=False, dataframe=False, **kwargs): else: return None - def set_record(self, record, autofill=False, check_data=True): + def set_record(self, record, autofill=False, check_data=True, replace=False): """Sets the contents of the data based on the contents of 'record`. @@ -2286,15 +2286,21 @@ def set_record(self, record, autofill=False, check_data=True): Automatically correct data check_data : bool Whether to verify the data + replace : bool + Perform the operation with replacement semantics: all existing + stress period keys not present in the new dictionary will be + removed. If False, existing keys not in the new dictionary + will be preserved. Defaults False for backwards compatibility. """ self._set_data_record( record, autofill=autofill, check_data=check_data, is_record=True, + replace=replace, ) - def set_data(self, data, key=None, autofill=False): + def set_data(self, data, key=None, autofill=False, replace=False): """Sets the contents of the data at time `key` to `data`. Parameters @@ -2310,8 +2316,14 @@ def set_data(self, data, key=None, autofill=False): if `data` is a dictionary. autofill : bool Automatically correct data. + replace : bool + If True and `data` is a dictionary, perform the operation + with replacement semantics: all existing stress period keys + not present in the new dictionary will be removed. If False, + existing keys not in the new dictionary will be preserved. + Defaults False for backwards compatibility. """ - self._set_data_record(data, key, autofill) + self._set_data_record(data, key, autofill, replace=replace) def masked_4D_arrays_itr(self): """Returns list data as an iterator of a masked 4D array.""" @@ -2339,12 +2351,22 @@ def _set_data_record( autofill=False, check_data=False, is_record=False, + replace=False, ): self._cache_model_grid = True if isinstance(data_record, dict): if "filename" not in data_record and "data" not in data_record: # each item in the dictionary is a list for one stress period # the dictionary key is the stress period the list is for + + # If replacing, remove keys not in the new data + if replace and self._data_storage: + keys_to_remove = set(self._data_storage.keys()) - set(data_record.keys()) + for k in keys_to_remove: + self.remove_transient_key(k) + if k in self.empty_keys: + del self.empty_keys[k] + del_keys = [] for key, list_item in data_record.items(): list_item_record = False diff --git a/flopy/mf6/mfpackage.py b/flopy/mf6/mfpackage.py index 702becfcc..270a9bb65 100644 --- a/flopy/mf6/mfpackage.py +++ b/flopy/mf6/mfpackage.py @@ -1285,6 +1285,37 @@ def write(self, fd, ext_file_action=ExtFileAction.copy_relative_paths): return if self.structure.repeating(): repeating_datasets = self._find_repeating_datasets() + + # First, collect active keys from ALL datasets in this block + # This is important for blocks with multiple datasets (e.g., storage package + # has both "steady-state" and "transient" datasets) that share block_headers. + # We need to preserve headers that are active in ANY dataset, not just the + # current one being processed. + all_active_keys = set() + for repeating_dataset in repeating_datasets: + for key_data in repeating_dataset.get_active_key_list(): + all_active_keys.add(key_data[0]) + for key, value in repeating_dataset.empty_keys.items(): + if value: + all_active_keys.add(key) + + # Clean up stale block headers once, using combined active keys from all datasets + # Only clean up if we have multiple headers and active data. + # This avoids breaking the initial write case where block_headers + # may have a template header with transient_key=None. Otherwise we + # get IndexError when _build_repeating_header tries to use index -1. + if len(self.block_headers) > 1 and all_active_keys: + headers_to_remove = [] + for i, header in enumerate(self.block_headers): + k = header.get_transient_key() + if k is not None and k not in all_active_keys: + headers_to_remove.append(i) + + # Remove in reverse order to preserve indices + for i in reversed(headers_to_remove): + del self.block_headers[i] + + # Now add missing block headers for each dataset for repeating_dataset in repeating_datasets: # resolve any missing block headers self._add_missing_block_headers(repeating_dataset)