@@ -81,34 +81,8 @@ def __init__(self, config: DictConfig):
8181 self .gradient_accumulation_steps = 1 # Example value, adjust as needed
8282 self ._rank = current_rank ().rank
8383 self ._size = math .prod (current_size ().values ())
84- self ._init_dist ()
8584 super ().__init__ (job_config )
8685
87- def _init_dist (self ):
88- """Initializes torch distributed.
89-
90- torchrun normally hands this, but we need to do it ourselves
91- in monarch for now.
92-
93- We should consider putting this into ForgeActor, but having this
94- be explicit for now.
95-
96- """
97- env = {
98- "RANK" : str (self ._rank ),
99- "LOCAL_RANK" : str (self ._rank ),
100- "LOCAL_WORLD_SIZE" : str (self ._size ),
101- "GROUP_RANK" : str (self ._size ),
102- "GROUP_WORLD_SIZE" : str (self ._size ),
103- "ROLE_RANK" : str (self ._rank ),
104- "ROLE_WORLD_SIZE" : str (self ._size ),
105- "ROLE_NAME" : "rank" ,
106- "WORLD_SIZE" : str (self ._size ),
107- "PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True" ,
108- }
109- os .environ .update (env )
110- logger .info ("env: {}" .format (env ))
111-
11286 async def setup_metric_logger (self ):
11387 """Initialization happens in the main process. Here we just retrieve it"""
11488 mlogger = await get_or_create_metric_logger ()
0 commit comments