Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pufferlib/config/ocean/adaptive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
package = ocean
env_name = puffer_adaptive_drive
policy_name = Drive
transformer_name = Transformer
; Changed from rnn_name
policy_architecture = Transformer

[vec]
num_workers = 16
Expand Down Expand Up @@ -60,15 +59,15 @@ k_scenarios = 2
termination_mode = 1
; 0 - terminate at episode_length, 1 - terminate after all agents have been reset
map_dir = "resources/drive/binaries/training"
num_maps = 1000
num_maps = 10000
; Determines which step of the trajectory to initialize the agents at upon reset
init_steps = 0
; Options: "control_vehicles", "control_agents", "control_wosac", "control_sdc_only"
control_mode = "control_vehicles"
; Options: "created_all_valid", "create_only_controlled"
init_mode = "create_all_valid"
; train with co players
co_player_enabled = False
co_player_enabled = True


[env.conditioning]
Expand Down Expand Up @@ -120,7 +119,7 @@ minibatch_size = 36400
; 400 * 91
max_minibatch_size = 36400
minibatch_multiplier = 400
policy_architecture = Transformer

; Matches scenario_length for buffer organization
bptt_horizon = 32
; Keep for backward compatibility
Expand Down
10 changes: 4 additions & 6 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
package = ocean
env_name = puffer_drive
policy_name = Drive
rnn_name = Transformer
policy_architecture = Transformer

[vec]
num_workers = 16
Expand All @@ -11,7 +11,7 @@ batch_size = 2
; backend = Serial

[policy]
input_size = 64
input_size = 128
hidden_size = 256

; [rnn]
Expand All @@ -25,13 +25,12 @@ num_layers = 2
; Number of transformer layers
num_heads = 4
; Number of attention heads (must divide hidden_size)
context_window = 32
; k_scenarios (2) * scenario_length (91) = maximum attention span
dropout = 0.0
; Dropout (keep at 0 for RL stability initially)

