Skip to content

Commit 7edb5d6

Browse files
committed
fix typos & simplify model loader
1 parent 38cf2df commit 7edb5d6

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

python/sgl_jax/srt/model_loader/loader.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import huggingface_hub
88
import jax
99
import jax.numpy as jnp
10-
from flax import nnx
1110
import numpy as np
11+
from flax import nnx
1212

1313
from sgl_jax.srt.configs.load_config import LoadConfig, LoadFormat
1414
from sgl_jax.srt.configs.model_config import ModelConfig
@@ -53,7 +53,7 @@ def init_new(cls, model_config: ModelConfig):
5353
)
5454

5555
def __init__(
56-
self, load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh
56+
self, load_config: LoadConfig, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
5757
):
5858
super().__init__(load_config)
5959
self.rng = rngs
@@ -101,7 +101,7 @@ def create_model(rng: nnx.Rngs):
101101
nnx.update(model, sharded_state)
102102
return model
103103

104-
with self.mesh:
104+
with jax.sharding.use_mesh(self.mesh):
105105
model = create_model(self.rng)
106106

107107
rng_key = self.rng.default.key.value
@@ -160,15 +160,15 @@ class JAXDummyModelLoader(BaseModelLoader):
160160
"""Model loader that will set model weights to random values for JAX models."""
161161

162162
def __init__(
163-
self, load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh
163+
self, load_config: LoadConfig, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
164164
):
165165
super().__init__(load_config)
166166
if load_config.model_loader_extra_config:
167167
raise ValueError(
168168
f"Model loader extra config is not supported for "
169169
f"load format {load_config.load_format}"
170170
)
171-
self.rng = rngs
171+
self.rngs = rngs
172172
self.mesh = mesh
173173

174174
def download_model(self, model_config: ModelConfig) -> None:
@@ -204,28 +204,27 @@ def _np_gen_dtype(jdtype) -> np.dtype:
204204
return np.float16
205205
if jdtype == jnp.bfloat16:
206206
return np.float32
207-
return {jnp.float16: np.float16,
208-
jnp.float32: np.float32,
209-
jnp.float64: np.float64}.get(jdtype, np.float32)
210-
211-
def _init_leaf(x, pspec):
212-
if isinstance(x, jax.Array) and jnp.issubdtype(x.dtype, jnp.floating):
213-
tgt_dtype = x.dtype
207+
return {
208+
jnp.float16: np.float16,
209+
jnp.float32: np.float32,
210+
jnp.float64: np.float64,
211+
}.get(jdtype, np.float32)
212+
213+
def _init_leaf(x):
214+
if isinstance(x, nnx.Param) and jnp.issubdtype(x.value.dtype, jnp.floating):
215+
tgt_dtype = x.value.dtype
214216
gen_dtype = _np_gen_dtype(tgt_dtype)
215-
numel = int(np.prod(x.shape))
217+
numel = int(np.prod(x.value.shape))
216218

217219
# Per-parameter reseed (shape-agnostic stream)
218220
rng = np.random.default_rng(seed)
219221
flat = rng.uniform(low, high, size=(numel,)).astype(gen_dtype)
220-
arr_np = flat.reshape(x.shape)
221-
222-
arr_jax = jnp.asarray(arr_np, dtype=tgt_dtype)
223-
if pspec is not None:
224-
arr_jax = jax.lax.with_sharding_constraint(arr_jax, pspec)
225-
return arr_jax
222+
x.value = jnp.asarray(flat.reshape(x.shape), dtype=tgt_dtype)
226223
return x
227224

228-
new_params = jax.tree_util.tree_map(_init_leaf, params, pspecs)
225+
new_params = jax.tree.map(
226+
_init_leaf, params, is_leaf=lambda x: isinstance(x, nnx.Param)
227+
)
229228

230229
# Do not alter rotary embedding caches
231230
def _preserve_rope_caches(path, old, new):
@@ -235,15 +234,16 @@ def _preserve_rope_caches(path, old, new):
235234
return old
236235
return new
237236

238-
new_params = jax.tree_util.tree_map_with_path(_preserve_rope_caches, params, new_params)
237+
new_params = jax.tree.map_with_path(_preserve_rope_caches, params, new_params)
239238
nnx.update(model, new_params)
240239
print(f"new_params: {new_params}")
240+
241241
# Print out the shape of each param in new_params
242242
def _print_shape(path, x):
243243
if isinstance(x, jax.Array):
244244
print(f"Param {path}: shape={x.shape}")
245-
jax.tree_util.tree_map_with_path(_print_shape, new_params)
246245

246+
jax.tree.map_with_path(_print_shape, new_params)
247247

248248
def load_model(
249249
self,
@@ -253,24 +253,22 @@ def load_model(
253253
# Initialize JAX model definition on mesh
254254
model_class = self._initialize_model(model_config)
255255

256-
def create_model(rng: nnx.Rngs):
257-
model = model_class(model_config.hf_config, rng, self.mesh)
256+
with jax.sharding.use_mesh(self.mesh):
257+
model = model_class(model_config.hf_config, self.rngs, self.mesh)
258+
# self._initialize_dummy_weights(model)
258259
state = nnx.state(model)
259260
pspecs = nnx.get_partition_spec(state)
260261
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
261262
nnx.update(model, sharded_state)
262-
return model
263263

264-
with self.mesh:
265-
model = create_model(self.rng)
266-
# Assign random weights deterministically
267-
self._initialize_dummy_weights(model)
264+
for dev in jax.local_devices():
265+
print(dev, dev.memory_stats())
268266

269267
return model
270268

271269

272270
def get_model_loader(
273-
load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh
271+
load_config: LoadConfig, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
274272
) -> BaseModelLoader:
275273
"""Get a model loader based on the load format."""
276274

0 commit comments

Comments
 (0)