Skip to content

Commit 69a6cca

Browse files
[mpmd] Fix stage_id field nanobind compatability
.none() annotations needed for optional parameters to allow accepting None from Python. PiperOrigin-RevId: 846253556
1 parent 315bb93 commit 69a6cca

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jaxlib/sdy_mpmd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ NB_MODULE(_sdy_mpmd, m) {
121121
std::optional<int>,
122122
std::optional<mlir::mpmd::SplitFragmentType>,
123123
const std::string&>(),
124-
nb::arg("origins"), nb::arg("stage_id"),
124+
nb::arg("origins"), nb::arg("stage_id").none() = std::nullopt,
125125
nb::arg("call_counter").none() = std::nullopt,
126126
nb::arg("split_type").none() = std::nullopt, nb::arg("mesh_name"))
127127
.def_ro("origins", &FragmentInfo::origins)

0 commit comments

Comments
 (0)