[env]
num_agents = 512
num_agents = 1024
num_ego_agents = 512
; Options: discrete, continuous
action_type = discrete
Expand Down Expand Up @@ -139,8 +138,7 @@ vf_coef = 2
vtrace_c_clip = 1
vtrace_rho_clip = 1
checkpoint_interval = 100
use_transformer = True
context_window = 32
context_length = 32
# Rendering options
render = True
render_interval = 100
Expand Down
2 changes: 1 addition & 1 deletion pufferlib/ocean/benchmark/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def rollout(self, args, puffer_env, policy):
lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device),
)
elif policy_architecture == "Transformer":
context_length = args["train"].get("context_window", 182)
context_length = args["train"].get("context_length", 182)
state = dict(
transformer_context=torch.zeros(num_agents, context_length, policy.hidden_size, device=device),
transformer_position=torch.zeros(1, dtype=torch.long, device=device),
Expand Down
53 changes: 32 additions & 21 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ def __init__(self, config, vecenv, policy, logger=None):
if config.get("policy_architecture", "Recurrent") == "Recurrent":
config["bptt_horizon"] = vecenv.driver_env.episode_length
if config.get("policy_architecture", "Recurrent") == "Transformer":
config["context_window"] = self.context_length = vecenv.driver_env.episode_length
config["context_length"] = self.context_length = vecenv.driver_env.episode_length
config["bptt_horizon"] = (
vecenv.driver_env.episode_length
) ## this is used downstream so you need to define it too
else:
if config.get("policy_architecture", "Recurrent") == "Transformer":
self.context_length = config["context_length"]
config["bptt_horizon"] = config["context_length"]

vecenv.async_reset(seed)
obs_space = vecenv.single_observation_space
Expand All @@ -95,7 +99,7 @@ def __init__(self, config, vecenv, policy, logger=None):
if config.get("policy_architecture", "Recurrent") == "Recurrent":
batch_size = vecenv.driver_env.num_ego_agents * config["bptt_horizon"] * vecenv.num_workers
if config.get("policy_architecture", "Recurrent") == "Transformer":
batch_size = vecenv.driver_env.num_ego_agents * config["context_window"] * vecenv.num_workers
batch_size = vecenv.driver_env.num_ego_agents * config["context_length"] * vecenv.num_workers
config["batch_size"] = batch_size ## this is dynamic and based on ego agents
else:
agents_for_calc = total_agents
Expand All @@ -107,34 +111,34 @@ def __init__(self, config, vecenv, policy, logger=None):
if (
config["batch_size"] == "auto"
and config.get("bptt_horizon", "auto") == "auto"
and config.get("context_window", "auto") == "auto"
and config.get("context_length", "auto") == "auto"
):
raise pufferlib.APIUsageError("Must specify batch_size, bptt_horizon, or context_window")
raise pufferlib.APIUsageError("Must specify batch_size, bptt_horizon, or context_length")
elif config["batch_size"] == "auto":
if config.get("policy_architecture", "Recurrent") == "Recurrent":
config["batch_size"] = agents_for_calc * config["bptt_horizon"]
elif config.get("policy_architecture", "Recurrent") == "Transformer":
config["batch_size"] = agents_for_calc * config["context_window"]
config["batch_size"] = agents_for_calc * config["context_length"]
elif (
config.get("bptt_horizon", "auto") == "auto"
and config.get("policy_architecture", "Recurrent") == "Recurrent"
):
config["bptt_horizon"] = config["batch_size"] // agents_for_calc
elif (
config.get("context_window", "auto") == "auto"
config.get("context_length", "auto") == "auto"
and config.get("policy_architecture", "Recurrent") == "Transformer"
):
config["context_window"] = config["batch_size"] // agents_for_calc
config["context_length"] = config["batch_size"] // agents_for_calc

batch_size = config["batch_size"]

# Set horizon based on model type
if config.get("policy_architecture", "Recurrent") == "Recurrent":
horizon = config["bptt_horizon"]
elif config.get("policy_architecture", "Recurrent") == "Transformer":
horizon = config["context_window"]
horizon = config["context_length"]
else:
horizon = config.get("bptt_horizon", config.get("context_window", 1))
horizon = config.get("bptt_horizon", config.get("context_length", 1))

config["bptt_horizon"] = horizon # For backward compatibility

Expand Down Expand Up @@ -426,7 +430,7 @@ def evaluate(self):
state["transformer_context"] = self.transformer_context[state_key]
state["transformer_position"] = self.transformer_position[state_key]
# Note: terminals not needed for eval since we're doing single-step inference

print(f"o_device shape: {o_device.shape}", flush=True)
logits, value = self.policy.forward_eval(o_device, state)
action, logprob, _ = pufferlib.pytorch.sample_logits(logits)
r = torch.clamp(r, -1, 1)
Expand Down Expand Up @@ -483,7 +487,7 @@ def evaluate(self):
self.ep_lengths[env_id] += 1
# Use appropriate horizon based on model type
horizon = (
config.get("context_window")
config.get("context_length")
if config.get("policy_architecture", "Recurrent") == "Transformer"
else config["bptt_horizon"]
)
Expand Down Expand Up @@ -1255,7 +1259,13 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None):
elif args["wandb"]:
logger = WandbLogger(args)

train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {}), env_config=args.get("env", {}))
train_config = dict(
**args["train"],
env=env_name,
eval=args.get("eval", {}),
env_config=args.get("env", {}),
policy_architecture=args.get("policy_architecture", "Recurrent"),
)
pufferl = PuffeRL(train_config, vecenv, policy, logger)

all_logs = []
Expand Down Expand Up @@ -1674,17 +1684,18 @@ def load_policy(args, vecenv, env_name=""):
policy = policy_cls(vecenv.driver_env, **args["policy"])

# Handle both RNN and Transformer wrappers
rnn_name = args.get("rnn_name")
transformer_name = args.get("transformer_name")

if transformer_name is not None:
# Load transformer wrapper
transformer_cls = getattr(env_module.torch, transformer_name)
args["transformer"]["context_length"] = vecenv.driver_env.episode_length
policy_architecture = args.get("policy_architecture", "Recurrent")

if policy_architecture == "Transformer": # Load transformer wrapper
transformer_cls = getattr(env_module.torch, policy_architecture)
if args.get("env_name") == "puffer_drive":
args["transformer"]["context_length"] = args["train"]["context_length"]
elif agrs.get("env_name") == "puffer_adaptive_drive":
args["transformer"]["context_length"] = vecenv.driver_env.episode_length
policy = transformer_cls(vecenv.driver_env, policy, **args["transformer"])
elif rnn_name is not None:
elif policy_architecture == "Recurrent":
# Load RNN wrapper
rnn_cls = getattr(env_module.torch, rnn_name)
rnn_cls = getattr(env_module.torch, policy_architecture)
policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"])

policy = policy.to(device)
Expand Down
Loading