diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a3c7e1bac7..e8b039a2a0 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -65,6 +65,7 @@ def forward( input_: torch.Tensor, fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], + set_output_requires_grad: bool, *params_and_extra_inputs: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -79,6 +80,8 @@ def forward( Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation + set_output_requires_grad: bool + Whether to set ``requires_grad`` flags on returned tensors *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -138,7 +141,8 @@ def forward( ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - y.requires_grad_(idx >= fuser.first_op_requiring_backward) + if set_output_requires_grad: + y.requires_grad_(idx >= fuser.first_op_requiring_backward) extra_outputs[idx] = ys # Flatten list of extra outputs @@ -190,7 +194,8 @@ def forward( for tensor in [x] + extra_outputs_flat: tensor._do_not_clear = True - x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) + if set_output_requires_grad: + x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: return x, *extra_outputs_flat @@ -293,6 +298,7 @@ def backward( dx, # input_ None, # fuser None, # basic_op_kwargs + None, # set_output_requires_grad *grad_params_flat, *grad_extra_inputs_flat, ) @@ -501,20 +507,24 @@ def __call__( op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) # Fuser forward pass - if is_grad_enabled: - forward_func = _OperationFuserAutogradFunction.apply - args = [] - else: - forward_func = _OperationFuserAutogradFunction.forward - args = [None] - args += ( + # When is_grad_enabled is False, we call forward directly. + # This does not register a PyTorch autograd node, so + # no fuser backward will run. We pass set_output_requires_grad=False + # to avoid setting requires_grad on outputs in + # this path since they may be non-leaf tensors from the inner ops. + args = ( input, self, basic_op_kwargs, + is_grad_enabled, # set_output_requires_grad *self._flat_basic_op_params, *extra_inputs, ) - return forward_func(*args) + + if not is_grad_enabled: + return _OperationFuserAutogradFunction.forward(None, *args) + + return _OperationFuserAutogradFunction.apply(*args) def register_forward_fusion(