Skip to content

Commit b9291e5

Browse files
emilyfertigGoogle-ML-Automation
authored andcommitted
Raise an error instead of crashing when get_on_device_size_in_bytes is called on a nonaddressable array.
This method will be supported in the future, but in the meantime this fails more gracefully. PiperOrigin-RevId: 799784831
1 parent 35d632e commit b9291e5

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

jaxlib/py_array.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,15 @@ absl::StatusOr<size_t> PyArray::GetOnDeviceSizeInBytes() {
809809
return xla::InvalidArgument(
810810
"GetOnDeviceSizeInBytes() called on deleted or donated buffer");
811811
}
812-
812+
// TODO(emilyaf): Support this method for non-addressable arrays by calling
813+
// py_client()->pjrt_client()->GetOnDeviceBytesCount once all clients
814+
// implement it.
815+
if (ifrt_array()->sharding().devices()->AddressableDeviceList()->size() ==
816+
0) {
817+
return xla::Unimplemented(
818+
"GetOnDeviceSizeInBytes() is not yet supported for arrays with no "
819+
"addressable devices");
820+
}
813821
TF_ASSIGN_OR_RETURN(size_t shard_size,
814822
GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes());
815823
return shard_size * nb::len(nb::object(sharding().attr("device_set")));

0 commit comments

Comments
 (0)