Skip to content

Conversation

@copybara-service
Copy link

[mpmd] Add custom pipeline API for user-defined scheduling and merging

Introduces a new API for defining custom MPMD pipeline schedules and fragment
merging logic using Python predicates.

How to Use:

The API centers around the PipelineSchedule object. Users can either define
their own custom schedules or use pre-built ones:

Custom Schedules:
There are two approaches for defining custom scheduling and merging logic:

  1. Predicate-based (recommended for simple patterns): Define scheduling and
    merging logic using binary predicate functions. Helper functions like
    build_schedule_rules_from_predicate and build_merge_rules_from_predicate
    create RuleBuilders from your predicates. See the docstring in
    pipeline.py for a simple example.

  2. Direct construction (for complex custom schedules): Directly build execution
    order and merge rules for full control over fragment scheduling. This
    approach is useful for advanced patterns that may not neatly translate to a
    predicate. See the docstring in pipeline.py for a comparative example.

Pre-built Schedules:
Pre-defined common pipeline schedules (e.g., GPIPE, ONE_FWD_ONE_BWD) are
available through the pipeline_registry module. The get_pipeline_schedule
function returns a PipelineSchedule object for a given name. See the docstring
in pipeline_registry.py for usage.

In all cases, the resulting PipelineSchedule object is passed to
mpmd.jit via MpmdConfig - this integration is done in a subsequent CL.

Introduces a new API for defining custom MPMD pipeline schedules and fragment
merging logic using Python predicates.

How to Use:

The API centers around the PipelineSchedule object. Users can either define
their own custom schedules or use pre-built ones:

Custom Schedules:
There are two approaches for defining custom scheduling and merging logic:

1.  Predicate-based (recommended for simple patterns): Define scheduling and
    merging logic using binary predicate functions. Helper functions like
    `build_schedule_rules_from_predicate` and `build_merge_rules_from_predicate`
    create `RuleBuilders` from your predicates. See the docstring in
    `pipeline.py` for a simple example.

2.  Direct construction (for complex custom schedules): Directly build execution
    order and merge rules for full control over fragment scheduling. This
    approach is useful for advanced patterns that may not neatly translate to a
    predicate. See the docstring in `pipeline.py` for a comparative example.

Pre-built Schedules:
Pre-defined common pipeline schedules (e.g., GPIPE, ONE_FWD_ONE_BWD) are
available through the `pipeline_registry` module. The `get_pipeline_schedule`
function returns a `PipelineSchedule` object for a given name. See the docstring
in pipeline_registry.py for usage.

In all cases, the resulting `PipelineSchedule` object is passed to
`mpmd.jit` via `MpmdConfig` - this integration is done in a subsequent CL.

PiperOrigin-RevId: 832290784
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants