Skip to content

Commit 987a025

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Update JAX tests for separate input_striding and input_tiling on transposes.
This updates JAX after an PJRT API change. PiperOrigin-RevId: 846260872
1 parent e2815d5 commit 987a025

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

jaxlib/callback.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ absl::Status CpuCallback::PrepareAndCall(void** result, void** arg_ptrs) {
104104
xla::primitive_util::ByteWidth(results_[i].type);
105105
options.dims = dims;
106106
options.permutation = results_[i].reversed_layout;
107-
options.input_layout = xla::TransposePlan::Striding{strides};
107+
options.input_striding = xla::TransposePlan::Striding{strides};
108108
absl::StatusOr<std::shared_ptr<xla::TransposePlan>> plan =
109109
transpose_cache_.GetOrCreate(options);
110110
if (!plan.ok()) {

jaxlib/gpu/py_client_gpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
217217
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
218218
reversed_layout.begin());
219219
options.permutation = reversed_layout;
220-
options.input_layout = xla::TransposePlan::Striding{strides};
220+
options.input_striding = xla::TransposePlan::Striding{strides};
221221
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
222222
if (!maybe_plan.ok()) {
223223
return xla::ffi::Error::Internal(maybe_plan.status().ToString());

jaxlib/py_client_cpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ ffi::Error XlaFfiPythonCpuCallback(xla::FfiLoadedHostCallbacks* callbacks,
173173
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
174174
reversed_layout.begin());
175175
options.permutation = reversed_layout;
176-
options.input_layout = xla::TransposePlan::Striding{strides};
176+
options.input_striding = xla::TransposePlan::Striding{strides};
177177
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
178178
if (!maybe_plan.ok()) {
179179
return ffi::Error::Internal(maybe_plan.status().ToString());

0 commit comments

Comments
 (0)