-
Notifications
You must be signed in to change notification settings - Fork 430
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
In the MultiAgentNetBase class, there's this agent_dim parameter, which I took to be the axis of the input that the agents are parallelizing over. Changing it, however, doesn't seem to do anything.
To Reproduce
# -*- coding: utf-8 -*-
import torch
from torch import nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.modules import ProbabilisticActor
from torchrl.modules.distributions import TanhNormal
from tensordict.nn.distributions import NormalParamExtractor
# Hyperparameters
num_envs = 4
n_agents = 3
obs_dim = 5
seq_len = 6
action_dim = 2
# Action bounds (needed by TanhNormal)
action_low = -1.0
action_high = 1.0
class SingleAgentMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 64),
nn.Tanh(),
nn.Linear(64, out_dim),
)
def forward(self, x):
return self.net(x)
class MultiAgentPolicyNet(MultiAgentNetBase):
def __init__(
self,
obs_dim,
action_dim,
n_agents,
share_params=True,
device=None,
):
self.obs_dim = obs_dim
self.action_dim = action_dim
super().__init__(
n_agents=n_agents,
centralized=False,
share_params=share_params,
agent_dim=1, # (..., agents, features)
device=device,
)
def _build_single_net(self, *, device, **kwargs):
# Output 2 * action_dim: loc + scale
net = SingleAgentMLP(self.obs_dim, 2 * self.action_dim)
return net.to(device) if device is not None else net
def _pre_forward_check(self, inputs):
# inputs: (B, A, obs_dim)
return inputs
policy_net = MultiAgentPolicyNet(
obs_dim=obs_dim,
action_dim=action_dim,
n_agents=n_agents,
share_params=True,
)
policy_module = TensorDictModule(
module= nn.Sequential(policy_net,NormalParamExtractor()),
in_keys=[("agents", "observation")],
out_keys=[("agents", "loc"), ("agents", "scale")],
)
actor = ProbabilisticActor(
module=policy_module,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[("agents", "action")],
distribution_class=TanhNormal,
distribution_kwargs={
"low": action_low,
"high": action_high,
},
return_log_prob=True,
)
td = TensorDict(
{
("agents", "observation"): torch.randn(
num_envs, n_agents, seq_len, obs_dim
)
},
batch_size=[num_envs],
)
td_out = actor(td)
print("Keys in output TensorDict:")
for k in td_out.keys(True):
print(" ", k)
print("\nShapes:")
print("observation:", td_out["agents", "observation"].shape)
print("action:", td_out["agents", "action"].shape)
print("log_prob:", td_out["agents", "action_log_prob"].shape)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
RuntimeError: TensorDictModule failed with operation
Sequential(
(0): MultiAgentPolicyNet(
SingleAgentMLP(
(net): Sequential(
(0): Linear(in_features=5, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=4, bias=True)
)
),
n_agents=3,
share_params=True,
centralized=False,
agent_dim=1)
(1): NormalParamExtractor(
(scale_mapping): biased_softplus()
)
)
in_keys=[('agents', 'observation')]
out_keys=[('agents', 'loc'), ('agents', 'scale')].
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
File e:\ai_assist\market_analysis\multi_agent_example_1.py:115
101 td = TensorDict(
102 {
103 ("agents", "observation"): torch.randn(
(...) 107 batch_size=[num_envs],
108 )
110 # td = TensorDict({f"agent_{i}":{"observation": torch.randn(num_envs,
111 # seq_len,
112 # obs_dim)}
113 # for i in range(n_agents)})
--> 115 td_out = actor(td)
117 print("Keys in output TensorDict:")
118 for k in td_out.keys(True):
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:328, in dispatch.__call__.<locals>.wrapper(*args, **kwargs)
325 return out[0] if len(out) == 1 else out
327 if _self is not None:
--> 328 return func(_self, tensordict, *args, **kwargs)
329 return func(tensordict, *args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\utils.py:372, in _set_skip_existing_None.__call__.<locals>.wrapper(_self, tensordict, *args, **kwargs)
370 self.prev = _skip_existing.get_mode()
371 try:
--> 372 result = func(_self, tensordict, *args, **kwargs)
373 finally:
374 _skip_existing.set_mode(self.prev)
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\probabilistic.py:1364, in ProbabilisticTensorDictSequential.forward(self, tensordict, tensordict_out, **kwargs)
1360 raise RuntimeError(
1361 f"Failed while executing module '{module_num_or_key}'. Scroll up for more info."
1362 ) from e
1363 else:
-> 1364 tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
1365 tensordict_exec = self._last_module(
1366 tensordict_exec, _requires_sample=self._requires_sample
1367 )
1369 if self.inplace is True:
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\probabilistic.py:1116, in ProbabilisticTensorDictSequential.get_dist_params(self, tensordict, tensordict_out, **kwargs)
1114 raise ValueError("Could not find a default interaction in the modules.")
1115 with set_interaction_type(type):
-> 1116 return tds(tensordict, tensordict_out, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:328, in dispatch.__call__.<locals>.wrapper(*args, **kwargs)
325 return out[0] if len(out) == 1 else out
327 if _self is not None:
--> 328 return func(_self, tensordict, *args, **kwargs)
329 return func(tensordict, *args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\utils.py:372, in _set_skip_existing_None.__call__.<locals>.wrapper(_self, tensordict, *args, **kwargs)
370 self.prev = _skip_existing.get_mode()
371 try:
--> 372 result = func(_self, tensordict, *args, **kwargs)
373 finally:
374 _skip_existing.set_mode(self.prev)
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\sequence.py:633, in TensorDictSequential.forward(self, tensordict, tensordict_out, **kwargs)
631 for module in self._module_iter():
632 try:
--> 633 tensordict_exec = self._run_module(
634 module, tensordict_exec, **kwargs
635 )
636 except Exception as e:
637 if _has_py311_or_greater:
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\sequence.py:579, in TensorDictSequential._run_module(self, module, tensordict, **kwargs)
570 def _run_module(
571 self,
572 module: TensorDictModuleBase,
573 tensordict: TensorDictBase,
574 **kwargs: Any,
575 ) -> Any:
576 if not self.partial_tolerant or all(
577 key in tensordict.keys(include_nested=True) for key in module.in_keys
578 ):
--> 579 tensordict = module(tensordict, **kwargs)
580 elif self.partial_tolerant and isinstance(tensordict, LazyStackedTensorDict):
581 for sub_td in tensordict.tensordicts:
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:328, in dispatch.__call__.<locals>.wrapper(*args, **kwargs)
325 return out[0] if len(out) == 1 else out
327 if _self is not None:
--> 328 return func(_self, tensordict, *args, **kwargs)
329 return func(tensordict, *args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\utils.py:372, in _set_skip_existing_None.__call__.<locals>.wrapper(_self, tensordict, *args, **kwargs)
370 self.prev = _skip_existing.get_mode()
371 try:
--> 372 result = func(_self, tensordict, *args, **kwargs)
373 finally:
374 _skip_existing.set_mode(self.prev)
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:1218, in TensorDictModule.forward(self, tensordict, tensordict_out, *args, **kwargs)
1216 in_keys = indent(f"in_keys={self.in_keys}", 4 * " ")
1217 out_keys = indent(f"out_keys={self.out_keys}", 4 * " ")
-> 1218 raise err from RuntimeError(
1219 f"TensorDictModule failed with operation\n{module}\n{in_keys}\n{out_keys}."
1220 )
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:1190, in TensorDictModule.forward(self, tensordict, tensordict_out, *args, **kwargs)
1184 raise KeyError(
1185 "Some tensors that are necessary for the module call may "
1186 "not have not been found in the input tensordict: "
1187 f"the following inputs are None: {none_set}."
1188 ) from err
1189 else:
-> 1190 raise err
1191 if isinstance(tensors_out, (dict, TensorDictBase)) and all(
1192 key in tensors_out for key in self.out_keys
1193 ):
1194 if isinstance(tensors_out, dict):
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:1174, in TensorDictModule.forward(self, tensordict, tensordict_out, *args, **kwargs)
1165 tensors = tuple( # type: ignore[unreachable]
1166 tensordict._get_tuple_maybe_non_tensor(
1167 _unravel_key_to_tuple(in_key),
(...) 1171 for in_key in self.in_keys
1172 )
1173 try:
-> 1174 tensors_out = self._call_module(tensors, **kwargs)
1175 if tensors_out is None:
1176 tensors_out = ()
File E:\miniconda\envs\science\Lib\site-packages\tensordict\nn\common.py:1133, in TensorDictModule._call_module(self, tensors, **kwargs)
1131 kwargs.update(self.method_kwargs)
1132 if self.method is None:
-> 1133 out = self.module(*tensors, **kwargs)
1134 else:
1135 out = getattr(self.module, self.method)(*tensors, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\container.py:240, in Sequential.forward(self, input)
238 def forward(self, input):
239 for module in self:
--> 240 input = module(input)
241 return input
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File E:\miniconda\envs\science\Lib\site-packages\torch\nn\modules\module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File E:\miniconda\envs\science\Lib\site-packages\torchrl\modules\models\multiagent.py:166, in MultiAgentNetBase.forward(self, *inputs)
161 output = output.expand(
162 *output.shape[:-2], self.n_agents, n_agent_outputs
163 )
165 if output.shape[-2] != (self.n_agents):
--> 166 raise ValueError(
167 f"Multi-agent network expected output with shape[-2]={self.n_agents}"
168 f" but got {output.shape}"
169 )
171 return output
ValueError: Multi-agent network expected output with shape[-2]=3 but got torch.Size([4, 3, 6, 4])
Failed while executing module '0'.Expected behavior
I thought by specifying the agent_dim, the agents are to parallelize over the specified dimension instead of always sticking to dim=-2
td = TensorDict(
{
("agents", "observation"): torch.randn(
num_envs, seq_len, n_agents, obs_dim
)
},
batch_size=[num_envs],
)
works as input, which means previous
td = TensorDict(
{
("agents", "observation"): torch.randn(
num_envs, n_agents, seq_len, obs_dim
)
},
batch_size=[num_envs],
)
was exactly the trigger. which means
super().__init__(
n_agents=n_agents,
centralized=False,
share_params=share_params,
agent_dim=1, # (..., agents, features)
device=device,
)
was ineffective.
System info
Describe the characteristic of your environment:
- torchrl installed via pip 1 week ago
- Python 3.11.11
Checklist
- [* ] I have checked that there is no similar issue in the repo (required)
- [* ] I have read the documentation (required)
- [* ] I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working