|
9 | 9 |
|
10 | 10 | import edgeml_pytorch.utils as utils |
11 | 11 |
|
12 | | -if utils.findCUDA() is not None: |
13 | | - import fastgrnn_cuda |
| 12 | +try: |
| 13 | + if utils.findCUDA() is not None: |
| 14 | + import fastgrnn_cuda |
| 15 | +except: |
| 16 | + print("Running without FastGRNN CUDA") |
| 17 | + pass |
14 | 18 |
|
15 | 19 |
|
16 | 20 | # All the matrix vector computations of the form Wx are done |
@@ -351,29 +355,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", |
351 | 355 | self._name = name |
352 | 356 |
|
353 | 357 | if wRank is None: |
354 | | - self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device)) |
| 358 | + self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device)) |
355 | 359 | self.W1 = torch.empty(0) |
356 | 360 | self.W2 = torch.empty(0) |
357 | 361 | else: |
358 | 362 | self.W = torch.empty(0) |
359 | | - self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device)) |
360 | | - self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device)) |
| 363 | + self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device)) |
| 364 | + self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device)) |
361 | 365 |
|
362 | 366 | if uRank is None: |
363 | | - self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device)) |
| 367 | + self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device)) |
364 | 368 | self.U1 = torch.empty(0) |
365 | 369 | self.U2 = torch.empty(0) |
366 | 370 | else: |
367 | 371 | self.U = torch.empty(0) |
368 | | - self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device)) |
369 | | - self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device)) |
| 372 | + self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device)) |
| 373 | + self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device)) |
370 | 374 |
|
371 | 375 | self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity] |
372 | 376 |
|
373 | | - self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device)) |
374 | | - self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device)) |
375 | | - self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device)) |
376 | | - self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device)) |
| 377 | + self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device)) |
| 378 | + self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device)) |
| 379 | + self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device)) |
| 380 | + self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device)) |
377 | 381 |
|
378 | 382 | @property |
379 | 383 | def name(self): |
|
0 commit comments