Skip to content

Commit d0eb418

Browse files
ColinPepplerfacebook-github-bot
authored andcommitted
Fix accuracy issue for double output alias (facebookincubator#950)
Summary: Pull Request resolved: facebookincubator#950 ## Problem Here's an edge case for AIT. Suppose we have two outputs, and both are view on the same tensor. Atm, AIT will not provide accurate results for output0. ``` some-tensor <--view-- output0 ^------view-- ouptut1 void SetUpInputOutput() { input_x = static_cast<decltype(input_x)>(params_[0].ptr); elementwise_0_0 = static_cast<decltype(elementwise_0_0)>(params_[2].ptr); output_0 = elementwise_0_0; output_1 = elementwise_0_0; } void DeviceToDeviceCopies(stream) { // empty } ``` Why doesn't AIT provide accurate results for output0? Because notice how `params_[1]` isn't assigned to anything. ## Solution Use a D2D copy to pass data from `params_[2]` to `params_[1]`. We do this by checking to see if the view is aliased by another output. * If yes, then run a D2D copy. * If no, don't worry about this output. ## Refactor We refactor `_codegen_output_tensor` by combining the `external_tensor` case with the `is_view` case. Differential Revision: D50202241 fbshipit-source-id: 8d61bac9ed2be0b3ba8f9e017b1373e4b0473d34
1 parent 37321c7 commit d0eb418

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

python/aitemplate/backend/codegen.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -631,19 +631,25 @@ def _codegen_output_tensor(self, tensor: Tensor) -> None:
631631
if is_param:
632632
self._codegen_param_setup(tensor)
633633
self.device_to_device_copies.append(device_copy(tensor, tensor, output_idx))
634-
elif external_tensor is not None:
635-
# Special view cases for outputs; we can hit this case if the output
636-
# is a view of a constant, input, or another output.
634+
elif is_view or external_tensor is not None:
637635
assert (
638636
is_view
639-
), f"orig_tensor is not None, but node {name} is not marked as a view! Node: {tensor}"
637+
), f"External tensor is not None, but node {name} is not marked as a view! Node: {tensor}"
638+
view_name = view._attrs["name"]
639+
self.set_inputs.append(set_value(name, view_name))
640640
self.set_inputs.append(
641-
check_not_null(tensor, output_idx, skip_if_lower_bound_is_zero=True)
641+
check_not_null(tensor, skip_if_lower_bound_is_zero=True)
642642
)
643-
self.set_inputs.append(set_value(name, view._attrs["name"]))
644-
self.device_to_device_copies.append(
645-
device_copy(tensor, external_tensor, output_idx)
643+
644+
view_assigned_to_another_output = (
645+
self._get_output_idx(view_name) != output_idx
646646
)
647+
if external_tensor or view_assigned_to_another_output:
648+
# Copy from original tensor so this output can also have the data.
649+
original_tensor = external_tensor if external_tensor else view
650+
self.device_to_device_copies.append(
651+
device_copy(tensor, original_tensor, output_idx)
652+
)
647653
elif is_input:
648654
# Inputs that are also outputs require an extra copy
649655
self.set_inputs.append(
@@ -655,11 +661,6 @@ def _codegen_output_tensor(self, tensor: Tensor) -> None:
655661
self._record_param_tensor_info(tensor, self.input_idx)
656662
self.device_to_device_copies.append(device_copy(tensor, tensor, output_idx))
657663
self.input_idx += 1
658-
elif is_view:
659-
self.set_inputs.append(set_value(name, view._attrs["name"]))
660-
self.set_inputs.append(
661-
check_not_null(tensor, skip_if_lower_bound_is_zero=True)
662-
)
663664
else:
664665
self.set_inputs.append(
665666
set_value(

tests/unittest/backend/test_codegen_output_tensor.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
import unittest
1919
from typing import Sequence
2020

21+
import torch
22+
2123
from aitemplate.backend.codegen import device_copy, set_value
2224

2325
from aitemplate.compiler import compile_model, ops
2426
from aitemplate.compiler.ops.common.epilogue import FuncEnum
2527
from aitemplate.testing import detect_target
2628

27-
from aitemplate.testing.test_utils import gen_input_tensor
29+
from aitemplate.testing.test_utils import gen_input_tensor, get_random_torch_tensor
2830

2931

3032
class TestCodegenOutput(unittest.TestCase):
@@ -102,6 +104,8 @@ def test_double_alias(self, test_name="double_alias"):
102104
Case: Two outputs are a view of the same tensor.
103105
Graph: ( gelu ) <--view-- ( output_0 )
104106
<--view-- ( output_1 )
107+
Expect: If a tensor is a view for multiple outputs, then it's assigned to
108+
only one of the outputs' ptrs. We expect D2D copies for the remaining outputs.
105109
"""
106110
# AIT, two outputs.
107111
x = gen_input_tensor(shape=self.SHAPE, name="input_x")
@@ -114,7 +118,7 @@ def test_double_alias(self, test_name="double_alias"):
114118
output1._attrs["is_output"] = True
115119
output1._attrs["name"] = "output_1"
116120

117-
compile_model(
121+
model = compile_model(
118122
[output0, output1],
119123
detect_target(),
120124
self.WORKDIR,
@@ -135,11 +139,26 @@ def test_double_alias(self, test_name="double_alias"):
135139
expected_codegen = (
136140
set_value("output_0", view_name),
137141
set_value("output_1", view_name),
142+
device_copy(output0, view, dst_idx=1),
138143
)
139144
self._assert_codegen_exists(
140145
test_name, expected_codegen, self.MODEL_GENERATED_FILE
141146
)
142147

148+
# This is an edge case -- test the accuracy.
149+
x_pt = get_random_torch_tensor(self.SHAPE)
150+
gelu_pt = torch.nn.functional.gelu(x_pt)
151+
output0_pt = torch.unsqueeze(gelu_pt, dim=0)
152+
output1_pt = torch.flatten(gelu_pt)
153+
output0_ait = torch.empty_like(output0_pt)
154+
output1_ait = torch.empty_like(output1_pt)
155+
156+
model.run_with_tensors(
157+
{"input_x": x_pt}, {"output_0": output0_ait, "output_1": output1_ait}
158+
)
159+
self.assertTrue(torch.allclose(output0_ait, output0_pt, atol=1e-2, rtol=1e-2))
160+
self.assertTrue(torch.allclose(output1_ait, output1_pt, atol=1e-2, rtol=1e-2))
161+
143162
def test_output_is_view_of_output(self, test_name="output_is_view_of_output"):
144163
"""
145164
Case: An output is a view of an output.

0 commit comments

Comments
 (0)