|
58 | 58 | "name": "stdout", |
59 | 59 | "output_type": "stream", |
60 | 60 | "text": [ |
61 | | - "PyTorch version: 2.6.0+cu124\n" |
| 61 | + "PyTorch version: 2.8.0\n" |
62 | 62 | ] |
63 | 63 | } |
64 | 64 | ], |
|
89 | 89 | }, |
90 | 90 | { |
91 | 91 | "cell_type": "code", |
92 | | - "execution_count": null, |
| 92 | + "execution_count": 2, |
93 | 93 | "id": "1db27f43-86f4-478f-89df-fbc2182a129b", |
94 | 94 | "metadata": { |
95 | 95 | "id": "1db27f43-86f4-478f-89df-fbc2182a129b" |
|
114 | 114 | }, |
115 | 115 | { |
116 | 116 | "cell_type": "code", |
117 | | - "execution_count": 2, |
| 117 | + "execution_count": 3, |
118 | 118 | "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", |
119 | 119 | "metadata": { |
120 | 120 | "colab": { |
|
205 | 205 | }, |
206 | 206 | { |
207 | 207 | "cell_type": "code", |
208 | | - "execution_count": 3, |
| 208 | + "execution_count": 4, |
209 | 209 | "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", |
210 | 210 | "metadata": { |
211 | 211 | "colab": { |
|
326 | 326 | }, |
327 | 327 | { |
328 | 328 | "cell_type": "code", |
329 | | - "execution_count": 4, |
| 329 | + "execution_count": 5, |
330 | 330 | "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", |
331 | 331 | "metadata": { |
332 | 332 | "colab": { |
|
434 | 434 | }, |
435 | 435 | { |
436 | 436 | "cell_type": "code", |
437 | | - "execution_count": 5, |
| 437 | + "execution_count": 6, |
438 | 438 | "id": "92481814-068d-439b-a65c-b1310ebbe0aa", |
439 | 439 | "metadata": { |
440 | 440 | "colab": { |
|
466 | 466 | " self.num_heads = num_heads\n", |
467 | 467 | " self.head_dim = d_out // num_heads\n", |
468 | 468 | "\n", |
469 | | - " # Initialize parameters for Q, K, V\n", |
470 | 469 | " self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n", |
471 | 470 | " self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n", |
472 | 471 | " self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n", |
|
483 | 482 | " self.out_proj = nn.Linear(d_out, d_out)\n", |
484 | 483 | " self.dropout = nn.Dropout(dropout)\n", |
485 | 484 | " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", |
486 | | - "\n", |
487 | | - " # Initialize parameters\n", |
488 | 485 | " self.reset_parameters()\n", |
489 | 486 | "\n", |
490 | 487 | "\n", |
|
0 commit comments