Skip to content

Commit f69aa08

Browse files
authored
[FSTORE-1805][APPEND] Engine.add_file() got an unexpected keyword argument 'distribute' (#751)
1 parent c69ca72 commit f69aa08

File tree

3 files changed

+48
-9
lines changed

3 files changed

+48
-9
lines changed

python/hsfs/engine/python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1553,7 +1553,7 @@ def _get_app_options(
15531553
spark_job_configuration=spark_job_configuration,
15541554
)
15551555

1556-
def add_file(self, file: Optional[str]) -> Optional[str]:
1556+
def add_file(self, file: Optional[str], distribute=True) -> Optional[str]:
15571557
if not file:
15581558
return file
15591559

python/tests/engine/test_python.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3253,14 +3253,27 @@ def test_get_app_options(self, mocker):
32533253
assert mock_ingestion_job_conf.call_count == 1
32543254
assert mock_ingestion_job_conf.call_args[1]["write_options"] == {"test": 2}
32553255

3256-
def test_add_file(self):
3256+
@pytest.mark.parametrize(
3257+
"distribute_arg",
3258+
[
3259+
None, # Test without providing distribute argument (uses default)
3260+
True, # Test with distribute=True
3261+
False, # Test with distribute=False
3262+
],
3263+
)
3264+
def test_add_file(self, distribute_arg):
32573265
# Arrange
32583266
python_engine = python.Engine()
32593267

32603268
file = None
32613269

32623270
# Act
3263-
result = python_engine.add_file(file=file)
3271+
if distribute_arg is None:
3272+
# Call without distribute argument
3273+
result = python_engine.add_file(file=file)
3274+
else:
3275+
# Call with distribute argument
3276+
result = python_engine.add_file(file=file, distribute=distribute_arg)
32643277

32653278
# Assert
32663279
assert result == file

python/tests/engine/test_spark.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4379,8 +4379,17 @@ def test_read_stream_kafka_message_format_avro_include_metadata(self, mocker):
43794379
assert "name" in result.columns
43804380
assert result.schema["name"].dataType == StringType()
43814381

4382-
def test_add_file(self, mocker):
4382+
@pytest.mark.parametrize(
4383+
"distribute_arg",
4384+
[
4385+
None, # Test without providing distribute argument (uses default)
4386+
True, # Test with distribute=True
4387+
False, # Test with distribute=False
4388+
],
4389+
)
4390+
def test_add_file(self, mocker, distribute_arg):
43834391
# Arrange
4392+
mock_dataset_api = mocker.patch("hsfs.core.dataset_api.DatasetApi")
43844393
mock_pyspark_files_get = mocker.patch("pyspark.files.SparkFiles.get")
43854394
mocker.patch("hopsworks_common.client._is_external", return_value=False)
43864395
mocker.patch("shutil.copy")
@@ -4389,14 +4398,31 @@ def test_add_file(self, mocker):
43894398

43904399
spark_engine = spark.Engine()
43914400

4401+
# Mock dataset API and file I/O for distribute=False case
4402+
if distribute_arg is False:
4403+
mock_dataset_api.return_value.read_content.return_value.content = bytes()
4404+
mocker.patch("builtins.open", mocker.mock_open())
4405+
43924406
# Act
4393-
spark_engine.add_file(
4394-
file="test_file",
4395-
)
4407+
if distribute_arg is None:
4408+
# Call without distribute argument
4409+
spark_engine.add_file(file="test_file")
4410+
else:
4411+
# Call with distribute argument
4412+
spark_engine.add_file(file="test_file", distribute=distribute_arg)
43964413

43974414
# Assert
4398-
mock_add_file.assert_called_once_with("hdfs://test_file")
4399-
mock_pyspark_files_get.assert_called_once_with("test_file")
4415+
if distribute_arg is False:
4416+
# When distribute=False, read_content should be called once
4417+
mock_dataset_api.return_value.read_content.assert_called_once()
4418+
# addFile and SparkFiles.get should NOT be called
4419+
mock_add_file.assert_not_called()
4420+
mock_pyspark_files_get.assert_not_called()
4421+
else:
4422+
# When distribute is True or None (default), addFile should be called
4423+
mock_dataset_api.return_value.read_content.assert_not_called()
4424+
mock_add_file.assert_called_once_with("hdfs://test_file")
4425+
mock_pyspark_files_get.assert_called_once_with("test_file")
44004426

44014427
def test_add_file_if_present_in_job_configuration(self, mocker):
44024428
# Arrange

0 commit comments

Comments
 (0)