diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index ae3d6ddfcad0..6e82d995b782 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -132,6 +132,34 @@ Cloud Storage (GCS) bucket. We recommend the following configuration: * All encryption policies are supported. +It is **recommended** to use +[Google Cloud Storage Fuse](https://cloud.google.com/storage/docs/cloud-storage-fuse) +to mount the GCS bucket as a local directory. This is because when running JAX +in a multi-node setup, multiple nodes might try to write to the cache +simultaneously, leading to GCS rate-limit errors. GCSFuse handles this by +ensuring that only one process can write to a file at a time, preventing these +errors. + +To set up GCSFuse, follow instructions for +[GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/mount-bucket) or +[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-setup). +For better performance, enable file caching +([GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/file-caching) and +[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-perf#enable-and-use-file-caching)). + +Once GCSFuse is configured, set the JAX cache directory to the GCSFuse mount +point: + +```python +# Example assuming the GCS bucket is mounted at /gcs/my-bucket +jax.config.update("jax_compilation_cache_dir", "/gcs/my-bucket/jax-cache") +``` + +**Direct GCS access :** + +If you choose not to use GCSFuse, you can point the cache directly to a GCS +bucket. + Assuming that `gs://jax-cache` is the GCS bucket, set cache location as follows: