77import huggingface_hub
88import jax
99import jax .numpy as jnp
10- from flax import nnx
1110import numpy as np
11+ from flax import nnx
1212
1313from sgl_jax .srt .configs .load_config import LoadConfig , LoadFormat
1414from sgl_jax .srt .configs .model_config import ModelConfig
@@ -53,7 +53,7 @@ def init_new(cls, model_config: ModelConfig):
5353 )
5454
5555 def __init__ (
56- self , load_config : LoadConfig , rngs : jax . Array , mesh : jax .sharding .Mesh
56+ self , load_config : LoadConfig , rngs : nnx . Rngs , mesh : jax .sharding .Mesh
5757 ):
5858 super ().__init__ (load_config )
5959 self .rng = rngs
@@ -101,7 +101,7 @@ def create_model(rng: nnx.Rngs):
101101 nnx .update (model , sharded_state )
102102 return model
103103
104- with self .mesh :
104+ with jax . sharding . use_mesh ( self .mesh ) :
105105 model = create_model (self .rng )
106106
107107 rng_key = self .rng .default .key .value
@@ -160,15 +160,15 @@ class JAXDummyModelLoader(BaseModelLoader):
160160 """Model loader that will set model weights to random values for JAX models."""
161161
162162 def __init__ (
163- self , load_config : LoadConfig , rngs : jax . Array , mesh : jax .sharding .Mesh
163+ self , load_config : LoadConfig , rngs : nnx . Rngs , mesh : jax .sharding .Mesh
164164 ):
165165 super ().__init__ (load_config )
166166 if load_config .model_loader_extra_config :
167167 raise ValueError (
168168 f"Model loader extra config is not supported for "
169169 f"load format { load_config .load_format } "
170170 )
171- self .rng = rngs
171+ self .rngs = rngs
172172 self .mesh = mesh
173173
174174 def download_model (self , model_config : ModelConfig ) -> None :
@@ -204,28 +204,27 @@ def _np_gen_dtype(jdtype) -> np.dtype:
204204 return np .float16
205205 if jdtype == jnp .bfloat16 :
206206 return np .float32
207- return {jnp .float16 : np .float16 ,
208- jnp .float32 : np .float32 ,
209- jnp .float64 : np .float64 }.get (jdtype , np .float32 )
210-
211- def _init_leaf (x , pspec ):
212- if isinstance (x , jax .Array ) and jnp .issubdtype (x .dtype , jnp .floating ):
213- tgt_dtype = x .dtype
207+ return {
208+ jnp .float16 : np .float16 ,
209+ jnp .float32 : np .float32 ,
210+ jnp .float64 : np .float64 ,
211+ }.get (jdtype , np .float32 )
212+
213+ def _init_leaf (x ):
214+ if isinstance (x , nnx .Param ) and jnp .issubdtype (x .value .dtype , jnp .floating ):
215+ tgt_dtype = x .value .dtype
214216 gen_dtype = _np_gen_dtype (tgt_dtype )
215- numel = int (np .prod (x .shape ))
217+ numel = int (np .prod (x .value . shape ))
216218
217219 # Per-parameter reseed (shape-agnostic stream)
218220 rng = np .random .default_rng (seed )
219221 flat = rng .uniform (low , high , size = (numel ,)).astype (gen_dtype )
220- arr_np = flat .reshape (x .shape )
221-
222- arr_jax = jnp .asarray (arr_np , dtype = tgt_dtype )
223- if pspec is not None :
224- arr_jax = jax .lax .with_sharding_constraint (arr_jax , pspec )
225- return arr_jax
222+ x .value = jnp .asarray (flat .reshape (x .shape ), dtype = tgt_dtype )
226223 return x
227224
228- new_params = jax .tree_util .tree_map (_init_leaf , params , pspecs )
225+ new_params = jax .tree .map (
226+ _init_leaf , params , is_leaf = lambda x : isinstance (x , nnx .Param )
227+ )
229228
230229 # Do not alter rotary embedding caches
231230 def _preserve_rope_caches (path , old , new ):
@@ -235,15 +234,16 @@ def _preserve_rope_caches(path, old, new):
235234 return old
236235 return new
237236
238- new_params = jax .tree_util . tree_map_with_path (_preserve_rope_caches , params , new_params )
237+ new_params = jax .tree . map_with_path (_preserve_rope_caches , params , new_params )
239238 nnx .update (model , new_params )
240239 print (f"new_params: { new_params } " )
240+
241241 # Print out the shape of each param in new_params
242242 def _print_shape (path , x ):
243243 if isinstance (x , jax .Array ):
244244 print (f"Param { path } : shape={ x .shape } " )
245- jax .tree_util .tree_map_with_path (_print_shape , new_params )
246245
246+ jax .tree .map_with_path (_print_shape , new_params )
247247
248248 def load_model (
249249 self ,
@@ -253,24 +253,22 @@ def load_model(
253253 # Initialize JAX model definition on mesh
254254 model_class = self ._initialize_model (model_config )
255255
256- def create_model (rng : nnx .Rngs ):
257- model = model_class (model_config .hf_config , rng , self .mesh )
256+ with jax .sharding .use_mesh (self .mesh ):
257+ model = model_class (model_config .hf_config , self .rngs , self .mesh )
258+ # self._initialize_dummy_weights(model)
258259 state = nnx .state (model )
259260 pspecs = nnx .get_partition_spec (state )
260261 sharded_state = jax .lax .with_sharding_constraint (state , pspecs )
261262 nnx .update (model , sharded_state )
262- return model
263263
264- with self .mesh :
265- model = create_model (self .rng )
266- # Assign random weights deterministically
267- self ._initialize_dummy_weights (model )
264+ for dev in jax .local_devices ():
265+ print (dev , dev .memory_stats ())
268266
269267 return model
270268
271269
272270def get_model_loader (
273- load_config : LoadConfig , rngs : jax . Array , mesh : jax .sharding .Mesh
271+ load_config : LoadConfig , rngs : nnx . Rngs , mesh : jax .sharding .Mesh
274272) -> BaseModelLoader :
275273 """Get a model loader based on the load format."""
276274
0 commit comments