We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents ec47c8a + c2c2fd3 commit 12869c7Copy full SHA for 12869c7
layer.py
@@ -184,7 +184,7 @@ def forward(self, idx):
184
adj = F.relu(torch.tanh(self.alpha*a))
185
mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
186
mask.fill_(float('0'))
187
- s1,t1 = adj.topk(self.k,1)
+ s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1)
188
mask.scatter_(1,t1,s1.fill_(1))
189
adj = adj*mask
190
return adj
0 commit comments