Skip to content

Conversation

@hsezhiyan
Copy link

No description provided.

y = torch.normal( # pylint: disable=no-member
mean=batch["player_future"][..., :2],
mean=target_mean,
std=torch.ones_like(batch["player_future"][..., :2]) * noise_level, # pylint: disable=no-member
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use target_mean instead of batch["player_future"][..., :2]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Filangel it looks like I'm already using target_mean. Can you clarify?


# Calculates loss (NLL).
loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member
if use_tcn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the API for the decoder should be the same, regardless if it's an RNN or a TCN

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what you mean by this? By API, do you want the TCN decoder to similarly have a .inverse method, as the Autoregressive decoder does?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants