Skip to content

Commit 1a98de4

Browse files
committed
Patch pytorch for CVE-2025-55552
1 parent 962e73f commit 1a98de4

File tree

3 files changed

+427
-1
lines changed

3 files changed

+427
-1
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
From 5b902cf87f2f8caa8a7023991992e547b1b7aef8 Mon Sep 17 00:00:00 2001
2+
From: Archana Shettigar <[email protected]>
3+
Date: Mon, 24 Nov 2025 11:08:31 +0530
4+
Subject: [PATCH] Addressing CVE-2025-55552
5+
6+
---
7+
test/inductor/test_torchinductor.py | 22 ++++-
8+
...st_torchinductor_codegen_dynamic_shapes.py | 1 +
9+
torch/_inductor/decomposition.py | 95 +++++++++----------
10+
3 files changed, 66 insertions(+), 52 deletions(-)
11+
12+
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
13+
index f1cfb90c..4c424157 100644
14+
--- a/test/inductor/test_torchinductor.py
15+
+++ b/test/inductor/test_torchinductor.py
16+
@@ -5872,7 +5872,7 @@ class CommonTemplate:
17+
model = Model()
18+
x = torch.rand(10, 3, 0)
19+
20+
- self.common(model, (x,))
21+
+ self.common(model, (x,), exact_stride=True)
22+
23+
def test_randint(self):
24+
@torch.compile(fullgraph=True)
25+
@@ -5907,9 +5907,21 @@ class CommonTemplate:
26+
@config.patch(fallback_random=True)
27+
def test_like_rands(self):
28+
def fn(x):
29+
- return torch.rand_like(x), torch.randn_like(x)
30+
+ return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11)
31+
32+
- self.common(fn, [torch.zeros([20, 20])])
33+
+ self.common(fn, [torch.zeros([20, 20])], exact_stride=True)
34+
+
35+
+ @config.patch(fallback_random=True)
36+
+ @xfail_if_mps # 100% are not close
37+
+ def test_like_rands_sliced(self):
38+
+ def fn(x):
39+
+ return (
40+
+ torch.randn_like(x),
41+
+ torch.randn_like(x),
42+
+ torch.randint_like(x, 1, 11),
43+
+ )
44+
+
45+
+ self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True)
46+
47+
def test_like_rands2(self):
48+
# rand_like with kwargs `device` of str type
49+
@@ -5924,6 +5936,8 @@ class CommonTemplate:
50+
a0 = fn(x).clone()
51+
a1 = fn(x).clone()
52+
self.assertFalse(torch.allclose(a0, a1))
53+
+ self.assertEqual(a0.shape, a1.shape)
54+
+ self.assertEqual(a0.stride(), a1.stride())
55+
56+
@requires_cuda()
57+
def test_like_rands3(self):
58+
@@ -5940,6 +5954,8 @@ class CommonTemplate:
59+
a1 = test_like_rands_on_different_device("cuda", "cpu")
60+
self.assertTrue(a0.device.type == "cuda")
61+
self.assertTrue(a1.device.type == "cpu")
62+
+ self.assertEqual(a0.shape, a1.shape)
63+
+ self.assertEqual(a0.stride(), a1.stride())
64+
65+
def test_max_pool2d_with_indices_backward(self):
66+
def fn(a, b, c):
67+
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
68+
index fa4b8040..ae52a802 100644
69+
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
70+
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
71+
@@ -162,6 +162,7 @@ test_failures = {
72+
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"),
73+
"test_bucketize_int_dynamic_shapes": TestFailure("cpu"),
74+
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")),
75+
+ "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
76+
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")),
77+
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")),
78+
"test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")),
79+
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
80+
index 88a56dea..c7943b2b 100644
81+
--- a/torch/_inductor/decomposition.py
82+
+++ b/torch/_inductor/decomposition.py
83+
@@ -344,36 +344,6 @@ def view_copy_dtype(self, dtype):
84+
return self.to(dtype).clone()
85+
86+
87+
-def get_like_layout(
88+
- tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
89+
-) -> torch.memory_format:
90+
- # TODO: _to_copy tensor to stride permutation
91+
- if memory_format is torch.preserve_format or memory_format is None:
92+
- return utils.suggest_memory_format(tensor)
93+
- else:
94+
- return memory_format
95+
-
96+
-
97+
-@register_decomposition(aten.rand_like)
98+
-def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
99+
- return torch.rand(
100+
- [*self.size()],
101+
- dtype=dtype or self.dtype,
102+
- device=device or self.device,
103+
- **kwargs,
104+
- ).to(memory_format=get_like_layout(self, memory_format))
105+
-
106+
-
107+
-@register_decomposition(aten.randn_like)
108+
-def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
109+
- return torch.randn(
110+
- [*self.size()],
111+
- dtype=dtype or self.dtype,
112+
- device=device or self.device,
113+
- **kwargs,
114+
- ).to(memory_format=get_like_layout(self, memory_format))
115+
-
116+
-
117+
@register_decomposition(aten.full_like)
118+
def full_like(
119+
self,
120+
@@ -396,30 +366,57 @@ def full_like(
121+
).to(memory_format=get_like_layout(self, memory_format))
122+
123+
124+
-@register_decomposition(aten.randint_like.default)
125+
-def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
126+
- return aten.randint.low(
127+
- 0,
128+
- high,
129+
- [*self.size()],
130+
- dtype=dtype or self.dtype,
131+
- device=device or self.device,
132+
+def _rand_like(
133+
+ rand_fn: Callable[..., torch.Tensor],
134+
+ self: torch.Tensor,
135+
+ *,
136+
+ dtype: Optional[torch.dtype] = None,
137+
+ device: Optional[torch.device] = None,
138+
+ memory_format: torch.memory_format = torch.preserve_format,
139+
+ **kwargs: Any,
140+
+) -> torch.Tensor:
141+
+ dtype = self.dtype if dtype is None else dtype
142+
+ device = self.device if device is None else device
143+
+
144+
+ if memory_format != torch.preserve_format:
145+
+ return rand_fn(
146+
+ self.shape,
147+
+ dtype=dtype,
148+
+ device=device,
149+
+ **kwargs,
150+
+ ).to(memory_format=memory_format)
151+
+
152+
+ shape, permutation = _get_shape_permutation_like(self)
153+
+ result = rand_fn(
154+
+ shape,
155+
+ dtype=dtype,
156+
+ device=device,
157+
**kwargs,
158+
- ).to(memory_format=get_like_layout(self, memory_format))
159+
+ )
160+
+ if permutation == list(range(len(permutation))):
161+
+ return result
162+
+ return result.permute(permutation).clone()
163+
164+
165+
+@register_decomposition(aten.rand_like)
166+
+def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
167+
+ return _rand_like(torch.rand, self, **kwargs)
168+
+
169+
+
170+
+@register_decomposition(aten.randn_like)
171+
+def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
172+
+ return _rand_like(torch.randn, self, **kwargs)
173+
+
174+
+
175+
+@register_decomposition(aten.randint_like.default)
176+
+def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
177+
+ return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)
178+
+
179+
@register_decomposition(aten.randint_like.low_dtype)
180+
def randint_like_low(
181+
- self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
182+
-):
183+
- return aten.randint.low(
184+
- low,
185+
- high,
186+
- [*self.size()],
187+
- dtype=dtype or self.dtype,
188+
- device=device or self.device,
189+
- **kwargs,
190+
- ).to(memory_format=get_like_layout(self, memory_format))
191+
+ self: torch.Tensor, low: int, high: int, **kwargs: Any
192+
+) -> torch.Tensor:
193+
+ return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)
194+
195+
196+
@register_decomposition(aten.randint.default)
197+
--
198+
2.45.4
199+

0 commit comments

Comments
 (0)