Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 387 additions & 0 deletions SPECS/pytorch/CVE-2025-55552.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
From c849ccbd342b6067d19d5805c6614a21a4f0b49f Mon Sep 17 00:00:00 2001
From: Sam Larsen <[email protected]>
Date: Fri, 25 Jul 2025 09:31:15 -0700
Subject: [PATCH] Fix full_like decomposition to preserve strides (#158898)

Summary:
See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again.

Upstream Patch Reference: https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159294.patch & https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/158898.patch
---
test/inductor/test_torchinductor.py | 51 ++++++-
...st_torchinductor_codegen_dynamic_shapes.py | 1 +
test/test_decomp.py | 11 +-
torch/_inductor/decomposition.py | 139 +++++++++++-------
torch/_inductor/lowering.py | 1 -
5 files changed, 145 insertions(+), 58 deletions(-)

diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index f1cfb90c..e2daa183 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -264,6 +264,8 @@ def check_model(
check_gradient=False,
check_has_compiled=True,
output_process_fn_grad=lambda x: x,
+ # TODO: enable this for all tests
+ exact_stride=False,
):
kwargs = kwargs or {}
torch._dynamo.reset()
@@ -282,6 +284,12 @@ def check_model(
):
has_lowp_args = True
return x.float()
+ # Preserve strides when casting
+ result = torch.empty_strided(
+ x.size(), x.stride(), device=x.device, dtype=torch.float
+ )
+ result.copy_(x)
+ return result
else:
return x

@@ -353,6 +361,7 @@ def check_model(
rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
+ exact_stride=exact_stride,
)
# In case of input mutations, check that inputs are the same
self.assertEqual(
@@ -363,6 +372,7 @@ def check_model(
equal_nan=True,
# our testing sometimes uses higher precision inputs for the reference
exact_dtype=False,
+ exact_stride=exact_stride,
)
else:
for correct_val, actual_val in zip(correct_flat, actual_flat):
@@ -376,6 +386,8 @@ def check_model(
assert correct_val.layout == actual_val.layout
if exact_dtype:
assert correct_val.dtype == actual_val.dtype
+ if exact_stride:
+ assert correct_val.stride() == actual_val.stride()

if check_gradient:
actual = output_process_fn_grad(actual)
@@ -423,6 +435,7 @@ def check_model(
rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
+ exact_stride=exact_stride,
)

torch._dynamo.reset()
@@ -446,6 +459,8 @@ def check_model_cuda(
check_gradient=False,
check_has_compiled=True,
output_process_fn_grad=lambda x: x,
+ # TODO: enable this for all tests
+ exact_stride=False,
):
kwargs = kwargs or {}
if hasattr(model, "to"):
@@ -470,6 +485,7 @@ def check_model_cuda(
check_gradient=check_gradient,
check_has_compiled=check_has_compiled,
output_process_fn_grad=output_process_fn_grad,
+ exact_stride=exact_stride,
)

if check_lowp:
@@ -500,6 +516,7 @@ def check_model_cuda(
check_gradient=check_gradient,
check_has_compiled=check_has_compiled,
output_process_fn_grad=output_process_fn_grad,
+ exact_stride=exact_stride,
)


@@ -4194,6 +4211,18 @@ class CommonTemplate:

self.common(fn, (torch.randn(8),))

+ def test_full_like_transposed(self):
+ def fn(a):
+ return torch.full_like(a, 3)
+
+ self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True)
+
+ def test_full_like_sliced(self):
+ def fn(a):
+ return torch.full_like(a, 3)
+
+ self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True)
+
def test_full_truncation(self):
def fn(a):
return a + torch.full_like(a, 7.777)
@@ -5872,7 +5901,7 @@ class CommonTemplate:
model = Model()
x = torch.rand(10, 3, 0)

- self.common(model, (x,))
+ self.common(model, (x,), exact_stride=True)

def test_randint(self):
@torch.compile(fullgraph=True)
@@ -5907,9 +5936,21 @@ class CommonTemplate:
@config.patch(fallback_random=True)
def test_like_rands(self):
def fn(x):
- return torch.rand_like(x), torch.randn_like(x)
+ return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11)
+
+ self.common(fn, [torch.zeros([20, 20])], exact_stride=True)
+
+ @config.patch(fallback_random=True)
+ @xfail_if_mps # 100% are not close
+ def test_like_rands_sliced(self):
+ def fn(x):
+ return (
+ torch.randn_like(x),
+ torch.randn_like(x),
+ torch.randint_like(x, 1, 11),
+ )

- self.common(fn, [torch.zeros([20, 20])])
+ self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True)

def test_like_rands2(self):
# rand_like with kwargs `device` of str type
@@ -5924,6 +5965,8 @@ class CommonTemplate:
a0 = fn(x).clone()
a1 = fn(x).clone()
self.assertFalse(torch.allclose(a0, a1))
+ self.assertEqual(a0.shape, a1.shape)
+ self.assertEqual(a0.stride(), a1.stride())

@requires_cuda()
def test_like_rands3(self):
@@ -5940,6 +5983,8 @@ class CommonTemplate:
a1 = test_like_rands_on_different_device("cuda", "cpu")
self.assertTrue(a0.device.type == "cuda")
self.assertTrue(a1.device.type == "cpu")
+ self.assertEqual(a0.shape, a1.shape)
+ self.assertEqual(a0.stride(), a1.stride())

