diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 3fcc23c2573f6e..a238a67b3b71d1 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -547,12 +547,12 @@ void ReplaceAllReduceOp(const Node &node, // ### v1 = op1(grad1) // It is converted to: // ### fused_grad = check_memory_continue(grad0, grad1, grad2, ...) - // ### fused_grad = c_sum_allreduce(fused_grad) + // ### fused_grad = all_reduce(fused_grad) // ### v0 = op0(grad0) // ### v1 = op1(grad1) // We should add the following dependency to ensure that op0 and op1 both run - // after c_sum_allreduce: - // ### grad0 = depend(grad0, fused_grad) + // after all_reduce(sum): + // ### grad0 = depend(grad0, fused_grad) // ### grad1 = depend(grad1, fused_grad) if (is_fused) { for (const auto &in : in_var_handles) {