diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index bf65142ad93..93527ed65ae 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -1007,6 +1007,52 @@ def simple_pipeline(img: str): self.assertTrue('base_image' in input_parameters) self.assertTrue('pipelinechannel--img' in input_parameters) + def test_pipeline_with_parameterized_container_image_inside_parallel_for( + self): + with tempfile.TemporaryDirectory() as tmpdir: + + @dsl.component(base_image='docker.io/python:3.11.17') + def empty_component(idx: int): + del idx + + @dsl.pipeline() + def simple_pipeline(img: str): + with dsl.ParallelFor(items=[1, 2]) as item: + task = empty_component(idx=item) + task.set_container_image(img) + + output_yaml = os.path.join(tmpdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=simple_pipeline, + package_path=output_yaml, + pipeline_parameters={'img': 'someimage'}) + + self.assertTrue(os.path.exists(output_yaml)) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + loop_components = { + name: comp + for name, comp in pipeline_spec['components'].items() + if name.startswith('comp-for-loop') + } + self.assertTrue( + loop_components, + 'Expected to find at least one ParallelFor component in the pipeline spec' + ) + + loop_component = next(iter(loop_components.values())) + input_parameters = loop_component['inputDefinitions']['parameters'] + self.assertIn('pipelinechannel--img', input_parameters) + + loop_tasks = loop_component['dag']['tasks'] + self.assertIn('empty-component', loop_tasks) + loop_task_inputs = loop_tasks['empty-component']['inputs'][ + 'parameters'] + self.assertIn('base_image', loop_task_inputs) + self.assertIn('pipelinechannel--img', loop_task_inputs) + def test_pipeline_with_constant_container_image(self): with tempfile.TemporaryDirectory() as tmpdir: diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 1975dba22fd..0cb46cadeea 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -659,9 +659,22 @@ def set_container_image( Self return to allow chained setting calls. """ self._ensure_container_spec_exists() + pipeline_channels = pipeline_channel.extract_pipeline_channels_from_any( + name) + if isinstance(name, pipeline_channel.PipelineChannel): name = str(name) + self.container_spec.image = name + + if pipeline_channels: + existing_channel_patterns = { + channel.pattern for channel in self._channel_inputs + } + for channel in pipeline_channels: + if channel.pattern not in existing_channel_patterns: + self._channel_inputs.append(channel) + existing_channel_patterns.add(channel.pattern) return self @block_if_final()