@@ -120,6 +120,27 @@ async def sample(
120120 entry .sample_count += 1
121121 sampled_episodes .append (entry .data )
122122
123+ # Calculate and record policy age metrics for sampled episodes
124+ sampled_policy_ages = [
125+ curr_policy_version - ep .policy_version for ep in sampled_episodes
126+ ]
127+ if sampled_policy_ages :
128+ record_metric (
129+ "buffer/sample/avg_sampled_policy_age" ,
130+ sum (sampled_policy_ages ) / len (sampled_policy_ages ),
131+ Reduce .MEAN ,
132+ )
133+ record_metric (
134+ "buffer/sample/max_sampled_policy_age" ,
135+ max (sampled_policy_ages ),
136+ Reduce .MAX ,
137+ )
138+ record_metric (
139+ "buffer/sample/min_sampled_policy_age" ,
140+ min (sampled_policy_ages ),
141+ Reduce .MIN ,
142+ )
143+
123144 # Reshape into (dp_size, bsz, ...)
124145 reshaped_episodes = [
125146 sampled_episodes [dp_idx * self .batch_size : (dp_idx + 1 ) * self .batch_size ]
@@ -149,22 +170,6 @@ def _evict(self, curr_policy_version):
149170 )
150171 self .buffer = deque (self ._collect (indices ))
151172
152- # Record evict metrics
153- policy_age = [
154- curr_policy_version - ep .data .policy_version for ep in self .buffer
155- ]
156- if policy_age :
157- record_metric (
158- "buffer/evict/avg_policy_age" ,
159- sum (policy_age ) / len (policy_age ),
160- Reduce .MEAN ,
161- )
162- record_metric (
163- "buffer/evict/max_policy_age" ,
164- max (policy_age ),
165- Reduce .MAX ,
166- )
167-
168173 evicted_count = buffer_len_before_evict - len (self .buffer )
169174 record_metric ("buffer/evict/sum_episodes_evicted" , evicted_count , Reduce .SUM )
170175
0 commit comments