-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
import jax
import jax.numpy as jnp
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print()
jax_gpu_array = jax.device_put(jnp.array([1.0, 2.0, 3.0, 4.0], dtype=jnp.float32), jax.devices("gpu")[0])
# Test copy=True (expected to fail with kMutableZeroCopy error)
try:
result = jax.dlpack.from_dlpack(jax_gpu_array, copy=True)
print(f"from_dlpack(copy=True) SUCCEEDED: {result}")
except Exception as e:
print(f"from_dlpack(copy=True) FAILED: {type(e).__name__}: {e}")
# Test copy=False (expected to work)
try:
result = jax.dlpack.from_dlpack(jax_gpu_array, copy=None)
print(f"from_dlpack(copy=False) SUCCEEDED: {result}")
except Exception as e:
print(f"from_dlpack(copy=False) FAILED: {type(e).__name__}: {e}")
# Workaround: copy=False then jnp.array with copy=True
try:
result_nocopy = jax.dlpack.from_dlpack(jax_gpu_array, copy=False)
result_copied = jnp.array(result_nocopy, copy=True)
print(f"Workaround (copy=False + jnp.array copy=True) SUCCEEDED: {result_copied}")
except Exception as e:
print(f"Workaround FAILED: {type(e).__name__}: {e}")output
JAX version: 0.8.1
Devices: [CudaDevice(id=0)]
from_dlpack(copy=True) FAILED: JaxRuntimeError: UNIMPLEMENTED: PJRT C API does not support HostBufferSemantics other than HostBufferSemantics::kImmutableOnlyDuringCall, HostBufferSemantics::kImmutableZeroCopy and HostBufferSemantics::kImmutableUntilTransferCompletes.
from_dlpack(copy=False) SUCCEEDED: [1. 2. 3. 4.]
Workaround (copy=False + jnp.array copy=True) SUCCEEDED: [1. 2. 3. 4.]
i believe this is introduced from ec1e65e where copy=True code path changed. i am able to temporarily workaround and do old code path with following code
result_nocopy = jax.dlpack.from_dlpack(jax_gpu_array, copy=False)
result_copied = jnp.array(result_nocopy, copy=True)System info (python version, jaxlib version, accelerator, etc.)
jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.2.6
python: 3.11.9 (main, Aug 14 2024, 05:07:28) [Clang 18.1.8 ]
device info: NVIDIA H100 80GB HBM3-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='h100-reserved-192-143', release='6.5.13-65-650-4141-22041-coreweave-amd64-85c45edc', version='#1 SMP PREEMPT_DYNAMIC Mon Oct 14 20:37:13 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon Dec 8 01:41:28 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 |
| N/A 28C P0 84W / 700W | 555MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3575769 C python 546MiB |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working