Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 95 additions & 75 deletions clippy_lints/src/methods/unnecessary_fold.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
use clippy_utils::diagnostics::span_lint_and_sugg;
use clippy_utils::res::{MaybeDef, MaybeResPath, MaybeTypeckRes};
use clippy_utils::res::{MaybeDef, MaybeQPath, MaybeResPath, MaybeTypeckRes};
use clippy_utils::source::snippet_with_applicability;
use clippy_utils::{peel_blocks, strip_pat_refs};
use clippy_utils::{DefinedTy, ExprUseNode, expr_use_ctxt, peel_blocks, strip_pat_refs};
use rustc_ast::ast;
use rustc_data_structures::packed::Pu128;
use rustc_errors::Applicability;
use rustc_hir as hir;
use rustc_hir::PatKind;
use rustc_hir::def::{DefKind, Res};
use rustc_lint::LateContext;
use rustc_middle::ty;
use rustc_span::{Span, sym};
use rustc_span::{Span, Symbol, sym};

use super::UNNECESSARY_FOLD;

/// Do we need to suggest turbofish when suggesting a replacement method?
/// Changing `fold` to `sum` needs it sometimes when the return type can't be
/// inferred. This checks for some common cases where it can be safely omitted
fn needs_turbofish(cx: &LateContext<'_>, expr: &hir::Expr<'_>) -> bool {
let parent = cx.tcx.parent_hir_node(expr.hir_id);

// some common cases where turbofish isn't needed:
// - assigned to a local variable with a type annotation
if let hir::Node::LetStmt(local) = parent
&& local.ty.is_some()
fn needs_turbofish<'tcx>(cx: &LateContext<'tcx>, expr: &hir::Expr<'tcx>) -> bool {
let use_cx = expr_use_ctxt(cx, expr);
if use_cx.same_ctxt
&& let use_node = use_cx.use_node(cx)
&& let Some(ty) = use_node.defined_ty(cx)
{
return false;
}
// some common cases where turbofish isn't needed:
match (use_node, ty) {
// - assigned to a local variable with a type annotation
(ExprUseNode::LetStmt(_), _) => return false,

// - part of a function call argument, can be inferred from the function signature (provided that
// the parameter is not a generic type parameter)
if let hir::Node::Expr(parent_expr) = parent
&& let hir::ExprKind::Call(recv, args) = parent_expr.kind
&& let hir::ExprKind::Path(ref qpath) = recv.kind
&& let Some(fn_def_id) = cx.qpath_res(qpath, recv.hir_id).opt_def_id()
&& let fn_sig = cx.tcx.fn_sig(fn_def_id).skip_binder().skip_binder()
&& let Some(arg_pos) = args.iter().position(|arg| arg.hir_id == expr.hir_id)
&& let Some(ty) = fn_sig.inputs().get(arg_pos)
&& !matches!(ty.kind(), ty::Param(_))
{
return false;
// - part of a function call argument, can be inferred from the function signature (provided that the
// parameter is not a generic type parameter)
(ExprUseNode::FnArg(..), DefinedTy::Mir { ty: arg_ty, .. })
if !matches!(arg_ty.skip_binder().kind(), ty::Param(_)) =>
{
return false;
},

// - the final expression in the body of a function with a simple return type
(ExprUseNode::Return(_), DefinedTy::Mir { ty: fn_return_ty, .. })
if fn_return_ty.skip_binder().is_simple_text() =>
{
return false;
},
_ => {},
}
}

// if it's neither of those, stay on the safe side and suggest turbofish,
Expand All @@ -60,7 +65,7 @@ fn check_fold_with_op(
fold_span: Span,
op: hir::BinOpKind,
replacement: Replacement,
) {
) -> bool {
if let hir::ExprKind::Closure(&hir::Closure { body, .. }) = acc.kind
// Extract the body of the closure passed to fold
&& let closure_body = cx.tcx.hir_body(body)
Expand Down Expand Up @@ -93,7 +98,7 @@ fn check_fold_with_op(
r = snippet_with_applicability(cx, right_expr.span, "EXPR", &mut applicability),
)
} else {
format!("{method}{turbofish}()", method = replacement.method_name,)
format!("{method}{turbofish}()", method = replacement.method_name)
};

span_lint_and_sugg(
Expand All @@ -105,12 +110,47 @@ fn check_fold_with_op(
sugg,
applicability,
);
return true;
}
false
}

