Skip to content

Commit fad488c

Browse files
committed
feat: validate minimum records in dataset file
1 parent 4942f89 commit fad488c

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/llama_prompt_ops/interfaces/cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,14 @@ def get_dataset_adapter_from_config(config_dict, config_path):
430430
return get_dataset_adapter(config_dict)
431431

432432

433+
def validate_min_records_in_dataset(dataset_adapter: DatasetAdapter):
434+
# The dataset must contain at least 4 records to avoid runtime errors during optimization.
435+
# This is because the data is split into 25% training, 25% validation, and 50% testing.
436+
data = dataset_adapter.load_raw_data()
437+
if len(data) < 4:
438+
raise ValueError("Dataset must contain at least 4 records")
439+
440+
433441
def get_models_from_config(config_dict, override_model_name=None, api_key=None):
434442
"""
435443
Create model adapter instances from configuration.
@@ -791,6 +799,13 @@ def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_
791799
except ValueError as e:
792800
click.echo(f"Error: {str(e)}", err=True)
793801
sys.exit(1)
802+
803+
# Validate the minimum number of records in dataset
804+
try:
805+
validate_min_records_in_dataset(dataset_adapter)
806+
except ValueError as e:
807+
click.echo(f"Error: {str(e)}", err=True)
808+
sys.exit(1)
794809

795810
# Create strategy based on config
796811
strategy = get_strategy(

tests/integration/test_cli_integration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_cli_migrate_command(self, mock_api_key_check, temp_config_file):
131131
"llama_prompt_ops.interfaces.cli.get_strategy", return_value=MagicMock()
132132
),
133133
patch("llama_prompt_ops.interfaces.cli.load_config", return_value={}),
134+
patch("llama_prompt_ops.interfaces.cli.validate_min_records_in_dataset", return_value=None),
134135
):
135136

136137
# Run the migrate command
@@ -194,6 +195,7 @@ def test_cli_config_loading(self, mock_api_key_check, facility_config_path):
194195
"llama_prompt_ops.interfaces.cli.get_strategy", return_value=MagicMock()
195196
),
196197
patch("llama_prompt_ops.interfaces.cli.load_config", return_value={}),
198+
patch("llama_prompt_ops.interfaces.cli.validate_min_records_in_dataset", return_value=None),
197199
):
198200

199201
# Run the migrate command with the real config
@@ -268,6 +270,7 @@ def test_end_to_end_cli_flow(self, mock_api_key_check, temp_config_file):
268270
return_value=MagicMock(),
269271
),
270272
patch("llama_prompt_ops.interfaces.cli.load_config", return_value={}),
273+
patch("llama_prompt_ops.interfaces.cli.validate_min_records_in_dataset", return_value=None),
271274
):
272275

273276
# Run the migrate command with the actual file output

tests/unit/test_datasets.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,22 @@ def test_custom_split_ratios(mock_dataset_adapter):
163163
assert len(train) == 70
164164
assert len(val) == 20
165165
assert len(test) == 10
166+
167+
168+
def test_minimum_records_in_dataset(simple_data_file):
169+
try:
170+
from llama_prompt_ops.interfaces.cli import validate_min_records_in_dataset
171+
except ImportError as e:
172+
pytest.skip(f"Skipping test because module import failed: {str(e)}")
173+
174+
# Sample data file has just 2 records
175+
temp_file, _ = simple_data_file
176+
177+
dataset_adapter = ConfigurableJSONAdapter(
178+
dataset_path=temp_file.name,
179+
input_field="question",
180+
golden_output_field="answer",
181+
)
182+
183+
with pytest.raises(ValueError, match="Dataset must contain at least 4 records"):
184+
validate_min_records_in_dataset(dataset_adapter)

0 commit comments

Comments
 (0)