def test_max_pool2d_with_indices_backward(self):
def fn(a, b, c):
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
index fa4b8040..ae52a802 100644
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
@@ -162,6 +162,7 @@ test_failures = {
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"),
"test_bucketize_int_dynamic_shapes": TestFailure("cpu"),
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")),
+ "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")),
diff --git a/test/test_decomp.py b/test/test_decomp.py
index 10df8b8b..9ad20995 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -693,7 +693,16 @@ class TestDecomp(TestCase):
assert len(real_out) == len(decomp_out)

if do_relative_check:
- upcast = partial(upcast_tensor, dtype=torch.float64)
+ device_arg = kwargs.get("device", None)
+
+ def upcast(x):
+ if (isinstance(x, Tensor) and x.device.type == "mps") or (
+ device_arg and torch.device(device_arg).type == "mps"
+ ):
+ return upcast_tensor(x, dtype=torch.float32)
+ else:
+ return upcast_tensor(x, dtype=torch.float64)
+
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 88a56dea..6f396f50 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -343,35 +343,19 @@ def view_copy_default(self, size):
def view_copy_dtype(self, dtype):
return self.to(dtype).clone()

+def _get_shape_permutation_like(
+ self: torch.Tensor, layout: torch.layout
+) -> tuple[utils.ShapeType, utils.StrideType]:
+ assert layout == torch.strided

-def get_like_layout(
- tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
-) -> torch.memory_format:
- # TODO: _to_copy tensor to stride permutation
- if memory_format is torch.preserve_format or memory_format is None:
- return utils.suggest_memory_format(tensor)
- else:
- return memory_format
-
-
-@register_decomposition(aten.rand_like)
-def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
- return torch.rand(
- [*self.size()],
- dtype=dtype or self.dtype,
- device=device or self.device,
- **kwargs,
- ).to(memory_format=get_like_layout(self, memory_format))
+ physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
+ shape = [self.shape[l] for l in physical_layout]

+ permutation = [0] * len(shape)
+ for p, l in enumerate(physical_layout):
+ permutation[l] = p

-@register_decomposition(aten.randn_like)
-def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
- return torch.randn(
- [*self.size()],
- dtype=dtype or self.dtype,
- device=device or self.device,
- **kwargs,
- ).to(memory_format=get_like_layout(self, memory_format))
+ return (shape, permutation)


@register_decomposition(aten.full_like)
@@ -386,40 +370,89 @@ def full_like(
requires_grad=False,
memory_format=torch.preserve_format,
):
- return torch.full(
- [*self.size()],
- fill_value,
- dtype=dtype or self.dtype,
- layout=layout or self.layout,
- device=device or self.device,
- requires_grad=requires_grad,
- ).to(memory_format=get_like_layout(self, memory_format))
+ dtype = self.dtype if dtype is None else dtype
+ layout = self.layout if layout is None else layout
+ device = self.device if device is None else device
+
+ if memory_format != torch.preserve_format:
+ result = torch.full(
+ self.shape,
+ fill_value,
+ dtype=dtype,
+ layout=layout,
+ device=device,
+ pin_memory=pin_memory,
+ requires_grad=requires_grad,
+ )
+ return result.to(memory_format=memory_format)

+ else:
+ shape, permutation = _get_shape_permutation_like(self, layout)
+ result = torch.full(
+ shape,
+ fill_value,
+ dtype=dtype,
+ layout=layout,
+ device=device,
+ pin_memory=pin_memory,
+ requires_grad=requires_grad,
+ )
+ if permutation == list(range(len(permutation))):
+ return result
+ return result.permute(permutation).clone()

-@register_decomposition(aten.randint_like.default)
-def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
- return aten.randint.low(
- 0,
- high,
- [*self.size()],
- dtype=dtype or self.dtype,
- device=device or self.device,
+
+def _rand_like(
+ rand_fn: Callable[..., torch.Tensor],
+ self: torch.Tensor,
+ *,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ memory_format: torch.memory_format = torch.preserve_format,
+ **kwargs: Any,
+) -> torch.Tensor:
+ dtype = self.dtype if dtype is None else dtype
+ device = self.device if device is None else device
+
+ if memory_format != torch.preserve_format:
+ return rand_fn(
+ self.shape,
+ dtype=dtype,
+ device=device,
+ **kwargs,
+ ).to(memory_format=memory_format)
+
+ shape, permutation = _get_shape_permutation_like(self)
+ result = rand_fn(
+ shape,
+ dtype=dtype,
+ device=device,
**kwargs,
- ).to(memory_format=get_like_layout(self, memory_format))
+ )
+ if permutation == list(range(len(permutation))):
+ return result
+ return result.permute(permutation).clone()


+@register_decomposition(aten.rand_like)
+def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ return _rand_like(torch.rand, self, **kwargs)
+
+
+@register_decomposition(aten.randn_like)
+def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
+ return _rand_like(torch.randn, self, **kwargs)
+
+
+@register_decomposition(aten.randint_like.default)
+def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
+ return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)
+
@register_decomposition(aten.randint_like.low_dtype)
def randint_like_low(
- self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
-):
- return aten.randint.low(
- low,
- high,
- [*self.size()],
- dtype=dtype or self.dtype,
- device=device or self.device,
- **kwargs,
- ).to(memory_format=get_like_layout(self, memory_format))
+ self: torch.Tensor, low: int, high: int, **kwargs: Any
+) -> torch.Tensor:
+ return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)


@register_decomposition(aten.randint.default)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index e6f2e8d0..dbd9aa28 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -2550,7 +2550,6 @@ def _full(fill_value, device, dtype, size):
)


-@register_lowering(aten.full_like, type_promotion_kind=None)
def full_like(x, fill_value, **kwargs):
return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)

--
2.45.4

Loading
Loading