pub(super) fn check(
fn check_fold_with_method(
cx: &LateContext<'_>,
expr: &hir::Expr<'_>,
acc: &hir::Expr<'_>,
fold_span: Span,
method: Symbol,
replacement: Replacement,
) {
// Extract the name of the function passed to `fold`
if let Res::Def(DefKind::AssocFn, fn_did) = acc.res_if_named(cx, method)
// Check if the function belongs to the operator
&& cx.tcx.is_diagnostic_item(method, fn_did)
{
let applicability = Applicability::MachineApplicable;

let turbofish = if replacement.has_generic_return {
format!("::<{}>", cx.typeck_results().expr_ty(expr))
} else {
String::new()
};

span_lint_and_sugg(
cx,
UNNECESSARY_FOLD,
fold_span.with_hi(expr.span.hi()),
"this `.fold` can be written more succinctly using another method",
"try",
format!("{method}{turbofish}()", method = replacement.method_name),
applicability,
);
}
}

pub(super) fn check<'tcx>(
cx: &LateContext<'tcx>,
expr: &hir::Expr<'tcx>,
init: &hir::Expr<'_>,
acc: &hir::Expr<'_>,
fold_span: Span,
Expand All @@ -124,60 +164,40 @@ pub(super) fn check(
if let hir::ExprKind::Lit(lit) = init.kind {
match lit.node {
ast::LitKind::Bool(false) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Or,
Replacement {
method_name: "any",
has_args: true,
has_generic_return: false,
},
);
let replacement = Replacement {
method_name: "any",
has_args: true,
has_generic_return: false,
};
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Or, replacement);
},
ast::LitKind::Bool(true) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::And,
Replacement {
method_name: "all",
has_args: true,
has_generic_return: false,
},
);
let replacement = Replacement {
method_name: "all",
has_args: true,
has_generic_return: false,
};
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::And, replacement);
},
ast::LitKind::Int(Pu128(0), _) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Add,
Replacement {
method_name: "sum",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
},
);
let replacement = Replacement {
method_name: "sum",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
};
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Add, replacement) {
check_fold_with_method(cx, expr, acc, fold_span, sym::add, replacement);
}
},
ast::LitKind::Int(Pu128(1), _) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Mul,
Replacement {
method_name: "product",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
},
);
let replacement = Replacement {
method_name: "product",
has_args: false,
has_generic_return: needs_turbofish(cx, expr),
};
if !check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Mul, replacement) {
check_fold_with_method(cx, expr, acc, fold_span, sym::mul, replacement);
}
},
_ => (),
}
Expand Down
3 changes: 1 addition & 2 deletions clippy_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2490,7 +2490,7 @@ pub enum DefinedTy<'tcx> {
/// in the context of its definition site. We also track the `def_id` of its
/// definition site.
///
/// WARNING: As the `ty` in in the scope of the definition, not of the function
/// WARNING: As the `ty` is in the scope of the definition, not of the function
/// using it, you must be very careful with how you use it. Using it in the wrong
/// scope easily results in ICEs.
Mir {
Expand Down Expand Up @@ -2719,7 +2719,6 @@ pub fn expr_use_ctxt<'tcx>(cx: &LateContext<'tcx>, e: &Expr<'tcx>) -> ExprUseCtx
moved_before_use,
same_ctxt,
},
Some(ControlFlow::Break(_)) => unreachable!("type of node is ControlFlow<!>"),
None => ExprUseCtxt {
node: Node::Crate(cx.tcx.hir_root_module()),
child_id: HirId::INVALID,
Expand Down
85 changes: 85 additions & 0 deletions tests/ui/unnecessary_fold.fixed
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,35 @@ fn is_any(acc: bool, x: usize) -> bool {

/// Calls which should trigger the `UNNECESSARY_FOLD` lint
fn unnecessary_fold() {
use std::ops::{Add, Mul};

// Can be replaced by .any
let _ = (0..3).any(|x| x > 2);
//~^ unnecessary_fold

// Can be replaced by .any (checking suggestion)
let _ = (0..3).fold(false, is_any);
//~^ redundant_closure

// Can be replaced by .all
let _ = (0..3).all(|x| x > 2);
//~^ unnecessary_fold

// Can be replaced by .sum
let _: i32 = (0..3).sum();
//~^ unnecessary_fold
let _: i32 = (0..3).sum();
//~^ unnecessary_fold
let _: i32 = (0..3).sum();
//~^ unnecessary_fold

// Can be replaced by .product
let _: i32 = (0..3).product();
//~^ unnecessary_fold
let _: i32 = (0..3).product();
//~^ unnecessary_fold
let _: i32 = (0..3).product();
//~^ unnecessary_fold
}

/// Should trigger the `UNNECESSARY_FOLD` lint, with an error span including exactly `.fold(...)`
Expand All @@ -37,6 +51,43 @@ fn unnecessary_fold_should_ignore() {
let _ = (0..3).fold(0, |acc, x| acc * x);
let _ = (0..3).fold(0, |acc, x| 1 + acc + x);

struct Adder;
impl Adder {
fn add(lhs: i32, rhs: i32) -> i32 {
unimplemented!()
}
fn mul(lhs: i32, rhs: i32) -> i32 {
unimplemented!()
}
}
// `add`/`mul` are inherent methods
let _: i32 = (0..3).fold(0, Adder::add);
let _: i32 = (0..3).fold(1, Adder::mul);

trait FakeAdd<Rhs = Self> {
type Output;
fn add(self, other: Rhs) -> Self::Output;
}
impl FakeAdd for i32 {
type Output = Self;
fn add(self, other: i32) -> Self::Output {
self + other
}
}
trait FakeMul<Rhs = Self> {
type Output;
fn mul(self, other: Rhs) -> Self::Output;
}
impl FakeMul for i32 {
type Output = Self;
fn mul(self, other: i32) -> Self::Output {
self * other
}
}
// `add`/`mul` come from an unrelated trait
let _: i32 = (0..3).fold(0, FakeAdd::add);
let _: i32 = (0..3).fold(1, FakeMul::mul);

// We only match against an accumulator on the left
// hand side. We could lint for .sum and .product when
// it's on the right, but don't for now (and this wouldn't
Expand All @@ -63,6 +114,7 @@ fn unnecessary_fold_over_multiple_lines() {
fn issue10000() {
use std::collections::HashMap;
use std::hash::BuildHasher;
use std::ops::{Add, Mul};

fn anything<T>(_: T) {}
fn num(_: i32) {}
Expand All @@ -74,23 +126,56 @@ fn issue10000() {
// more cases:
let _ = map.values().sum::<i32>();
//~^ unnecessary_fold
let _ = map.values().sum::<i32>();
//~^ unnecessary_fold
let _ = map.values().product::<i32>();
//~^ unnecessary_fold
let _ = map.values().product::<i32>();
//~^ unnecessary_fold
let _: i32 = map.values().sum();
//~^ unnecessary_fold
let _: i32 = map.values().sum();
//~^ unnecessary_fold
let _: i32 = map.values().product();
//~^ unnecessary_fold
let _: i32 = map.values().product();
//~^ unnecessary_fold
anything(map.values().sum::<i32>());
//~^ unnecessary_fold
anything(map.values().sum::<i32>());
//~^ unnecessary_fold
anything(map.values().product::<i32>());
//~^ unnecessary_fold
anything(map.values().product::<i32>());
//~^ unnecessary_fold
num(map.values().sum());
//~^ unnecessary_fold
num(map.values().sum());
//~^ unnecessary_fold
num(map.values().product());
//~^ unnecessary_fold
num(map.values().product());
//~^ unnecessary_fold
}

smoketest_map(HashMap::new());

fn add_turbofish_not_necessary() -> i32 {
(0..3).sum()
//~^ unnecessary_fold
}
fn mul_turbofish_not_necessary() -> i32 {
(0..3).product()
//~^ unnecessary_fold
}
fn add_turbofish_necessary() -> impl Add {
(0..3).sum::<i32>()
//~^ unnecessary_fold
}
fn mul_turbofish_necessary() -> impl Mul {
(0..3).product::<i32>()
//~^ unnecessary_fold
}
}

fn main() {}
Loading