Skip to content

Commit e2dd4d1

Browse files
authored
Add metrics for sampled policy age in replay buffer (#524)
1 parent 8058089 commit e2dd4d1

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

src/forge/actors/replay_buffer.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,27 @@ async def sample(
120120
entry.sample_count += 1
121121
sampled_episodes.append(entry.data)
122122

123+
# Calculate and record policy age metrics for sampled episodes
124+
sampled_policy_ages = [
125+
curr_policy_version - ep.policy_version for ep in sampled_episodes
126+
]
127+
if sampled_policy_ages:
128+
record_metric(
129+
"buffer/sample/avg_sampled_policy_age",
130+
sum(sampled_policy_ages) / len(sampled_policy_ages),
131+
Reduce.MEAN,
132+
)
133+
record_metric(
134+
"buffer/sample/max_sampled_policy_age",
135+
max(sampled_policy_ages),
136+
Reduce.MAX,
137+
)
138+
record_metric(
139+
"buffer/sample/min_sampled_policy_age",
140+
min(sampled_policy_ages),
141+
Reduce.MIN,
142+
)
143+
123144
# Reshape into (dp_size, bsz, ...)
124145
reshaped_episodes = [
125146
sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]
@@ -149,22 +170,6 @@ def _evict(self, curr_policy_version):
149170
)
150171
self.buffer = deque(self._collect(indices))
151172

152-
# Record evict metrics
153-
policy_age = [
154-
curr_policy_version - ep.data.policy_version for ep in self.buffer
155-
]
156-
if policy_age:
157-
record_metric(
158-
"buffer/evict/avg_policy_age",
159-
sum(policy_age) / len(policy_age),
160-
Reduce.MEAN,
161-
)
162-
record_metric(
163-
"buffer/evict/max_policy_age",
164-
max(policy_age),
165-
Reduce.MAX,
166-
)
167-
168173
evicted_count = buffer_len_before_evict - len(self.buffer)
169174
record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM)
170175

0 commit comments

Comments
 (0)