|
| 1 | +From 5b902cf87f2f8caa8a7023991992e547b1b7aef8 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Archana Shettigar < [email protected]> |
| 3 | +Date: Mon, 24 Nov 2025 11:08:31 +0530 |
| 4 | +Subject: [PATCH] Addressing CVE-2025-55552 |
| 5 | + |
| 6 | +--- |
| 7 | + test/inductor/test_torchinductor.py | 22 ++++- |
| 8 | + ...st_torchinductor_codegen_dynamic_shapes.py | 1 + |
| 9 | + torch/_inductor/decomposition.py | 95 +++++++++---------- |
| 10 | + 3 files changed, 66 insertions(+), 52 deletions(-) |
| 11 | + |
| 12 | +diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py |
| 13 | +index f1cfb90c..4c424157 100644 |
| 14 | +--- a/test/inductor/test_torchinductor.py |
| 15 | ++++ b/test/inductor/test_torchinductor.py |
| 16 | +@@ -5872,7 +5872,7 @@ class CommonTemplate: |
| 17 | + model = Model() |
| 18 | + x = torch.rand(10, 3, 0) |
| 19 | + |
| 20 | +- self.common(model, (x,)) |
| 21 | ++ self.common(model, (x,), exact_stride=True) |
| 22 | + |
| 23 | + def test_randint(self): |
| 24 | + @torch.compile(fullgraph=True) |
| 25 | +@@ -5907,9 +5907,21 @@ class CommonTemplate: |
| 26 | + @config.patch(fallback_random=True) |
| 27 | + def test_like_rands(self): |
| 28 | + def fn(x): |
| 29 | +- return torch.rand_like(x), torch.randn_like(x) |
| 30 | ++ return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11) |
| 31 | + |
| 32 | +- self.common(fn, [torch.zeros([20, 20])]) |
| 33 | ++ self.common(fn, [torch.zeros([20, 20])], exact_stride=True) |
| 34 | ++ |
| 35 | ++ @config.patch(fallback_random=True) |
| 36 | ++ @xfail_if_mps # 100% are not close |
| 37 | ++ def test_like_rands_sliced(self): |
| 38 | ++ def fn(x): |
| 39 | ++ return ( |
| 40 | ++ torch.randn_like(x), |
| 41 | ++ torch.randn_like(x), |
| 42 | ++ torch.randint_like(x, 1, 11), |
| 43 | ++ ) |
| 44 | ++ |
| 45 | ++ self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True) |
| 46 | + |
| 47 | + def test_like_rands2(self): |
| 48 | + # rand_like with kwargs `device` of str type |
| 49 | +@@ -5924,6 +5936,8 @@ class CommonTemplate: |
| 50 | + a0 = fn(x).clone() |
| 51 | + a1 = fn(x).clone() |
| 52 | + self.assertFalse(torch.allclose(a0, a1)) |
| 53 | ++ self.assertEqual(a0.shape, a1.shape) |
| 54 | ++ self.assertEqual(a0.stride(), a1.stride()) |
| 55 | + |
| 56 | + @requires_cuda() |
| 57 | + def test_like_rands3(self): |
| 58 | +@@ -5940,6 +5954,8 @@ class CommonTemplate: |
| 59 | + a1 = test_like_rands_on_different_device("cuda", "cpu") |
| 60 | + self.assertTrue(a0.device.type == "cuda") |
| 61 | + self.assertTrue(a1.device.type == "cpu") |
| 62 | ++ self.assertEqual(a0.shape, a1.shape) |
| 63 | ++ self.assertEqual(a0.stride(), a1.stride()) |
| 64 | + |
| 65 | + def test_max_pool2d_with_indices_backward(self): |
| 66 | + def fn(a, b, c): |
| 67 | +diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
| 68 | +index fa4b8040..ae52a802 100644 |
| 69 | +--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
| 70 | ++++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
| 71 | +@@ -162,6 +162,7 @@ test_failures = { |
| 72 | + "test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"), |
| 73 | + "test_bucketize_int_dynamic_shapes": TestFailure("cpu"), |
| 74 | + "test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 75 | ++ "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), |
| 76 | + "test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 77 | + "test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 78 | + "test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")), |
| 79 | +diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py |
| 80 | +index 88a56dea..c7943b2b 100644 |
| 81 | +--- a/torch/_inductor/decomposition.py |
| 82 | ++++ b/torch/_inductor/decomposition.py |
| 83 | +@@ -344,36 +344,6 @@ def view_copy_dtype(self, dtype): |
| 84 | + return self.to(dtype).clone() |
| 85 | + |
| 86 | + |
| 87 | +-def get_like_layout( |
| 88 | +- tensor: torch.Tensor, memory_format: Optional[torch.memory_format] |
| 89 | +-) -> torch.memory_format: |
| 90 | +- # TODO: _to_copy tensor to stride permutation |
| 91 | +- if memory_format is torch.preserve_format or memory_format is None: |
| 92 | +- return utils.suggest_memory_format(tensor) |
| 93 | +- else: |
| 94 | +- return memory_format |
| 95 | +- |
| 96 | +- |
| 97 | +-@register_decomposition(aten.rand_like) |
| 98 | +-def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): |
| 99 | +- return torch.rand( |
| 100 | +- [*self.size()], |
| 101 | +- dtype=dtype or self.dtype, |
| 102 | +- device=device or self.device, |
| 103 | +- **kwargs, |
| 104 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 105 | +- |
| 106 | +- |
| 107 | +-@register_decomposition(aten.randn_like) |
| 108 | +-def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): |
| 109 | +- return torch.randn( |
| 110 | +- [*self.size()], |
| 111 | +- dtype=dtype or self.dtype, |
| 112 | +- device=device or self.device, |
| 113 | +- **kwargs, |
| 114 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 115 | +- |
| 116 | +- |
| 117 | + @register_decomposition(aten.full_like) |
| 118 | + def full_like( |
| 119 | + self, |
| 120 | +@@ -396,30 +366,57 @@ def full_like( |
| 121 | + ).to(memory_format=get_like_layout(self, memory_format)) |
| 122 | + |
| 123 | + |
| 124 | +-@register_decomposition(aten.randint_like.default) |
| 125 | +-def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs): |
| 126 | +- return aten.randint.low( |
| 127 | +- 0, |
| 128 | +- high, |
| 129 | +- [*self.size()], |
| 130 | +- dtype=dtype or self.dtype, |
| 131 | +- device=device or self.device, |
| 132 | ++def _rand_like( |
| 133 | ++ rand_fn: Callable[..., torch.Tensor], |
| 134 | ++ self: torch.Tensor, |
| 135 | ++ *, |
| 136 | ++ dtype: Optional[torch.dtype] = None, |
| 137 | ++ device: Optional[torch.device] = None, |
| 138 | ++ memory_format: torch.memory_format = torch.preserve_format, |
| 139 | ++ **kwargs: Any, |
| 140 | ++) -> torch.Tensor: |
| 141 | ++ dtype = self.dtype if dtype is None else dtype |
| 142 | ++ device = self.device if device is None else device |
| 143 | ++ |
| 144 | ++ if memory_format != torch.preserve_format: |
| 145 | ++ return rand_fn( |
| 146 | ++ self.shape, |
| 147 | ++ dtype=dtype, |
| 148 | ++ device=device, |
| 149 | ++ **kwargs, |
| 150 | ++ ).to(memory_format=memory_format) |
| 151 | ++ |
| 152 | ++ shape, permutation = _get_shape_permutation_like(self) |
| 153 | ++ result = rand_fn( |
| 154 | ++ shape, |
| 155 | ++ dtype=dtype, |
| 156 | ++ device=device, |
| 157 | + **kwargs, |
| 158 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 159 | ++ ) |
| 160 | ++ if permutation == list(range(len(permutation))): |
| 161 | ++ return result |
| 162 | ++ return result.permute(permutation).clone() |
| 163 | + |
| 164 | + |
| 165 | ++@register_decomposition(aten.rand_like) |
| 166 | ++def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| 167 | ++ return _rand_like(torch.rand, self, **kwargs) |
| 168 | ++ |
| 169 | ++ |
| 170 | ++@register_decomposition(aten.randn_like) |
| 171 | ++def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| 172 | ++ return _rand_like(torch.randn, self, **kwargs) |
| 173 | ++ |
| 174 | ++ |
| 175 | ++@register_decomposition(aten.randint_like.default) |
| 176 | ++def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: |
| 177 | ++ return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) |
| 178 | ++ |
| 179 | + @register_decomposition(aten.randint_like.low_dtype) |
| 180 | + def randint_like_low( |
| 181 | +- self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs |
| 182 | +-): |
| 183 | +- return aten.randint.low( |
| 184 | +- low, |
| 185 | +- high, |
| 186 | +- [*self.size()], |
| 187 | +- dtype=dtype or self.dtype, |
| 188 | +- device=device or self.device, |
| 189 | +- **kwargs, |
| 190 | +- ).to(memory_format=get_like_layout(self, memory_format)) |
| 191 | ++ self: torch.Tensor, low: int, high: int, **kwargs: Any |
| 192 | ++) -> torch.Tensor: |
| 193 | ++ return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs) |
| 194 | + |
| 195 | + |
| 196 | + @register_decomposition(aten.randint.default) |
| 197 | +-- |
| 198 | +2.45.4 |
| 199 | + |
0 commit comments