Skip to content

Commit f265496

Browse files
author
Orbax Authors
committed
Add additional support for TENSORSTORE_GCS_BACKEND environment variable
This allows the user to configure the Tensorstore GCS backend (i.e. with `gcs` for http or `gcs_grpc` for grpc) for additional places in the code that did not previously have support, including when using ocdbt. PiperOrigin-RevId: 809285436
1 parent 535542d commit f265496

File tree

6 files changed

+154
-55
lines changed

6 files changed

+154
-55
lines changed

checkpoint/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919

2020
- #v1 Modify LeafHandler definitions so that `AbstractLeaf` or
2121
`Type[AbstractLeaf]` are always accepted as valid abstract values.
22+
- Configuring the `TENSORSTORE_GCS_BACKEND` environment variable is now
23+
supported for additional locations in the code, notably when using ocdbt.
2224

2325
## [0.11.25] - 2025-09-10
2426

checkpoint/orbax/checkpoint/_src/path/gcs_utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Utils for interacting with GCS paths."""
1616

1717
import functools
18+
import os
1819
from urllib import parse
1920
from etils import epath
2021

@@ -26,19 +27,55 @@ def is_gcs_path(path: epath.Path) -> bool:
2627
return path.as_posix().startswith(_GCS_PATH_PREFIX)
2728

2829

29-
def parse_gcs_path(path: epath.PathLike) -> tuple[str, str]:
30+
def parse_gcs_path(
31+
path: epath.PathLike, add_trailing_slash: bool = True
32+
) -> tuple[str, str]:
33+
"""Parses a GCS path into bucket and path within the bucket.
34+
35+
Args:
36+
path: The GCS path to parse (e.g., "gs://my-bucket/path/to/object").
37+
add_trailing_slash: Whether to ensure the returned path has a trailing
38+
slash.
39+
40+
Returns:
41+
A tuple containing the bucket name and the path within the bucket.
42+
"""
3043
parsed = parse.urlparse(str(path))
3144
assert parsed.scheme == 'gs', f'Unsupported scheme for GCS: {parsed.scheme}'
3245
# Strip the leading slash from the path.
3346
standardized_path = parsed.path
3447
if standardized_path.startswith('/'):
3548
standardized_path = standardized_path[1:]
3649
# Add a trailing slash if it's missing.
37-
if not standardized_path.endswith('/'):
50+
if add_trailing_slash and not standardized_path.endswith('/'):
3851
standardized_path = standardized_path + '/'
3952
return parsed.netloc, standardized_path
4053

4154

55+
def get_kvstore_for_gcs(ckpt_path: str):
56+
"""Constructs a TensorStore kvstore spec for a GCS path.
57+
58+
Args:
59+
ckpt_path: A GCS path of the form gs://<bucket>/<path>.
60+
61+
Returns:
62+
A dictionary containing the TensorStore kvstore spec.
63+
64+
Raises:
65+
ValueError: if ckpt_path is not a valid GCS path.
66+
"""
67+
gcs_bucket, path_without_bucket = parse_gcs_path(
68+
ckpt_path, add_trailing_slash=False
69+
)
70+
# TODO(stoelinga): Switch to gcs_grpc by default.
71+
# gcs_grpc performs roughly twice as fast as gcs backend.
72+
gcs_backend = os.environ.get('TENSORSTORE_GCS_BACKEND', 'gcs')
73+
spec = {'driver': gcs_backend, 'bucket': gcs_bucket}
74+
if path_without_bucket:
75+
spec['path'] = path_without_bucket
76+
return spec
77+
78+
4279
@functools.lru_cache(maxsize=32)
4380
def get_bucket(bucket_name: str):
4481
# pylint: disable=g-import-not-at-top

checkpoint/orbax/checkpoint/_src/serialization/serialization.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from collections.abc import Mapping
2020
import contextlib
2121
import os
22-
import re
2322
from typing import Any, AsyncIterator, Dict, Optional, Protocol, Sequence, Union
2423

2524
from absl import logging
@@ -32,6 +31,7 @@
3231
from orbax.checkpoint._src.arrays import numpy_utils as np_utils
3332
from orbax.checkpoint._src.arrays import types
3433
from orbax.checkpoint._src.multihost import multihost
34+
from orbax.checkpoint._src.path import gcs_utils
3535
from orbax.checkpoint._src.serialization import replica_slices
3636
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
3737
import tensorstore as ts
@@ -65,36 +65,6 @@ def _spec_has_metadata(tree):
6565
)
6666

6767

68-
def _get_kvstore_for_gcs(ckpt_path: str):
69-
"""Constructs a TensorStore kvstore spec for a GCS path.
70-
71-
Args:
72-
ckpt_path: A GCS path of the form gs://<bucket>/<path>.
73-
74-
Returns:
75-
A dictionary containing the TensorStore kvstore spec.
76-
77-
Raises:
78-
ValueError: if ckpt_path is not a valid GCS path.
79-
"""
80-
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
81-
if m is None:
82-
raise ValueError(
83-
'The ckpt_path should contain the bucket name and the '
84-
f'file path inside the bucket. Got: {ckpt_path}'
85-
)
86-
gcs_bucket = m.group(1)
87-
path_without_bucket = m.group(2)
88-
# TODO(stoelinga): Switch to gcs_grpc by default.
89-
# gcs_grpc performs roughly twice as fast as gcs backend.
90-
gcs_backend = os.environ.get('TENSORSTORE_GCS_BACKEND', 'gcs')
91-
return {
92-
'driver': f'{gcs_backend}',
93-
'bucket': gcs_bucket,
94-
'path': path_without_bucket,
95-
}
96-
97-
9868
def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
9969
"""Constructs a TensorStore spec for the given checkpoint path."""
10070
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
@@ -107,7 +77,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
10777
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
10878
base_path = os.path.dirname(ckpt_path)
10979
base_driver_spec = (
110-
base_path
80+
gcs_utils.get_kvstore_for_gcs(base_path)
11181
if is_gcs_path
11282
else {'driver': ts_utils.DEFAULT_DRIVER, 'path': base_path}
11383
)
@@ -118,7 +88,7 @@ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
11888
}
11989
else:
12090
if is_gcs_path:
121-
spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
91+
spec['kvstore'] = gcs_utils.get_kvstore_for_gcs(ckpt_path)
12292
else:
12393
spec['kvstore'] = {'driver': ts_utils.DEFAULT_DRIVER, 'path': ckpt_path}
12494

checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,14 @@ def test_get_tensorstore_spec_ocdbt(self, path):
475475
spec = serialization.get_tensorstore_spec(path, ocdbt=True)
476476
is_gcs_path = path.startswith('gs://')
477477
if is_gcs_path:
478-
self.assertEqual(spec['kvstore']['base'], os.path.dirname(path))
478+
self.assertEqual(
479+
spec['kvstore']['base'],
480+
{
481+
'driver': 'gcs',
482+
'bucket': 'my',
483+
'path': 'ckpt/dir',
484+
},
485+
)
479486
else:
480487
self.assertEqual(
481488
spec['kvstore']['base'],
@@ -493,6 +500,52 @@ def test_get_tensorstore_spec_not_absolute_path(self):
493500
):
494501
serialization.get_tensorstore_spec(path, ocdbt=True)
495502

503+
@parameterized.named_parameters(
504+
dict(testcase_name='none', backend=None, target_driver='gcs'),
505+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
506+
dict(
507+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
508+
),
509+
)
510+
def test_get_tensorstore_spec_ocdbt_grpc(self, backend, target_driver):
511+
if backend:
512+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
513+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
514+
spec = serialization.get_tensorstore_spec(
515+
'gs://my/ckpt/dir/path', ocdbt=True
516+
)
517+
self.assertEqual(
518+
spec['kvstore']['base'],
519+
{
520+
'driver': target_driver,
521+
'bucket': 'my',
522+
'path': 'ckpt/dir',
523+
},
524+
)
525+
526+
@parameterized.named_parameters(
527+
dict(testcase_name='none', backend=None, target_driver='gcs'),
528+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
529+
dict(
530+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
531+
),
532+
)
533+
def test_get_tensorstore_spec_grpc(self, backend, target_driver):
534+
if backend:
535+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
536+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
537+
spec = serialization.get_tensorstore_spec(
538+
'gs://my/ckpt/dir/path', ocdbt=False
539+
)
540+
self.assertEqual(
541+
spec['kvstore'],
542+
{
543+
'driver': target_driver,
544+
'bucket': 'my',
545+
'path': 'ckpt/dir/path',
546+
},
547+
)
548+
496549
def test_deserialization_with_int4(self):
497550
dtype = jnp.int4
498551
shape = (8, 2)

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from orbax.checkpoint._src.arrays import subchunking
2626
from orbax.checkpoint._src.arrays import types
2727
from orbax.checkpoint._src.metadata import array_metadata
28+
from orbax.checkpoint._src.path import gcs_utils
2829
import tensorstore as ts
2930

3031
JsonSpec: TypeAlias = dict[str, Any]
@@ -43,8 +44,6 @@
4344
ZARR_VER2 = 'zarr'
4445
ZARR_VER3 = 'zarr3'
4546

46-
_GCS_PATH_RE = r'^gs://([^/]*)/(.*)$'
47-
4847
# Even if the data is equal to the fill value, we still want to write it
4948
# to the checkpoint. This results in unnecessary writes in some edge
5049
# cases, but it allows us to verify that data was actually written when
@@ -111,18 +110,6 @@ def get_ts_context(
111110
### Building KvStore specs.
112111

113112

114-
def _get_kvstore_for_gcs(ckpt_path: str) -> JsonSpec:
115-
m = re.fullmatch(_GCS_PATH_RE, ckpt_path, re.DOTALL)
116-
if m is None:
117-
raise ValueError(
118-
'The ckpt_path should contain the bucket name and the '
119-
f'file path inside the bucket. Got: {ckpt_path}'
120-
)
121-
gcs_bucket = m.group(1)
122-
path_without_bucket = m.group(2)
123-
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
124-
125-
126113
def build_kvstore_tspec(
127114
directory: str,
128115
name: str | None = None,
@@ -165,7 +152,7 @@ def build_kvstore_tspec(
165152
directory, f'{PROCESS_SUBDIR_PREFIX}{process_id}'
166153
)
167154
base_driver_spec = (
168-
directory
155+
gcs_utils.get_kvstore_for_gcs(str(directory))
169156
if is_gcs_path
170157
else {'driver': default_driver, 'path': str(directory)}
171158
)
@@ -196,7 +183,7 @@ def build_kvstore_tspec(
196183
else:
197184
path = os.path.join(directory, name)
198185
if is_gcs_path:
199-
kv_spec = _get_kvstore_for_gcs(path)
186+
kv_spec = gcs_utils.get_kvstore_for_gcs(path)
200187
else:
201188
kv_spec = {'driver': default_driver, 'path': path}
202189

checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ def test_ocdbt_kvstore(
203203
dict(
204204
testcase_name='regular_path',
205205
directory='gs://gcs_bucket/object_path',
206-
expected_directory=None,
206+
expected_directory='object_path',
207207
),
208208
dict(
209209
testcase_name='path_with_single_slash',
210210
directory='gs:/gcs_bucket/object_path',
211-
expected_directory='gs://gcs_bucket/object_path',
211+
expected_directory='object_path',
212212
),
213213
)
214214
def test_ocdbt_kvstore_with_gcs_path(
@@ -228,10 +228,60 @@ def test_ocdbt_kvstore_with_gcs_path(
228228
self.assertEqual(kvstore_tspec['driver'], 'ocdbt')
229229
self.assertEqual(
230230
kvstore_tspec['base'],
231-
os.path.join(expected_directory or directory, 'ocdbt.process_0'),
231+
{
232+
'driver': 'gcs',
233+
'bucket': 'gcs_bucket',
234+
'path': os.path.join(
235+
expected_directory or directory, 'ocdbt.process_0'
236+
),
237+
},
232238
)
233239
self.assertEqual(kvstore_tspec['path'], self.param_name)
234240

241+
@parameterized.named_parameters(
242+
dict(testcase_name='none', backend=None, target_driver='gcs'),
243+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
244+
dict(
245+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
246+
),
247+
)
248+
def test_get_tensorstore_spec_ocdbt_grpc(self, backend, target_driver):
249+
if backend:
250+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
251+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
252+
spec = ts_utils.build_kvstore_tspec('gs://my/ckpt/dir/path', use_ocdbt=True)
253+
self.assertEqual(
254+
spec['base'],
255+
{
256+
'driver': target_driver,
257+
'bucket': 'my',
258+
'path': 'ckpt/dir/path',
259+
},
260+
)
261+
262+
@parameterized.named_parameters(
263+
dict(testcase_name='none', backend=None, target_driver='gcs'),
264+
dict(testcase_name='gcs', backend='gcs', target_driver='gcs'),
265+
dict(
266+
testcase_name='gcs_grpc', backend='gcs_grpc', target_driver='gcs_grpc'
267+
),
268+
)
269+
def test_get_tensorstore_spec_grpc(self, backend, target_driver):
270+
if backend:
271+
os.environ['TENSORSTORE_GCS_BACKEND'] = backend
272+
self.addCleanup(lambda: os.environ.pop('TENSORSTORE_GCS_BACKEND'))
273+
spec = ts_utils.build_kvstore_tspec(
274+
'gs://my/ckpt/dir/path', use_ocdbt=False
275+
)
276+
self.assertEqual(
277+
spec,
278+
{
279+
'driver': target_driver,
280+
'bucket': 'my',
281+
'path': 'ckpt/dir/path',
282+
},
283+
)
284+
235285
@parameterized.product(use_zarr3=(True, False))
236286
def test_ocdbt_kvstore_default_target_data_file_size(self, use_zarr3: bool):
237287
tspec = self.array_write_spec_constructor(

0 commit comments

Comments
 (0)