diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 32228bae1..7ae876e62 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -2,6 +2,9 @@ mod test { use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; + const STACKED_TYPE_GUARDS: usize = 180; + const LARGE_LINEAR_ASSIGNMENT_STEPS: usize = 2048; + #[test] fn test_closure_return() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -92,13 +95,602 @@ mod test { local foo = props.bar() end - if type(props.bar) == 'function' then - local foo = props.bar() + if type(props.bar) == 'function' then + local foo = props.bar() + end + + local foo = props.bar and props.bar() or nil + "# + )); + } + + #[test] + fn test_stacked_same_var_type_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if type(value) ~= 'string' then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + local value ---@type string|integer|boolean + + {repeated_guards} + local narrowed ---@type string + narrowed = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-variable type guard repro" + ); + assert!(ws.check_code_for(DiagnosticCode::AssignTypeMismatch, &block)); + } + + #[test] + fn test_stacked_same_var_truthiness_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if not value then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + local value ---@type string? + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-variable truthiness repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_stacked_same_var_call_type_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not instance_of(value, 'string') then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local value ---@type string|integer|boolean + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-variable call type guard repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_stacked_same_var_call_type_guard_eq_false_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if instance_of(value, 'string') == false then return end\n" + .repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local value ---@type string|integer|boolean + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked binary call type guard repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_branch_join_keeps_union_when_only_one_side_narrows() { + let mut ws = VirtualWorkspace::new(); + let block = r#" + local cond ---@type boolean + local value ---@type string|integer + + if cond then + if type(value) ~= 'string' then + return + end + end + + after_join = value + "#; + + let file_id = ws.def(block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for branch join merge-safety repro" + ); + assert_eq!(ws.expr_ty("after_join"), ws.ty("string|integer")); + } + + #[test] + fn test_stacked_same_field_truthiness_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if not value.foo then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class HasFoo + ---@field foo string + + ---@class NoFoo + ---@field bar integer + + local value ---@type HasFoo|NoFoo + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-field truthiness repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("HasFoo")); + } + + #[test] + fn test_stacked_return_cast_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not is_player(creature) then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + + ---@return boolean + ---@return_cast creature Player else Monster + local function is_player(creature) + return true + end + + local creature ---@type Creature + + {repeated_guards} + after_guard = creature + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked return-cast repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_stacked_return_cast_self_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not creature:is_player() then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + local creature = {{}} + + ---@return boolean + ---@return_cast self Player else Monster + function creature:is_player() + return true + end + + {repeated_guards} + after_guard = creature + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked self return-cast repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_large_linear_assignment_file_builds_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let mut block = String::from( + r#" + local value ---@type integer + value = 1 + + "#, + ); + + for i in 0..LARGE_LINEAR_ASSIGNMENT_STEPS { + block.push_str(&format!("local alias_{i} = value\n")); + block.push_str(&format!("value = alias_{i}\n")); + } + block.push_str("after_assign = value\n"); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for large linear assignment stress case" + ); + let after_assign = ws.expr_ty("after_assign"); + assert_eq!(ws.humanize_type(after_assign), "integer"); + } + + #[test] + fn test_pending_replay_order_uses_type_guard_before_self_return_cast_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = {} + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local branch ---@type boolean + local creature = branch and checker or false + + if type(creature) ~= "table" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_pending_replay_order_with_three_guards_before_self_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class PlayerA + + ---@class MonsterA + + ---@class PlayerB + + ---@class MonsterB + + local checker_a = { + kind = "checker_a", + } + + ---@return boolean + ---@return_cast self PlayerA else MonsterA + function checker_a:is_player() + return true + end + + local checker_b = { + kind = "checker_b", + } + + ---@return boolean + ---@return_cast self PlayerB else MonsterB + function checker_b:is_player() + return true + end + + local allow_false ---@type boolean + local choose_a ---@type boolean + local creature = allow_false and false or (choose_a and checker_a or checker_b) + + if type(creature) ~= "table" then + return + end + + if creature.kind ~= "checker_a" then + return + end + + if creature:is_player() == false then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("PlayerA")); + } + + #[test] + fn test_return_cast_self_guard_uses_prior_narrowing_for_method_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + } + + local branch ---@type boolean + local creature = branch and checker or monster + + if creature.kind ~= "checker" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_return_cast_self_guard_without_prior_method_lookup_narrowing_does_not_apply() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + } + + local branch ---@type boolean + local creature = branch and checker or monster + before_guard = creature + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.expr_ty("before_guard")); + } + + #[test] + fn test_return_cast_self_guard_with_multiple_method_candidates_uses_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class PlayerA + + ---@class MonsterA + + ---@class PlayerB + + ---@class MonsterB + + local checker_a = { + kind = "checker_a", + } + + ---@return boolean + ---@return_cast self PlayerA else MonsterA + function checker_a:is_player() + return true + end + + local checker_b = { + kind = "checker_b", + } + + ---@return boolean + ---@return_cast self PlayerB else MonsterB + function checker_b:is_player() + return true + end + + local branch ---@type boolean + local creature = branch and checker_a or checker_b + + if creature.kind ~= "checker_a" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("PlayerA")); + } + + #[test] + fn test_return_cast_self_guard_with_non_callable_member_uses_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + is_player = false, + } + + local branch ---@type boolean + local creature = branch and checker or monster + + if creature.kind ~= "checker" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_return_cast_self_guard_eq_false_uses_prior_narrowing_for_method_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + is_player = false, + } + + local branch ---@type boolean + local creature = branch and checker or monster + + if creature.kind ~= "checker" then + return + end + + if creature:is_player() == false then + return end - local foo = props.bar and props.bar() or nil - "# - )); + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); } #[test] @@ -627,6 +1219,111 @@ end assert_eq!(b, LuaType::String); } + #[test] + fn test_while_loop_post_flow_keeps_incoming_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local condition ---@type boolean + local value ---@type string? + + while condition do + value = "loop" + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + + #[test] + fn test_repeat_loop_post_flow_keeps_body_assignment() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local condition ---@type boolean + local value ---@type string? + + repeat + value = "loop" + until condition + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string")); + } + + #[test] + fn test_numeric_for_loop_post_flow_keeps_incoming_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local value ---@type string? + + for i = 1, 3 do + value = "loop" + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + + #[test] + fn test_for_in_loop_post_flow_keeps_incoming_type_after_break() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + local value ---@type string? + + for _, _value in ipairs({ "loop" }) do + value = "loop" + break + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + + #[test] + fn test_nested_while_loop_post_flow_keeps_incoming_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local outer_condition ---@type boolean + local inner_condition ---@type boolean + local value ---@type string? + + while outer_condition do + while inner_condition do + value = "loop" + break + end + + break + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + #[test] fn test_issue_347() { let mut ws = VirtualWorkspace::new(); @@ -1787,4 +2484,305 @@ _2 = a[1] let type_str = ws.humanize_type_detailed(e_ty); assert_eq!(type_str, "(A|B|table)"); } + + #[test] + fn test_assignment_from_wider_single_return_call_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param ok boolean + ---@return Foo|Bar + local function pick(ok) + if ok then + return { kind = "foo", a = 1 } + end + + return { kind = "bar", b = 2 } + end + + local ok ---@type boolean + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = pick(ok) + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo|Bar")); + } + + #[test] + fn test_assignment_after_pending_return_cast_guard_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + + ---@param creature Creature + ---@return boolean + ---@return_cast creature Player else Monster + local function is_player(creature) + return true + end + + local creature ---@type Creature + local next_creature ---@type Creature + + if not is_player(creature) then + return + end + + before_assign = creature + creature = next_creature + after_assign = creature + "#, + ); + + assert_eq!(ws.expr_ty("before_assign"), ws.ty("Player")); + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Creature")); + } + + #[test] + fn test_assignment_after_binary_call_guard_eq_false_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local value ---@type string|integer|boolean + local next_value ---@type string|integer|boolean + + if instance_of(value, 'string') == false then + return + end + + before_assign = value + value = next_value + after_assign = value + "#, + ); + + assert_eq!(ws.expr_ty("before_assign"), ws.ty("string")); + assert_eq!(ws.expr_ty("after_assign"), ws.ty("string|integer|boolean")); + } + + #[test] + fn test_assignment_after_mixed_eager_and_pending_guards_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Player + ---@field kind "player" + + ---@class Monster + ---@field kind "monster" + + ---@param creature Player|Monster + ---@return boolean + ---@return_cast creature Player else Monster + local function is_player(creature) + return true + end + + local creature ---@type Player|Monster + local next_creature ---@type Player|Monster + + if creature.kind ~= "player" then + return + end + + if not is_player(creature) then + return + end + + before_assign = creature + creature = next_creature + after_assign = creature + "#, + ); + + assert_eq!(ws.expr_ty("before_assign"), ws.ty("Player")); + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Player|Monster")); + } + + #[test] + fn test_assignment_from_nullable_union_keeps_rhs_members() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string? + local y ---@type number? + + if x then + x = y + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("number?")); + } + + #[test] + fn test_assignment_from_partially_overlapping_union_keeps_rhs_members() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string|number + local y ---@type integer|string + + if x == 1 then + x = y + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("integer|string")); + } + + #[test] + fn test_partial_table_reassignment_preserves_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = {} + x.kind = "foo" + x.a = 1 + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo")); + } + + #[test] + fn test_partial_table_reassignment_with_discriminant_preserves_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = { kind = "foo" } + x.a = 1 + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo")); + } + + #[test] + fn test_exact_string_reassignment_preserves_literal_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string|number + + if x == 1 then + x = "a" + after_assign = x + end + "#, + ); + + let after_assign = ws.expr_ty("after_assign"); + assert_eq!(ws.humanize_type(after_assign), r#""a""#); + } + + #[test] + fn test_assignment_from_broad_string_drops_literal_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type "a"|boolean + local y ---@type string + + if x == "a" then + x = y + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("string")); + } + + #[test] + fn test_partial_table_reassignment_with_conflicting_discriminant_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = { kind = "bar" } + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo|Bar")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 5cbb83f2b..7700022c1 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -2,6 +2,8 @@ mod test { use crate::{DiagnosticCode, VirtualWorkspace}; + const STACKED_PCALL_ALIAS_GUARDS: usize = 180; + #[test] fn test_issue_263() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -141,4 +143,36 @@ mod test { assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string")); assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); } + + #[test] + fn test_pcall_stacked_alias_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + let repeated_guards = + "if failed then error(result) end\n".repeat(STACKED_PCALL_ALIAS_GUARDS); + let block = format!( + r#" + ---@return integer + local function foo() + return 1 + end + + local ok, result = pcall(foo) + local failed = ok == false + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked pcall alias guard repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs index f0a74b505..0cf6b74a8 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs @@ -2,6 +2,8 @@ mod test { use crate::{DiagnosticCode, VirtualWorkspace}; + const STACKED_CORRELATED_GUARDS: usize = 180; + #[test] fn test_return_overload_narrow_after_not() { let mut ws = VirtualWorkspace::new(); @@ -37,6 +39,44 @@ mod test { assert_eq!(ws.expr_ty("a"), ws.ty("integer")); } + #[test] + fn test_return_overload_narrow_tracks_multiple_targets_from_same_call() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return integer|string + ---@return string|boolean + ---@return_overload true, integer, string + ---@return_overload false, string, boolean + local function pick(ok) + if ok then + return true, 1, "value" + end + return false, "error", false + end + + local cond ---@type boolean + local ok, result, extra = pick(cond) + + if ok then + a = result + b = extra + else + c = result + d = extra + end + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer")); + assert_eq!(ws.expr_ty("b"), ws.ty("string")); + assert_eq!(ws.expr_ty("c"), ws.ty("string")); + assert_eq!(ws.expr_ty("d"), ws.ty("boolean")); + } + #[test] fn test_return_overload_reassign_clears_multi_return_mapping() { let mut ws = VirtualWorkspace::new(); @@ -541,4 +581,610 @@ mod test { assert!(after_guard.contains("integer")); assert!(!after_guard.contains("string")); } + + #[test] + fn test_return_overload_stacked_same_discriminant_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not ok then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked correlated-guard repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_stacked_eq_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if ok == false then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked correlated-eq repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_stacked_mixed_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if ok == false then error(result) end\nif not ok then error(result) end\n" + .repeat(STACKED_CORRELATED_GUARDS / 2); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked mixed correlated-guard repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_stacked_noncorrelated_origin_guards_keep_extra_type() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not ok then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond, 1, "err") + + if branch then + result = false + end + + {repeated_guards} + after_guard = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked noncorrelated correlated-guard repro" + ); + let after_guard_ty = ws.expr_ty("after_guard"); + let after_guard = ws.humanize_type(after_guard_ty); + assert!(after_guard.contains("false")); + assert!(after_guard.contains("integer")); + assert!(!after_guard.contains("string")); + } + + #[test] + fn test_return_overload_uncorrelated_later_guard_keeps_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "err") + + if not ok then + error(result) + end + + ok = cond + + if not ok then + error(result) + end + + narrowed = result + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_unmatched_discriminant_call_keeps_target_wide() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return integer|string + ---@return_overload true, integer + ---@return_overload false, string + local function pick(ok) + if ok then + return true, 1 + end + return false, "err" + end + + ---@param ok boolean + ---@return boolean + local function bounce(ok) + return ok + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond) + + if branch then + ok = bounce(other) + end + + if not ok then + error(result) + end + + after_guard = result + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_unmatched_target_call_keeps_guard_union() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return boolean|table + ---@return_overload true, boolean + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, true + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("boolean|table|integer")); + } + + #[test] + fn test_return_overload_unmatched_target_root_then_truthiness_guard() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return false|table + ---@return_overload true, false + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, false + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + after_guard = result + + if result then + truthy = result + end + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("false|table|integer")); + assert_eq!(ws.expr_ty("truthy"), ws.ty("table|integer")); + } + + #[test] + fn test_return_overload_unmatched_target_root_then_type_guard() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return false|table + ---@return_overload true, false + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, false + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + after_guard = result + + if type(result) == "table" then + table_result = result + end + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("false|table|integer")); + let table_result = ws.expr_ty("table_result"); + assert_eq!(ws.humanize_type(table_result), "table"); + } + + #[test] + fn test_return_overload_post_guard_reassign_clears_mixed_root_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return false|table + ---@return_overload true, false + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, false + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local next_result ---@type string + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + before_reassign = result + result = next_result + after_reassign = result + end + "#, + ); + + assert_eq!(ws.expr_ty("before_reassign"), ws.ty("false|table|integer")); + assert_eq!(ws.expr_ty("after_reassign"), ws.ty("string")); + } + + #[test] + fn test_return_overload_reassign_from_fresh_call_ignores_prior_guard() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond, 1, "err") + + if not ok then + error(result) + end + + if branch then + ok, result = pick(cond, "x", 2) + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_branch_reassign_to_different_call_preserves_matching_root_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return "right_ok"|"right_err" + ---@return boolean|table + ---@return_overload "right_ok", boolean + ---@return_overload "right_err", table + local function pick_right(ok) + if ok then + return "right_ok", true + end + return "right_err", {} + end + + local cond ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + tag, result = pick_right(cond) + end + + at_join = result + + if tag == "left_ok" then + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("at_join"), ws.ty("boolean|table|integer|string")); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_branch_reassign_to_different_call_narrows_alternate_matching_root() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return "right_ok"|"right_err" + ---@return boolean|table + ---@return_overload "right_ok", boolean + ---@return_overload "right_err", table + local function pick_right(ok) + if ok then + return "right_ok", true + end + return "right_err", {} + end + + local cond ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + tag, result = pick_right(cond) + end + + if tag == "right_ok" then + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("boolean")); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs index fa8e9c209..f67193db2 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs @@ -1184,4 +1184,64 @@ return t "#, )); } + + #[test] + fn test_exact_string_reassignment_in_narrowed_branch_keeps_assign_literal() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + local x ---@type string|number + + if x == 1 then + x = "a" + + ---@type "a" + local y = x + end + "#, + )); + } + + #[test] + fn test_return_overload_mixed_guards_keep_assign_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + ---@param cond boolean + local function test(cond) + local ok, result = pick(cond, 1, "err") + + if ok == false then + error(result) + end + + if not ok then + error(result) + end + + ---@type integer + local narrowed = result + end + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs index 1fcb2438c..bdf6cd523 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs @@ -131,6 +131,158 @@ mod tests { )); } + #[test] + fn test_discriminated_union_assignment_keeps_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param x Foo|Bar + ---@return Foo + local function test(x) + if x.kind == "foo" then + x = { kind = "foo", a = 1 } + return x + end + + return { kind = "foo", a = 2 } + end + "# + )); + } + + #[test] + fn test_discriminated_union_partial_assignment_keeps_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param x Foo|Bar + ---@return Foo + local function test(x) + if x.kind == "foo" then + x = {} + x.kind = "foo" + x.a = 1 + return x + end + + return { kind = "foo", a = 2 } + end + "# + )); + } + + #[test] + fn test_discriminated_union_partial_literal_assignment_keeps_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param x Foo|Bar + ---@return Foo + local function test(x) + if x.kind == "foo" then + x = { kind = "foo" } + x.a = 1 + return x + end + + return { kind = "foo", a = 2 } + end + "# + )); + } + + #[test] + fn test_exact_string_reassignment_in_narrowed_branch_keeps_return_literal() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@param x string|number + ---@return "a" + local function test(x) + if x == 1 then + x = "a" + return x + end + + return "a" + end + "# + )); + } + + #[test] + fn test_return_overload_mixed_guards_keep_return_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + ---@param cond boolean + ---@return integer + local function test(cond) + local ok, result = pick(cond, 1, "err") + + if ok == false then + error(result) + end + + if not ok then + error(result) + end + + return result + end + "# + )); + } + #[test] fn test_variadic_return_type_mismatch() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs index a11a2f6d4..f7ee7d8d0 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs @@ -7,7 +7,11 @@ use std::{ sync::Arc, }; -use crate::{FileId, FlowId, LuaFunctionType, db_index::LuaType, semantic::infer::VarRefId}; +use crate::{ + FileId, FlowId, LuaFunctionType, + db_index::LuaType, + semantic::infer::{ConditionFlowAction, VarRefId}, +}; #[derive(Debug)] pub enum CacheEntry { @@ -22,7 +26,9 @@ pub struct LuaInferCache { pub expr_cache: HashMap>, pub call_cache: HashMap<(LuaSyntaxId, Option, LuaType), CacheEntry>>, - pub flow_node_cache: HashMap<(VarRefId, FlowId), CacheEntry>, + pub(crate) flow_node_cache: HashMap<(VarRefId, FlowId, bool), CacheEntry>, + pub(in crate::semantic) condition_flow_cache: + HashMap<(VarRefId, FlowId, bool), CacheEntry>, pub index_ref_origin_type_cache: HashMap>, pub expr_var_ref_id_cache: HashMap, pub narrow_by_literal_stop_position_cache: HashSet, @@ -36,6 +42,7 @@ impl LuaInferCache { expr_cache: HashMap::new(), call_cache: HashMap::new(), flow_node_cache: HashMap::new(), + condition_flow_cache: HashMap::new(), index_ref_origin_type_cache: HashMap::new(), expr_var_ref_id_cache: HashMap::new(), narrow_by_literal_stop_position_cache: HashSet::new(), @@ -58,6 +65,7 @@ impl LuaInferCache { self.expr_cache.clear(); self.call_cache.clear(); self.flow_node_cache.clear(); + self.condition_flow_cache.clear(); self.index_ref_origin_type_cache.clear(); self.expr_var_ref_id_cache.clear(); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index e30b8477c..2036d2c9c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -26,6 +26,7 @@ pub use infer_name::{find_self_decl_or_member_id, infer_param}; use infer_table::infer_table_expr; pub use infer_table::{infer_table_field_value_should_be, infer_table_should_be}; use infer_unary::infer_unary_expr; +pub(in crate::semantic) use narrow::ConditionFlowAction; pub use narrow::VarRefId; use rowan::TextRange; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index f8052cadf..a570878af 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -12,8 +12,9 @@ use crate::{ narrow::{ ResultTypeOrContinue, condition_flow::{ - InferConditionFlow, always_literal_equal, call_flow::get_type_at_call_expr, - correlated_flow::narrow_var_from_return_overload_condition, + ConditionFlowAction, InferConditionFlow, PendingConditionNarrow, + always_literal_equal, call_flow::get_type_at_call_expr, + correlated_flow::prepare_var_from_return_overload_condition, }, get_single_antecedent, get_type_at_flow::get_type_at_flow, @@ -33,13 +34,13 @@ pub fn get_type_at_binary_expr( flow_node: &FlowNode, binary_expr: LuaBinaryExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(op_token) = binary_expr.get_op_token() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some((left_expr, right_expr)) = binary_expr.get_exprs() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match op_token.get_op() { @@ -76,7 +77,8 @@ pub fn get_type_at_binary_expr( right_expr, condition_flow, true, - ), + ) + .map(Into::into), BinaryOperator::OpGe => try_get_at_gt_or_ge_expr( db, tree, @@ -88,8 +90,9 @@ pub fn get_type_at_binary_expr( right_expr, condition_flow, false, - ), - _ => Ok(ResultTypeOrContinue::Continue), + ) + .map(Into::into), + _ => Ok(ConditionFlowAction::Continue), } } @@ -104,8 +107,8 @@ fn try_get_at_eq_or_neq_expr( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { - if let ResultTypeOrContinue::Result(result_type) = maybe_type_guard_binary( +) -> Result { + if let Some(action) = maybe_type_guard_binary_action( db, tree, cache, @@ -116,21 +119,7 @@ fn try_get_at_eq_or_neq_expr( right_expr.clone(), condition_flow, )? { - return Ok(ResultTypeOrContinue::Result(result_type)); - } - - if let ResultTypeOrContinue::Result(result_type) = maybe_field_literal_eq_narrow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - left_expr.clone(), - right_expr.clone(), - condition_flow, - )? { - return Ok(ResultTypeOrContinue::Result(result_type)); + return Ok(action); } let (left_expr, right_expr) = if !matches!( @@ -145,7 +134,21 @@ fn try_get_at_eq_or_neq_expr( (left_expr, right_expr) }; - maybe_var_eq_narrow( + if let Some(action) = maybe_field_literal_eq_action( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + left_expr.clone(), + right_expr.clone(), + condition_flow, + )? { + return Ok(action); + } + + get_var_eq_condition_action( db, tree, cache, @@ -196,7 +199,7 @@ fn try_get_at_gt_or_ge_expr( } let right_expr_type = infer_expr(db, cache, right_expr)?; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; match (&antecedent_type, &right_expr_type) { @@ -225,7 +228,7 @@ fn try_get_at_gt_or_ge_expr( } #[allow(clippy::too_many_arguments)] -fn maybe_type_guard_binary( +fn maybe_type_guard_binary_action( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -235,7 +238,7 @@ fn maybe_type_guard_binary( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result, InferFailReason> { let (candidate_expr, literal_expr) = match (left_expr, right_expr) { // If either side is a literal expression and the other side is a type guard call expression // (or ref), we can narrow it @@ -249,7 +252,7 @@ fn maybe_type_guard_binary( let (Some(candidate_expr), Some(LuaLiteralToken::String(literal_string))) = (candidate_expr, literal_expr.and_then(|e| e.get_literal())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let candidate_expr = match candidate_expr { @@ -266,53 +269,69 @@ fn maybe_type_guard_binary( LuaExpr::CallExpr(call_expr) if call_expr.is_type() => Some(call_expr), _ => None, }) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(narrow) = type_call_name_to_type(&literal_string.get_value()) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(arg) = type_guard_expr .get_args_list() .and_then(|arg_list| arg_list.get_args().next()) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, arg) else { // If we cannot find a reference declaration ID, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); + }; + + if maybe_var_ref_id == *var_ref_id { + return Ok(Some(ConditionFlowAction::pending( + PendingConditionNarrow::TypeGuard { + narrow, + condition_flow, + }, + ))); + } + + let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { + return Ok(None); + }; + let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { + return Ok(None); }; + if !tree.has_decl_multi_return_refs(&discriminant_decl_id) + || !tree.has_decl_multi_return_refs(&target_decl_id) + { + return Ok(None); + } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, &maybe_var_ref_id, antecedent_flow_id)?; let narrowed_discriminant_type = match condition_flow { InferConditionFlow::TrueCondition => { - narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow) + narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow.clone()) } InferConditionFlow::FalseCondition => TypeOps::Remove.apply(db, &antecedent_type, &narrow), }; - if maybe_var_ref_id == *var_ref_id { - Ok(ResultTypeOrContinue::Result(narrowed_discriminant_type)) - } else { - let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { - return Ok(ResultTypeOrContinue::Continue); - }; - narrow_var_from_return_overload_condition( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - discriminant_decl_id, - type_guard_expr.get_position(), - &narrowed_discriminant_type, - ) - } + Ok(prepare_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + discriminant_decl_id, + type_guard_expr.get_position(), + &narrowed_discriminant_type, + )? + .map(PendingConditionNarrow::Correlated) + .map(ConditionFlowAction::pending)) } /// Maps the string result of Lua's builtin `type()` call to the corresponding `LuaType`. @@ -335,12 +354,31 @@ fn narrow_eq_condition( antecedent_type: LuaType, right_expr_type: LuaType, condition_flow: InferConditionFlow, + allow_literal_equivalence: bool, ) -> LuaType { match condition_flow { InferConditionFlow::TrueCondition => { let left_maybe_type = TypeOps::Intersect.apply(db, &antecedent_type, &right_expr_type); if left_maybe_type.is_never() { + if allow_literal_equivalence { + let literal_matches = match &antecedent_type { + LuaType::Union(union) => union + .into_vec() + .into_iter() + .filter(|candidate| always_literal_equal(candidate, &right_expr_type)) + .collect::>(), + _ if always_literal_equal(&antecedent_type, &right_expr_type) => { + vec![antecedent_type.clone()] + } + _ => Vec::new(), + }; + + if !literal_matches.is_empty() { + return LuaType::from_vec(literal_matches); + } + } + antecedent_type } else { left_maybe_type @@ -353,7 +391,7 @@ fn narrow_eq_condition( } #[allow(clippy::too_many_arguments)] -fn maybe_var_eq_narrow( +fn get_var_eq_condition_action( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -363,28 +401,35 @@ fn maybe_var_eq_narrow( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { // only check left as need narrow match left_expr { LuaExpr::NameExpr(left_name_expr) => { let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::NameExpr(left_name_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let right_expr_type = infer_expr(db, cache, right_expr)?; - if maybe_ref_id != *var_ref_id { let Some(discriminant_decl_id) = maybe_ref_id.get_decl_id_ref() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); + }; + let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { + return Ok(ConditionFlowAction::Continue); }; + if !tree.has_decl_multi_return_refs(&discriminant_decl_id) + || !tree.has_decl_multi_return_refs(&target_decl_id) + { + return Ok(ConditionFlowAction::Continue); + } + let antecedent_flow_id = get_single_antecedent(flow_node)?; + let right_expr_type = infer_expr(db, cache, right_expr)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, &maybe_ref_id, antecedent_flow_id)?; let narrowed_discriminant_type = - narrow_eq_condition(db, antecedent_type, right_expr_type, condition_flow); - return narrow_var_from_return_overload_condition( + narrow_eq_condition(db, antecedent_type, right_expr_type, condition_flow, true); + return Ok(prepare_var_from_return_overload_condition( db, tree, cache, @@ -394,26 +439,33 @@ fn maybe_var_eq_narrow( discriminant_decl_id, left_name_expr.get_position(), &narrowed_discriminant_type, - ); + )? + .map(PendingConditionNarrow::Correlated) + .map(ConditionFlowAction::pending) + .unwrap_or(ConditionFlowAction::Continue)); } - let left_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - + let right_expr_type = infer_expr(db, cache, right_expr)?; let result_type = match condition_flow { InferConditionFlow::TrueCondition => { // self 是特殊的, 我们删除其 nil 类型 if var_ref_id.is_self_ref() && !right_expr_type.is_nil() { TypeOps::Remove.apply(db, &right_expr_type, &LuaType::Nil) } else { - narrow_eq_condition(db, left_type, right_expr_type, condition_flow) + return Ok(ConditionFlowAction::pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + })); } } InferConditionFlow::FalseCondition => { - TypeOps::Remove.apply(db, &left_type, &right_expr_type) + return Ok(ConditionFlowAction::pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + })); } }; - Ok(ResultTypeOrContinue::Result(result_type)) + Ok(ConditionFlowAction::Result(result_type)) } LuaExpr::CallExpr(left_call_expr) => { if let LuaExpr::LiteralExpr(literal_expr) = right_expr { @@ -425,72 +477,61 @@ fn maybe_var_eq_narrow( condition_flow.get_negated() }; - return get_type_at_call_expr( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - left_call_expr, - flow, - ); + return get_type_at_call_expr(db, cache, var_ref_id, left_call_expr, flow); } - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(ConditionFlowAction::Continue), } }; - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } LuaExpr::IndexExpr(left_index_expr) => { let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(left_index_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if maybe_ref_id != *var_ref_id { // If the reference declaration ID does not match, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } let right_expr_type = infer_expr(db, cache, right_expr)?; - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => right_expr_type, - InferConditionFlow::FalseCondition => { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - TypeOps::Remove.apply(db, &antecedent_type, &right_expr_type) - } - }; - Ok(ResultTypeOrContinue::Result(result_type)) + if condition_flow.is_false() { + return Ok(ConditionFlowAction::pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + })); + } + + Ok(ConditionFlowAction::Result(right_expr_type)) } LuaExpr::UnaryExpr(unary_expr) => { let Some(op) = unary_expr.get_op_token() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match op.get_op() { UnaryOperator::OpLen => {} - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(ConditionFlowAction::Continue), }; let Some(expr) = unary_expr.get_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, expr) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if maybe_ref_id != *var_ref_id { // If the reference declaration ID does not match, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } let right_expr_type = infer_expr(db, cache, right_expr)?; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; match (&antecedent_type, &right_expr_type) { @@ -501,25 +542,25 @@ fn maybe_var_eq_narrow( if condition_flow.is_true() { let new_array_type = LuaArrayType::new(array_type.get_base().clone(), LuaArrayLen::Max(*i)); - return Ok(ResultTypeOrContinue::Result(LuaType::Array( + return Ok(ConditionFlowAction::Result(LuaType::Array( new_array_type.into(), ))); } } - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(ConditionFlowAction::Continue), } - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } _ => { // If the left expression is not a name or call expression, we cannot narrow it - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } } } #[allow(clippy::too_many_arguments)] -fn maybe_field_literal_eq_narrow( +fn maybe_field_literal_eq_action( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -529,7 +570,7 @@ fn maybe_field_literal_eq_narrow( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result, InferFailReason> { // only check left as need narrow let syntax_id = left_expr.get_syntax_id(); let (index_expr, literal_expr) = match (left_expr, right_expr) { @@ -539,16 +580,16 @@ fn maybe_field_literal_eq_narrow( (LuaExpr::LiteralExpr(literal_expr), LuaExpr::IndexExpr(index_expr)) => { (index_expr, literal_expr) } - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(None), }; let Some(prefix_expr) = index_expr.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, prefix_expr.clone()) else { // If we cannot find a reference declaration ID, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; if maybe_var_ref_id != *var_ref_id { @@ -557,18 +598,18 @@ fn maybe_field_literal_eq_narrow( .contains(&syntax_id) && var_ref_id.start_with(&maybe_var_ref_id) { - return Ok(ResultTypeOrContinue::Result(get_var_ref_type( + return Ok(Some(ConditionFlowAction::Result(get_var_ref_type( db, cache, var_ref_id, - )?)); + )?))); } - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let left_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; let LuaType::Union(union_type) = left_type else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; cache @@ -599,16 +640,18 @@ fn maybe_field_literal_eq_narrow( match condition_flow { InferConditionFlow::TrueCondition => { if let Some(i) = opt_result { - return Ok(ResultTypeOrContinue::Result(union_types[i].clone())); + return Ok(Some(ConditionFlowAction::Result(union_types[i].clone()))); } } InferConditionFlow::FalseCondition => { if let Some(i) = opt_result { union_types.remove(i); - return Ok(ResultTypeOrContinue::Result(LuaType::from_vec(union_types))); + return Ok(Some(ConditionFlowAction::Result(LuaType::from_vec( + union_types, + )))); } } } - Ok(ResultTypeOrContinue::Continue) + Ok(None) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index 8ad1b8b0e..0e6cdb03c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -1,61 +1,88 @@ use std::{ops::Deref, sync::Arc}; -use emmylua_parser::{LuaCallExpr, LuaChunk, LuaExpr}; +use emmylua_parser::{LuaCallExpr, LuaExpr, LuaIndexMemberExpr}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaAliasCallKind, LuaAliasCallType, - LuaFunctionType, LuaInferCache, LuaSignatureCast, LuaSignatureId, LuaType, TypeOps, - infer_call_expr_func, infer_expr, + DbIndex, InferFailReason, InferGuard, LuaAliasCallKind, LuaAliasCallType, LuaFunctionType, + LuaInferCache, LuaSignatureId, LuaType, infer_call_expr_func, infer_expr, semantic::infer::{ VarRefId, + infer_index::infer_member_by_member_key, narrow::{ - ResultTypeOrContinue, condition_flow::InferConditionFlow, get_single_antecedent, - get_type_at_cast_flow::cast_type, get_type_at_flow::get_type_at_flow, - narrow_false_or_nil, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id, + ResultTypeOrContinue, + condition_flow::{ConditionFlowAction, InferConditionFlow, PendingConditionNarrow}, + get_var_ref_type, narrow_false_or_nil, remove_false_or_nil, + var_ref_id::get_var_expr_var_ref_id, }, }, }; -#[allow(clippy::too_many_arguments)] pub fn get_type_at_call_expr( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(prefix_expr) = call_expr.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - let maybe_func = infer_expr(db, cache, prefix_expr.clone())?; + let maybe_func = if call_expr.is_colon_call() { + match &prefix_expr { + LuaExpr::IndexExpr(index_expr) => { + if let Some(self_expr) = index_expr.get_prefix_expr() + && let Some(self_var_ref_id) = get_var_expr_var_ref_id(db, cache, self_expr) + && self_var_ref_id == *var_ref_id + { + let self_type = get_var_ref_type(db, cache, var_ref_id)?; + let member_type = infer_member_by_member_key( + db, + cache, + &self_type, + LuaIndexMemberExpr::IndexExpr(index_expr.clone()), + &InferGuard::new(), + )?; + + if needs_antecedent_same_var_colon_lookup(&member_type) { + // Keep the dedicated pending case here: replay needs the antecedent type + // for member lookup itself, not just for applying a cast after lookup. + return Ok(ConditionFlowAction::pending( + PendingConditionNarrow::SameVarColonCall { + index: LuaIndexMemberExpr::IndexExpr(index_expr.clone()), + condition_flow, + }, + )); + } else { + member_type + } + } else { + infer_expr(db, cache, prefix_expr.clone())? + } + } + _ => infer_expr(db, cache, prefix_expr.clone())?, + } + } else { + infer_expr(db, cache, prefix_expr.clone())? + }; match maybe_func { LuaType::DocFunction(f) => { let return_type = f.get_ret(); match return_type { LuaType::TypeGuard(_) => get_type_at_call_expr_by_type_guard( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, f, condition_flow, ), - _ => { - // If the return type is not a type guard, we cannot infer the type cast. - Ok(ResultTypeOrContinue::Continue) - } + _ => Ok(ConditionFlowAction::Continue), } } LuaType::Signature(signature_id) => { let Some(signature) = db.get_signature_index().get(&signature_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let ret = signature.get_return_type(); @@ -63,99 +90,92 @@ pub fn get_type_at_call_expr( LuaType::TypeGuard(_) => { return get_type_at_call_expr_by_type_guard( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, signature.to_doc_func_type(), condition_flow, ); } LuaType::Call(call) => { - return get_type_at_call_expr_by_call( + return Ok(get_type_at_call_expr_by_call( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, &call, condition_flow, - ); + )? + .into()); } _ => {} } let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match signature_cast.name.as_str() { "self" => get_type_at_call_expr_by_signature_self( db, - tree, cache, - root, var_ref_id, - flow_node, prefix_expr, - signature_cast, signature_id, condition_flow, ), name => get_type_at_call_expr_by_signature_param_name( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, - signature_cast, signature_id, name, condition_flow, ), } } - _ => { - // If the prefix expression is not a function, we cannot infer the type cast. - Ok(ResultTypeOrContinue::Continue) - } + _ => Ok(ConditionFlowAction::Continue), } } -#[allow(clippy::too_many_arguments)] -fn get_type_at_call_expr_by_type_guard( +fn needs_antecedent_same_var_colon_lookup(member_type: &LuaType) -> bool { + let candidate_members = match member_type { + LuaType::Union(union_type) => union_type.into_vec(), + LuaType::MultiLineUnion(multi_union) => match multi_union.to_union() { + LuaType::Union(union_type) => union_type.into_vec(), + _ => return false, + }, + _ => return false, + }; + + candidate_members.len() > 1 + && candidate_members.iter().any(|ty| { + matches!( + ty, + LuaType::DocFunction(_) | LuaType::Signature(_) | LuaType::Call(_) + ) + }) +} + +fn get_type_guard_call_info( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, - var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, func_type: Arc, - condition_flow: InferConditionFlow, -) -> Result { +) -> Result, InferFailReason> { let Some(arg_list) = call_expr.get_args_list() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(first_arg) = arg_list.get_args().next() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, first_arg) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; - if maybe_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); - } - let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { let call_expr_type = LuaType::DocFunction(func_type); @@ -171,142 +191,94 @@ fn get_type_at_call_expr_by_type_guard( return_type = inst_func.get_ret().clone(); } - let guard_type = match return_type { - LuaType::TypeGuard(guard) => guard.deref().clone(), - _ => return Ok(ResultTypeOrContinue::Continue), + let LuaType::TypeGuard(guard) = return_type else { + return Ok(None); }; - match condition_flow { - InferConditionFlow::TrueCondition => Ok(ResultTypeOrContinue::Result(guard_type)), - InferConditionFlow::FalseCondition => { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - Ok(ResultTypeOrContinue::Result(TypeOps::Remove.apply( - db, - &antecedent_type, - &guard_type, - ))) - } + Ok(Some((maybe_ref_id, guard.deref().clone()))) +} + +fn get_type_at_call_expr_by_type_guard( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, + call_expr: LuaCallExpr, + func_type: Arc, + condition_flow: InferConditionFlow, +) -> Result { + let Some((maybe_ref_id, guard_type)) = + get_type_guard_call_info(db, cache, call_expr, func_type)? + else { + return Ok(ConditionFlowAction::Continue); + }; + + if maybe_ref_id != *var_ref_id { + return Ok(ConditionFlowAction::Continue); } + + Ok(ConditionFlowAction::pending( + PendingConditionNarrow::TypeGuard { + narrow: guard_type, + condition_flow, + }, + )) } -#[allow(clippy::too_many_arguments)] fn get_type_at_call_expr_by_signature_self( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_prefix: LuaExpr, - signature_cast: &LuaSignatureCast, signature_id: LuaSignatureId, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let LuaExpr::IndexExpr(call_prefix_index) = call_prefix else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(self_expr) = call_prefix_index.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, self_expr) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if name_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let signature_root = syntax_tree.get_chunk_node(); - - // Choose the appropriate cast based on condition_flow and whether fallback exists - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => { - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - InferConditionFlow::FalseCondition => { - // Use fallback_cast if available, otherwise use the default behavior - if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { - let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - fallback_op_type, - antecedent_type.clone(), - InferConditionFlow::TrueCondition, // Apply fallback as force cast - )? - } else { - // Original behavior: remove the true type from antecedent - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - } - }; - - Ok(ResultTypeOrContinue::Result(result_type)) + Ok(get_signature_cast_pending(signature_id, condition_flow)) } #[allow(clippy::too_many_arguments)] fn get_type_at_call_expr_by_signature_param_name( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, - signature_cast: &LuaSignatureCast, signature_id: LuaSignatureId, name: &str, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let colon_call = call_expr.is_colon_call(); let Some(arg_list) = call_expr.get_args_list() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(signature) = db.get_signature_index().get(&signature_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(mut param_idx) = signature.find_param_idx(name) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let colon_define = signature.is_colon_define; match (colon_call, colon_define) { (true, false) => { if param_idx == 0 { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } param_idx -= 1; @@ -318,80 +290,34 @@ fn get_type_at_call_expr_by_signature_param_name( } let Some(expr) = arg_list.get_args().nth(param_idx) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, expr) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if name_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let signature_root = syntax_tree.get_chunk_node(); - - // Choose the appropriate cast based on condition_flow and whether fallback exists - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => { - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - InferConditionFlow::FalseCondition => { - // Use fallback_cast if available, otherwise use the default behavior - if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { - let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - fallback_op_type, - antecedent_type.clone(), - InferConditionFlow::TrueCondition, // Apply fallback as force cast - )? - } else { - // Original behavior: remove the true type from antecedent - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - } - }; + Ok(get_signature_cast_pending(signature_id, condition_flow)) +} - Ok(ResultTypeOrContinue::Result(result_type)) +fn get_signature_cast_pending( + signature_id: LuaSignatureId, + condition_flow: InferConditionFlow, +) -> ConditionFlowAction { + ConditionFlowAction::pending(PendingConditionNarrow::SignatureCast { + signature_id, + condition_flow, + }) } -#[allow(unused, clippy::too_many_arguments)] fn get_type_at_call_expr_by_call( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, alias_call_type: &Arc, condition_flow: InferConditionFlow, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 8d0ddaac1..5b9b88a4c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -4,15 +4,112 @@ use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk}; use crate::{ DbIndex, FlowId, FlowNode, FlowTree, InferFailReason, LuaDeclId, LuaFunctionType, - LuaInferCache, LuaType, TypeOps, infer_expr, instantiate_func_generic, + LuaInferCache, LuaSignature, LuaType, TypeOps, infer_expr, instantiate_func_generic, semantic::infer::{ VarRefId, - narrow::{ResultTypeOrContinue, get_single_antecedent, get_type_at_flow::get_type_at_flow}, + narrow::{get_single_antecedent, get_type_at_flow::get_type_at_flow}, }, }; +#[derive(Debug, Clone)] +pub(in crate::semantic) struct CorrelatedConditionNarrowing { + search_root_correlated_types: Vec, +} + +#[derive(Debug, Clone)] +struct SearchRootCorrelatedTypes { + matching_target_types: Vec, + uncorrelated_target_types: Vec, + deferred_known_call_target_types: Option>, +} + +#[derive(Debug)] +struct CollectedCorrelatedTypes { + matching_target_types: Vec, + correlated_candidate_types: Vec, + unmatched_target_types: Vec, + has_unmatched_discriminant_origin: bool, + has_opaque_target_origin: bool, +} + +impl CorrelatedConditionNarrowing { + pub(in crate::semantic::infer::narrow) fn apply( + &self, + db: &DbIndex, + antecedent_type: LuaType, + ) -> LuaType { + let mut root_target_types = Vec::new(); + let mut found_matching_root = false; + for root_types in &self.search_root_correlated_types { + let matching_target_types = &root_types.matching_target_types; + let mut uncorrelated_target_types = root_types.uncorrelated_target_types.clone(); + let deferred_known_call_target_types = + root_types.deferred_known_call_target_types.as_deref(); + + let root_matching_target_type = if matching_target_types.is_empty() { + None + } else { + let matching_target_type = LuaType::from_vec(matching_target_types.clone()); + let narrowed_correlated_type = + TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); + if narrowed_correlated_type.is_never() { + None + } else { + found_matching_root = true; + Some(narrowed_correlated_type) + } + }; + + if let Some(known_call_target_types) = deferred_known_call_target_types { + let remaining_root_type = + if known_call_target_types.is_empty() && uncorrelated_target_types.is_empty() { + Some(antecedent_type.clone()) + } else { + subtract_correlated_candidate_types( + db, + antecedent_type.clone(), + &known_call_target_types, + ) + }; + if let Some(remaining_root_type) = remaining_root_type { + uncorrelated_target_types.push(remaining_root_type); + } + } + + let root_uncorrelated_target_type = (!uncorrelated_target_types.is_empty()) + .then(|| LuaType::from_vec(uncorrelated_target_types)); + + match (root_matching_target_type, root_uncorrelated_target_type) { + (Some(root_matching_target_type), Some(root_uncorrelated_target_type)) => { + root_target_types.push(LuaType::from_vec(vec![ + root_matching_target_type, + root_uncorrelated_target_type, + ])); + } + (Some(root_matching_target_type), None) => { + root_target_types.push(root_matching_target_type); + } + (None, Some(root_uncorrelated_target_type)) => { + root_target_types.push(root_uncorrelated_target_type); + } + (None, None) => {} + } + } + + if !found_matching_root { + return antecedent_type; + } + + if root_target_types.is_empty() { + antecedent_type + } else { + LuaType::from_vec(root_target_types) + } + } +} + #[allow(clippy::too_many_arguments)] -pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return_overload_condition( +pub(in crate::semantic::infer::narrow) fn prepare_var_from_return_overload_condition( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -22,27 +119,27 @@ pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return discriminant_decl_id: LuaDeclId, condition_position: rowan::TextSize, narrowed_discriminant_type: &LuaType, -) -> Result { +) -> Result, InferFailReason> { let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; if !tree.has_decl_multi_return_refs(&discriminant_decl_id) || !tree.has_decl_multi_return_refs(&target_decl_id) { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let search_root_flow_ids = tree.get_decl_multi_return_search_roots( &discriminant_decl_id, &target_decl_id, condition_position, antecedent_flow_id, ); - let mut matching_target_types = Vec::new(); - let mut uncorrelated_target_types = Vec::new(); - for search_root_flow_id in search_root_flow_ids { - let (root_matching_target_types, root_uncorrelated_target_type) = + let root_correlated_types = search_root_flow_ids + .iter() + .copied() + .map(|search_root_flow_id| { collect_correlated_types_from_search_root( db, tree, @@ -52,47 +149,23 @@ pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return discriminant_decl_id, target_decl_id, condition_position, + antecedent_flow_id, search_root_flow_id, narrowed_discriminant_type, - )?; - matching_target_types.extend(root_matching_target_types); - if let Some(root_uncorrelated_target_type) = root_uncorrelated_target_type { - uncorrelated_target_types.push(root_uncorrelated_target_type); - } - } + ) + }) + .collect::, _>>()?; - if matching_target_types.is_empty() { - return Ok(ResultTypeOrContinue::Continue); - } - - let matching_target_type = LuaType::from_vec(matching_target_types); - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - let narrowed_correlated_type = - TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); - if narrowed_correlated_type.is_never() { - return Ok(ResultTypeOrContinue::Continue); - } - - if uncorrelated_target_types.is_empty() { - return Ok(if narrowed_correlated_type == antecedent_type { - ResultTypeOrContinue::Continue - } else { - ResultTypeOrContinue::Result(narrowed_correlated_type) - }); + if root_correlated_types + .iter() + .all(|root_types| root_types.matching_target_types.is_empty()) + { + return Ok(None); } - let uncorrelated_target_type = LuaType::from_vec(uncorrelated_target_types); - let merged_type = if uncorrelated_target_type.is_never() { - narrowed_correlated_type - } else { - LuaType::from_vec(vec![narrowed_correlated_type, uncorrelated_target_type]) - }; - - Ok(if merged_type == antecedent_type { - ResultTypeOrContinue::Continue - } else { - ResultTypeOrContinue::Result(merged_type) - }) + Ok(Some(CorrelatedConditionNarrowing { + search_root_correlated_types: root_correlated_types, + })) } #[allow(clippy::too_many_arguments)] @@ -105,9 +178,10 @@ fn collect_correlated_types_from_search_root( discriminant_decl_id: LuaDeclId, target_decl_id: LuaDeclId, condition_position: rowan::TextSize, + antecedent_flow_id: FlowId, search_root_flow_id: FlowId, narrowed_discriminant_type: &LuaType, -) -> Result<(Vec, Option), InferFailReason> { +) -> Result { let (discriminant_refs, discriminant_has_non_reference_origin) = tree .get_decl_multi_return_ref_summary_at( &discriminant_decl_id, @@ -119,18 +193,13 @@ fn collect_correlated_types_from_search_root( condition_position, search_root_flow_id, ); - if discriminant_refs.is_empty() || target_refs.is_empty() { - return Ok(( - Vec::new(), - get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id).ok(), - )); - } - - let ( - root_matching_target_types, - root_correlated_candidate_types, - has_unmatched_correlated_origin, - ) = collect_matching_correlated_types( + let CollectedCorrelatedTypes { + matching_target_types: root_matching_target_types, + correlated_candidate_types: root_correlated_candidate_types, + unmatched_target_types: root_unmatched_target_types, + has_unmatched_discriminant_origin, + has_opaque_target_origin, + } = collect_matching_correlated_types( db, cache, root, @@ -138,27 +207,48 @@ fn collect_correlated_types_from_search_root( &target_refs, narrowed_discriminant_type, )?; - if root_matching_target_types.is_empty() { - return Ok(( - Vec::new(), - get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id).ok(), - )); - } - let root_uncorrelated_target_type = if discriminant_has_non_reference_origin + let mut root_uncorrelated_target_types = root_unmatched_target_types; + let has_uncorrelated_origin = discriminant_has_non_reference_origin || target_has_non_reference_origin - || has_unmatched_correlated_origin - { - get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) - .ok() - .and_then(|root_type| { - subtract_correlated_candidate_types(db, root_type, &root_correlated_candidate_types) - }) - } else { - None - }; + || has_opaque_target_origin + || has_unmatched_discriminant_origin; + let correlated_candidate_types_is_empty = root_correlated_candidate_types.is_empty(); + let deferred_known_call_target_types = + if has_uncorrelated_origin && search_root_flow_id == antecedent_flow_id { + let mut known_call_target_types = root_correlated_candidate_types.clone(); + known_call_target_types.extend(root_uncorrelated_target_types.iter().cloned()); + Some(known_call_target_types) + } else { + None + }; + if has_uncorrelated_origin && deferred_known_call_target_types.is_none() { + if correlated_candidate_types_is_empty && root_uncorrelated_target_types.is_empty() { + if let Ok(root_type) = + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) + { + root_uncorrelated_target_types.push(root_type); + } + } else { + let mut known_call_target_types = root_correlated_candidate_types; + known_call_target_types.extend(root_uncorrelated_target_types.iter().cloned()); + if let Some(remaining_root_type) = + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) + .ok() + .and_then(|root_type| { + subtract_correlated_candidate_types(db, root_type, &known_call_target_types) + }) + { + root_uncorrelated_target_types.push(remaining_root_type); + } + } + } - Ok((root_matching_target_types, root_uncorrelated_target_type)) + Ok(SearchRootCorrelatedTypes { + matching_target_types: root_matching_target_types, + uncorrelated_target_types: root_uncorrelated_target_types, + deferred_known_call_target_types, + }) } fn subtract_correlated_candidate_types( @@ -195,9 +285,10 @@ fn collect_matching_correlated_types( discriminant_refs: &[crate::DeclMultiReturnRef], target_refs: &[crate::DeclMultiReturnRef], narrowed_discriminant_type: &LuaType, -) -> Result<(Vec, Vec, bool), InferFailReason> { +) -> Result { let mut matching_target_types = Vec::new(); let mut correlated_candidate_types = Vec::new(); + let mut unmatched_target_types = Vec::new(); let mut correlated_discriminant_call_expr_ids = HashSet::new(); let mut correlated_target_call_expr_ids = HashSet::new(); @@ -211,7 +302,7 @@ fn collect_matching_correlated_types( continue; } - let overload_rows = instantiate_return_overload_rows(db, cache, call_expr, signature); + let overload_rows = instantiate_return_rows(db, cache, call_expr, signature); let discriminant_call_expr_id = discriminant_ref.call_expr.get_syntax_id(); for target_ref in target_refs { @@ -221,18 +312,16 @@ fn collect_matching_correlated_types( correlated_discriminant_call_expr_ids.insert(discriminant_call_expr_id); correlated_target_call_expr_ids.insert(target_ref.call_expr.get_syntax_id()); correlated_candidate_types.extend(overload_rows.iter().map(|overload| { - crate::LuaSignature::get_overload_row_slot(overload, target_ref.return_index) + LuaSignature::get_overload_row_slot(overload, target_ref.return_index) })); matching_target_types.extend(overload_rows.iter().filter_map(|overload| { - let discriminant_type = crate::LuaSignature::get_overload_row_slot( - overload, - discriminant_ref.return_index, - ); + let discriminant_type = + LuaSignature::get_overload_row_slot(overload, discriminant_ref.return_index); if !TypeOps::Intersect .apply(db, &discriminant_type, narrowed_discriminant_type) .is_never() { - return Some(crate::LuaSignature::get_overload_row_slot( + return Some(LuaSignature::get_overload_row_slot( overload, target_ref.return_index, )); @@ -243,16 +332,36 @@ fn collect_matching_correlated_types( } } - let has_unmatched_correlated_origin = discriminant_refs.iter().any(|discriminant_ref| { + let mut has_opaque_target_origin = false; + for target_ref in target_refs { + if correlated_target_call_expr_ids.contains(&target_ref.call_expr.get_syntax_id()) { + continue; + } + + let Some((call_expr, signature)) = + infer_signature_for_call_ptr(db, cache, root, &target_ref.call_expr)? + else { + has_opaque_target_origin = true; + continue; + }; + let return_rows = instantiate_return_rows(db, cache, call_expr, signature); + unmatched_target_types.extend( + return_rows + .iter() + .map(|row| LuaSignature::get_overload_row_slot(row, target_ref.return_index)), + ); + } + + let has_unmatched_discriminant_origin = discriminant_refs.iter().any(|discriminant_ref| { !correlated_discriminant_call_expr_ids.contains(&discriminant_ref.call_expr.get_syntax_id()) - }) || target_refs.iter().any(|target_ref| { - !correlated_target_call_expr_ids.contains(&target_ref.call_expr.get_syntax_id()) }); - Ok(( + Ok(CollectedCorrelatedTypes { matching_target_types, correlated_candidate_types, - has_unmatched_correlated_origin, - )) + unmatched_target_types, + has_unmatched_discriminant_origin, + has_opaque_target_origin, + }) } fn infer_signature_for_call_ptr<'a>( @@ -260,7 +369,7 @@ fn infer_signature_for_call_ptr<'a>( cache: &mut LuaInferCache, root: &LuaChunk, call_expr_ptr: &LuaAstPtr, -) -> Result, InferFailReason> { +) -> Result, InferFailReason> { let Some(call_expr) = call_expr_ptr.to_node(root) else { return Ok(None); }; @@ -278,16 +387,36 @@ fn infer_signature_for_call_ptr<'a>( Ok(Some((call_expr, signature))) } -fn instantiate_return_overload_rows( +fn instantiate_return_rows( db: &DbIndex, cache: &mut LuaInferCache, call_expr: LuaCallExpr, - signature: &crate::LuaSignature, + signature: &LuaSignature, ) -> Vec> { + if signature.return_overloads.is_empty() { + let return_type = signature.get_return_type(); + let instantiated_return_type = if return_type.contain_tpl() { + let func = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + return_type.clone(), + ); + match instantiate_func_generic(db, cache, &func, call_expr) { + Ok(instantiated) => instantiated.get_ret().clone(), + Err(_) => return_type, + } + } else { + return_type + }; + return vec![LuaSignature::return_type_to_row(instantiated_return_type)]; + } + let mut rows = Vec::with_capacity(signature.return_overloads.len()); for overload in &signature.return_overloads { let type_refs = &overload.type_refs; - let overload_return_type = crate::LuaSignature::row_to_return_type(type_refs.to_vec()); + let overload_return_type = LuaSignature::row_to_return_type(type_refs.to_vec()); let instantiated_return_type = if overload_return_type.contain_tpl() { let overload_func = LuaFunctionType::new( signature.async_state, @@ -304,9 +433,7 @@ fn instantiate_return_overload_rows( overload_return_type }; - rows.push(crate::LuaSignature::return_type_to_row( - instantiated_return_type, - )); + rows.push(LuaSignature::return_type_to_row(instantiated_return_type)); } rows diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs index ef1738954..a2d7eb54a 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs @@ -1,123 +1,52 @@ -use emmylua_parser::{LuaChunk, LuaExpr, LuaIndexExpr, LuaIndexMemberExpr}; +use emmylua_parser::{LuaExpr, LuaIndexExpr, LuaIndexMemberExpr}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaInferCache, LuaType, TypeOps, + DbIndex, InferFailReason, LuaInferCache, semantic::infer::{ VarRefId, - infer_index::infer_member_by_member_key, narrow::{ - ResultTypeOrContinue, condition_flow::InferConditionFlow, get_single_antecedent, - get_type_at_flow::get_type_at_flow, narrow_false_or_nil, remove_false_or_nil, + condition_flow::{ConditionFlowAction, InferConditionFlow, PendingConditionNarrow}, var_ref_id::get_var_expr_var_ref_id, }, }, }; -#[allow(clippy::too_many_arguments)] pub fn get_type_at_index_expr( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, index_expr: LuaIndexExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(index_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - if name_var_ref_id != *var_ref_id { - return maybe_field_exist_narrow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - index_expr, - condition_flow, - ); + if name_var_ref_id == *var_ref_id { + return Ok(ConditionFlowAction::pending( + PendingConditionNarrow::Truthiness(condition_flow), + )); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let result_type = match condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), - }; - - Ok(ResultTypeOrContinue::Result(result_type)) -} - -#[allow(clippy::too_many_arguments)] -fn maybe_field_exist_narrow( - db: &DbIndex, - tree: &FlowTree, - cache: &mut LuaInferCache, - root: &LuaChunk, - var_ref_id: &VarRefId, - flow_node: &FlowNode, - index_expr: LuaIndexExpr, - condition_flow: InferConditionFlow, -) -> Result { let Some(prefix_expr) = index_expr.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, prefix_expr.clone()) else { // If we cannot find a reference declaration ID, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if maybe_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); - } - - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let left_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - let LuaType::Union(union_type) = &left_type else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let index = LuaIndexMemberExpr::IndexExpr(index_expr); - let mut result = vec![]; - let union_types = union_type.into_vec(); - for sub_type in &union_types { - let member_type = match infer_member_by_member_key( - db, - cache, - sub_type, - index.clone(), - &InferGuard::new(), - ) { - Ok(member_type) => member_type, - Err(_) => continue, // If we cannot infer the member type, skip this type - }; - // donot use always true - if !member_type.is_always_falsy() { - result.push(sub_type.clone()); - } + return Ok(ConditionFlowAction::Continue); } - match condition_flow { - InferConditionFlow::TrueCondition => { - if !result.is_empty() { - return Ok(ResultTypeOrContinue::Result(LuaType::from_vec(result))); - } - } - InferConditionFlow::FalseCondition => { - if !result.is_empty() { - let target = LuaType::from_vec(result); - let t = TypeOps::Remove.apply(db, &left_type, &target); - return Ok(ResultTypeOrContinue::Result(t)); - } - } - } - - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::pending( + PendingConditionNarrow::FieldTruthy { + index: LuaIndexMemberExpr::IndexExpr(index_expr), + condition_flow, + }, + )) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index 1a28ecb92..4e51cb355 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -1,24 +1,33 @@ mod binary_flow; mod call_flow; -mod correlated_flow; +pub(in crate::semantic::infer::narrow) mod correlated_flow; mod index_flow; -use self::correlated_flow::narrow_var_from_return_overload_condition; -use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator}; +use std::rc::Rc; + +use self::{ + binary_flow::get_type_at_binary_expr, + correlated_flow::{CorrelatedConditionNarrowing, prepare_var_from_return_overload_condition}, +}; +use emmylua_parser::{ + LuaAstNode, LuaChunk, LuaExpr, LuaIndexMemberExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator, +}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaType, + DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaInferCache, LuaSignatureCast, + LuaSignatureId, LuaType, semantic::infer::{ VarRefId, + infer_index::infer_member_by_member_key, narrow::{ ResultTypeOrContinue, condition_flow::{ - binary_flow::get_type_at_binary_expr, call_flow::get_type_at_call_expr, - index_flow::get_type_at_index_expr, + call_flow::get_type_at_call_expr, index_flow::get_type_at_index_expr, }, get_single_antecedent, + get_type_at_cast_flow::cast_type, get_type_at_flow::get_type_at_flow, - narrow_false_or_nil, remove_false_or_nil, + narrow_down_type, narrow_false_or_nil, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id, }, }, @@ -48,6 +57,236 @@ impl InferConditionFlow { } } +#[derive(Debug, Clone)] +pub(in crate::semantic) enum ConditionFlowAction { + Continue, + Result(LuaType), + Pending(Rc), +} + +impl From for ConditionFlowAction { + fn from(result_or_continue: ResultTypeOrContinue) -> Self { + match result_or_continue { + ResultTypeOrContinue::Continue => ConditionFlowAction::Continue, + ResultTypeOrContinue::Result(result_type) => ConditionFlowAction::Result(result_type), + } + } +} + +impl ConditionFlowAction { + pub(in crate::semantic::infer::narrow) fn pending( + pending_condition_narrow: PendingConditionNarrow, + ) -> Self { + ConditionFlowAction::Pending(Rc::new(pending_condition_narrow)) + } +} + +#[derive(Debug, Clone)] +pub(in crate::semantic) enum PendingConditionNarrow { + Truthiness(InferConditionFlow), + FieldTruthy { + index: LuaIndexMemberExpr, + condition_flow: InferConditionFlow, + }, + SameVarColonCall { + index: LuaIndexMemberExpr, + condition_flow: InferConditionFlow, + }, + SignatureCast { + signature_id: LuaSignatureId, + condition_flow: InferConditionFlow, + }, + Eq { + right_expr_type: LuaType, + condition_flow: InferConditionFlow, + }, + TypeGuard { + narrow: LuaType, + condition_flow: InferConditionFlow, + }, + Correlated(CorrelatedConditionNarrowing), +} + +impl PendingConditionNarrow { + pub(in crate::semantic::infer::narrow) fn apply( + &self, + db: &DbIndex, + cache: &mut LuaInferCache, + antecedent_type: LuaType, + ) -> LuaType { + match self { + PendingConditionNarrow::Truthiness(condition_flow) => match condition_flow.clone() { + InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), + InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), + }, + PendingConditionNarrow::FieldTruthy { + index, + condition_flow, + } => { + let LuaType::Union(union_type) = &antecedent_type else { + return antecedent_type; + }; + + let union_types = union_type.into_vec(); + let mut result = vec![]; + for sub_type in &union_types { + let member_type = match infer_member_by_member_key( + db, + cache, + sub_type, + index.clone(), + &InferGuard::new(), + ) { + Ok(member_type) => member_type, + Err(_) => continue, + }; + + if !member_type.is_always_falsy() { + result.push(sub_type.clone()); + } + } + + if result.is_empty() { + antecedent_type + } else { + match condition_flow.clone() { + InferConditionFlow::TrueCondition => LuaType::from_vec(result), + InferConditionFlow::FalseCondition => { + let target = LuaType::from_vec(result); + crate::TypeOps::Remove.apply(db, &antecedent_type, &target) + } + } + } + } + PendingConditionNarrow::SameVarColonCall { + index, + condition_flow, + } => { + let Ok(member_type) = infer_member_by_member_key( + db, + cache, + &antecedent_type, + index.clone(), + &InferGuard::new(), + ) else { + return antecedent_type; + }; + + let LuaType::Signature(signature_id) = member_type else { + return antecedent_type; + }; + + let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) + else { + return antecedent_type; + }; + + if signature_cast.name != "self" { + return antecedent_type; + } + + apply_signature_cast( + db, + antecedent_type, + signature_id.clone(), + signature_cast, + condition_flow.clone(), + ) + } + PendingConditionNarrow::SignatureCast { + signature_id, + condition_flow, + } => { + let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) + else { + return antecedent_type; + }; + + apply_signature_cast( + db, + antecedent_type, + signature_id.clone(), + signature_cast, + condition_flow.clone(), + ) + } + PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + } => match condition_flow.clone() { + InferConditionFlow::TrueCondition => { + let maybe_type = + crate::TypeOps::Intersect.apply(db, &antecedent_type, right_expr_type); + if maybe_type.is_never() { + antecedent_type + } else { + maybe_type + } + } + InferConditionFlow::FalseCondition => { + crate::TypeOps::Remove.apply(db, &antecedent_type, right_expr_type) + } + }, + PendingConditionNarrow::TypeGuard { + narrow, + condition_flow, + } => match condition_flow.clone() { + InferConditionFlow::TrueCondition => { + narrow_down_type(db, antecedent_type, narrow.clone(), None) + .unwrap_or_else(|| narrow.clone()) + } + InferConditionFlow::FalseCondition => { + crate::TypeOps::Remove.apply(db, &antecedent_type, narrow) + } + }, + PendingConditionNarrow::Correlated(correlated_narrowing) => { + correlated_narrowing.apply(db, antecedent_type) + } + } + } +} + +fn apply_signature_cast( + db: &DbIndex, + antecedent_type: LuaType, + signature_id: LuaSignatureId, + signature_cast: &LuaSignatureCast, + condition_flow: InferConditionFlow, +) -> LuaType { + let file_id = signature_id.get_file_id(); + let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&file_id) else { + return antecedent_type; + }; + let signature_root = syntax_tree.get_chunk_node(); + + let (cast_ptr, cast_flow) = match condition_flow { + InferConditionFlow::TrueCondition => (&signature_cast.cast, condition_flow), + InferConditionFlow::FalseCondition => ( + signature_cast + .fallback_cast + .as_ref() + .unwrap_or(&signature_cast.cast), + signature_cast + .fallback_cast + .as_ref() + .map(|_| InferConditionFlow::TrueCondition) + .unwrap_or(condition_flow), + ), + }; + let Some(cast_op_type) = cast_ptr.to_node(&signature_root) else { + return antecedent_type; + }; + + cast_type( + db, + file_id, + cast_op_type, + antecedent_type.clone(), + cast_flow, + ) + .unwrap_or(antecedent_type) +} + #[allow(clippy::too_many_arguments)] pub fn get_type_at_condition_flow( db: &DbIndex, @@ -58,7 +297,7 @@ pub fn get_type_at_condition_flow( flow_node: &FlowNode, condition: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { match condition { LuaExpr::NameExpr(name_expr) => get_type_at_name_expr( db, @@ -70,28 +309,14 @@ pub fn get_type_at_condition_flow( name_expr, condition_flow, ), - LuaExpr::CallExpr(call_expr) => get_type_at_call_expr( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - call_expr, - condition_flow, - ), - LuaExpr::IndexExpr(index_expr) => get_type_at_index_expr( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - index_expr, - condition_flow, - ), + LuaExpr::CallExpr(call_expr) => { + get_type_at_call_expr(db, cache, var_ref_id, call_expr, condition_flow) + } + LuaExpr::IndexExpr(index_expr) => { + get_type_at_index_expr(db, cache, var_ref_id, index_expr, condition_flow) + } LuaExpr::TableExpr(_) | LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) => { - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } LuaExpr::BinaryExpr(binary_expr) => get_type_at_binary_expr( db, @@ -115,7 +340,7 @@ pub fn get_type_at_condition_flow( ), LuaExpr::ParenExpr(paren_expr) => { let Some(inner_expr) = paren_expr.get_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; get_type_at_condition_flow( @@ -142,11 +367,11 @@ fn get_type_at_name_expr( flow_node: &FlowNode, name_expr: LuaNameExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::NameExpr(name_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if name_var_ref_id != *var_ref_id { @@ -162,15 +387,9 @@ fn get_type_at_name_expr( ); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let result_type = match condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), - }; - - Ok(ResultTypeOrContinue::Result(result_type)) + Ok(ConditionFlowAction::pending( + PendingConditionNarrow::Truthiness(condition_flow), + )) } #[allow(clippy::too_many_arguments)] @@ -183,47 +402,57 @@ fn get_type_at_name_ref( flow_node: &FlowNode, name_expr: LuaNameExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(decl_id) = db .get_reference_index() .get_var_reference_decl(&cache.get_file_id(), name_expr.get_range()) else { - return Ok(ResultTypeOrContinue::Continue); - }; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_discriminant_type = get_type_at_flow( - db, - tree, - cache, - root, - &VarRefId::VarRef(decl_id), - antecedent_flow_id, - )?; - let narrowed_discriminant_type = match condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_discriminant_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_discriminant_type), + return Ok(ConditionFlowAction::Continue); }; - if let ResultTypeOrContinue::Result(result_type) = narrow_var_from_return_overload_condition( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - decl_id, - name_expr.get_position(), - &narrowed_discriminant_type, - )? { - return Ok(ResultTypeOrContinue::Result(result_type)); + if let Some(target_decl_id) = var_ref_id.get_decl_id_ref() + && tree.has_decl_multi_return_refs(&decl_id) + && tree.has_decl_multi_return_refs(&target_decl_id) + { + let antecedent_flow_id = get_single_antecedent(flow_node)?; + let antecedent_discriminant_type = get_type_at_flow( + db, + tree, + cache, + root, + &VarRefId::VarRef(decl_id), + antecedent_flow_id, + )?; + let narrowed_discriminant_type = match condition_flow { + InferConditionFlow::FalseCondition => { + narrow_false_or_nil(db, antecedent_discriminant_type) + } + InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_discriminant_type), + }; + + if let Some(correlated_narrowing) = prepare_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + decl_id, + name_expr.get_position(), + &narrowed_discriminant_type, + )? { + return Ok(ConditionFlowAction::pending( + PendingConditionNarrow::Correlated(correlated_narrowing), + )); + } } let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(expr) = expr_ptr.to_node(root) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; get_type_at_condition_flow( @@ -274,19 +503,19 @@ fn get_type_at_unary_flow( flow_node: &FlowNode, unary_expr: LuaUnaryExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(inner_expr) = unary_expr.get_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(op) = unary_expr.get_op_token() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match op.get_op() { UnaryOperator::OpNot => {} _ => { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs index 6aa14497a..17bf8193d 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs @@ -50,7 +50,7 @@ fn get_type_at_cast_expr( return Ok(ResultTypeOrContinue::Continue); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let mut antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; for cast_op_type in tag_cast.get_op_types() { @@ -74,7 +74,7 @@ fn get_type_at_inline_cast( flow_node: &FlowNode, tag_cast: LuaDocTagCast, ) -> Result { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let mut antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; for cast_op_type in tag_cast.get_op_types() { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 66ccda945..416d942c9 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -1,18 +1,26 @@ +use std::rc::Rc; + use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaChunk, LuaExpr, LuaVarExpr}; use crate::{ CacheEntry, DbIndex, FlowId, FlowNode, FlowNodeKind, FlowTree, InferFailReason, LuaDeclId, - LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, infer_expr, - semantic::infer::{ - InferResult, VarRefId, infer_expr_list_value_type_at, - narrow::{ - ResultTypeOrContinue, - condition_flow::{InferConditionFlow, get_type_at_condition_flow}, - get_multi_antecedents, get_single_antecedent, - get_type_at_cast_flow::get_type_at_cast_flow, - get_var_ref_type, narrow_down_type, - var_ref_id::get_var_expr_var_ref_id, + LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, check_type_compact, infer_expr, + semantic::{ + infer::{ + InferResult, VarRefId, infer_expr_list_value_type_at, + narrow::{ + ResultTypeOrContinue, + condition_flow::{ + ConditionFlowAction, InferConditionFlow, PendingConditionNarrow, + get_type_at_condition_flow, + }, + get_multi_antecedents, get_single_antecedent, + get_type_at_cast_flow::get_type_at_cast_flow, + get_var_ref_type, narrow_down_type, + var_ref_id::get_var_expr_var_ref_id, + }, }, + member::find_members, }, }; @@ -24,171 +32,325 @@ pub fn get_type_at_flow( var_ref_id: &VarRefId, flow_id: FlowId, ) -> InferResult { - let key = (var_ref_id.clone(), flow_id); - if let Some(cache_entry) = cache.flow_node_cache.get(&key) - && let CacheEntry::Cache(narrow_type) = cache_entry - { - return Ok(narrow_type.clone()); + get_type_at_flow_internal(db, tree, cache, root, var_ref_id, flow_id, true) +} + +fn can_reuse_narrowed_assignment_source( + db: &DbIndex, + narrowed_source_type: &LuaType, + expr_type: &LuaType, +) -> bool { + if matches!(expr_type, LuaType::TableConst(_) | LuaType::Object(_)) { + return is_partial_assignment_expr_compatible(db, narrowed_source_type, expr_type); } - let result_type; - let mut antecedent_flow_id = flow_id; - loop { - let flow_node = tree - .get_flow_node(antecedent_flow_id) - .ok_or(InferFailReason::None)?; - - match &flow_node.kind { - FlowNodeKind::Start | FlowNodeKind::Unreachable => { - result_type = get_var_ref_type(db, cache, var_ref_id)?; - break; - } - FlowNodeKind::LoopLabel | FlowNodeKind::Break | FlowNodeKind::Return => { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + if !is_exact_assignment_expr_type(expr_type) { + return false; + } + + match narrow_down_type(db, narrowed_source_type.clone(), expr_type.clone(), None) { + Some(narrowed_expr_type) => narrowed_expr_type == *expr_type, + None => true, + } +} + +fn preserves_assignment_expr_type(typ: &LuaType) -> bool { + matches!(typ, LuaType::TableConst(_) | LuaType::Object(_)) || is_exact_assignment_expr_type(typ) +} + +fn is_partial_assignment_expr_compatible( + db: &DbIndex, + source_type: &LuaType, + expr_type: &LuaType, +) -> bool { + if check_type_compact(db, source_type, expr_type).is_ok() { + return true; + } + + // Only preserve branch narrowing for concrete partial table/object literals. + // Broader RHS expressions can carry hidden state the current flow/type model cannot represent + // without wider semantic changes. + if !matches!(expr_type, LuaType::TableConst(_) | LuaType::Object(_)) { + return false; + } + + let expr_members = find_members(db, expr_type).unwrap_or_default(); + + if expr_members.is_empty() { + return true; + } + + let Some(source_members) = find_members(db, source_type) else { + return false; + }; + + expr_members.into_iter().all(|expr_member| { + match source_members + .iter() + .find(|source_member| source_member.key == expr_member.key) + { + Some(source_member) => { + is_partial_assignment_expr_compatible(db, &source_member.typ, &expr_member.typ) } - FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => { - let multi_antecedents = get_multi_antecedents(tree, flow_node)?; - - let mut branch_result_type = LuaType::Unknown; - for &flow_id in &multi_antecedents { - let branch_type = get_type_at_flow(db, tree, cache, root, var_ref_id, flow_id)?; - branch_result_type = - TypeOps::Union.apply(db, &branch_result_type, &branch_type); + None => true, + } + }) +} + +fn is_exact_assignment_expr_type(typ: &LuaType) -> bool { + match typ { + LuaType::Nil | LuaType::DocBooleanConst(_) => true, + typ if typ.is_const() => !matches!(typ, LuaType::TableConst(_)), + LuaType::Union(union) => union.into_vec().iter().all(is_exact_assignment_expr_type), + LuaType::MultiLineUnion(multi_union) => { + is_exact_assignment_expr_type(&multi_union.to_union()) + } + LuaType::TypeGuard(inner) => is_exact_assignment_expr_type(inner), + _ => false, + } +} + +fn get_type_at_flow_internal( + db: &DbIndex, + tree: &FlowTree, + cache: &mut LuaInferCache, + root: &LuaChunk, + var_ref_id: &VarRefId, + flow_id: FlowId, + use_condition_narrowing: bool, +) -> InferResult { + let key = (var_ref_id.clone(), flow_id, use_condition_narrowing); + if let Some(cache_entry) = cache.flow_node_cache.get(&key) { + return match cache_entry { + CacheEntry::Cache(narrow_type) => Ok::(narrow_type.clone()), + CacheEntry::Ready => Err(InferFailReason::RecursiveInfer), + }; + } + + cache.flow_node_cache.insert(key.clone(), CacheEntry::Ready); + + let result = (|| { + let result_type; + let mut antecedent_flow_id = flow_id; + let mut pending_condition_narrows: Vec> = Vec::new(); + loop { + let flow_node = tree + .get_flow_node(antecedent_flow_id) + .ok_or(InferFailReason::None)?; + + match &flow_node.kind { + FlowNodeKind::Start | FlowNodeKind::Unreachable => { + result_type = get_var_ref_type(db, cache, var_ref_id)?; + break; } - result_type = branch_result_type; - break; - } - FlowNodeKind::DeclPosition(position) => { - if *position <= var_ref_id.get_position() { - match get_var_ref_type(db, cache, var_ref_id) { - Ok(var_type) => { - result_type = var_type; - break; - } - Err(err) => { - // 尝试推断声明位置的类型, 如果发生错误则返回初始错误, 否则返回当前推断错误 - if let Some(init_type) = - try_infer_decl_initializer_type(db, cache, root, var_ref_id)? - { - result_type = init_type; + FlowNodeKind::LoopLabel | FlowNodeKind::Break | FlowNodeKind::Return => { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } + FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => { + let multi_antecedents = get_multi_antecedents(tree, flow_node)?; + + let mut branch_result_type = LuaType::Unknown; + for &flow_id in &multi_antecedents { + let branch_type = get_type_at_flow_internal( + db, + tree, + cache, + root, + var_ref_id, + flow_id, + use_condition_narrowing, + )?; + branch_result_type = + TypeOps::Union.apply(db, &branch_result_type, &branch_type); + } + result_type = branch_result_type; + break; + } + FlowNodeKind::DeclPosition(position) => { + if *position <= var_ref_id.get_position() { + match get_var_ref_type(db, cache, var_ref_id) { + Ok(var_type) => { + result_type = var_type; break; } - - return Err(err); + Err(err) => { + // 尝试推断声明位置的类型, 如果发生错误则返回初始错误, 否则返回当前推断错误 + if let Some(init_type) = + try_infer_decl_initializer_type(db, cache, root, var_ref_id)? + { + result_type = init_type; + break; + } + + return Err(err); + } } + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; } - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; } - } - FlowNodeKind::Assignment(assign_ptr) => { - let assign_stat = assign_ptr.to_node(root).ok_or(InferFailReason::None)?; - let result_or_continue = get_type_at_assign_stat( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - assign_stat, - )?; - - if let ResultTypeOrContinue::Result(assign_type) = result_or_continue { - result_type = assign_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::Assignment(assign_ptr) => { + let assign_stat = assign_ptr.to_node(root).ok_or(InferFailReason::None)?; + let result_or_continue = get_type_at_assign_stat( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + assign_stat, + )?; + + if let ResultTypeOrContinue::Result(assign_type) = result_or_continue { + result_type = assign_type; + break; + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } } - } - FlowNodeKind::ImplFunc(func_ptr) => { - let func_stat = func_ptr.to_node(root).ok_or(InferFailReason::None)?; - let Some(func_name) = func_stat.get_func_name() else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - continue; - }; - - let Some(ref_id) = get_var_expr_var_ref_id(db, cache, func_name.to_expr()) else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - continue; - }; - - if ref_id == *var_ref_id { - let Some(closure) = func_stat.get_closure() else { - return Err(InferFailReason::None); + FlowNodeKind::ImplFunc(func_ptr) => { + let func_stat = func_ptr.to_node(root).ok_or(InferFailReason::None)?; + let Some(func_name) = func_stat.get_func_name() else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + continue; }; - result_type = LuaType::Signature(LuaSignatureId::from_closure( - cache.get_file_id(), - &closure, - )); - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let Some(ref_id) = get_var_expr_var_ref_id(db, cache, func_name.to_expr()) + else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + continue; + }; + + if ref_id == *var_ref_id { + let Some(closure) = func_stat.get_closure() else { + return Err(InferFailReason::None); + }; + + result_type = LuaType::Signature(LuaSignatureId::from_closure( + cache.get_file_id(), + &closure, + )); + break; + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } } - } - FlowNodeKind::TrueCondition(condition_ptr) => { - let condition = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; - let result_or_continue = get_type_at_condition_flow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - condition, - InferConditionFlow::TrueCondition, - )?; - - if let ResultTypeOrContinue::Result(condition_type) = result_or_continue { - result_type = condition_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::TrueCondition(condition_ptr) + | FlowNodeKind::FalseCondition(condition_ptr) => { + if !use_condition_narrowing { + antecedent_flow_id = get_single_antecedent(flow_node)?; + continue; + } + + let condition_flow = + if matches!(&flow_node.kind, FlowNodeKind::TrueCondition(_)) { + InferConditionFlow::TrueCondition + } else { + InferConditionFlow::FalseCondition + }; + let condition_key = ( + var_ref_id.clone(), + antecedent_flow_id, + matches!(condition_flow, InferConditionFlow::TrueCondition), + ); + let condition_action = { + if let Some(cache_entry) = cache.condition_flow_cache.get(&condition_key) { + match cache_entry { + CacheEntry::Cache(action) => { + Ok::(action.clone()) + } + CacheEntry::Ready => Err(InferFailReason::RecursiveInfer), + } + } else { + let condition = + condition_ptr.to_node(root).ok_or(InferFailReason::None)?; + cache + .condition_flow_cache + .insert(condition_key.clone(), CacheEntry::Ready); + let result = get_type_at_condition_flow( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + condition, + condition_flow, + ); + match &result { + Ok(action) => { + cache + .condition_flow_cache + .insert(condition_key, CacheEntry::Cache(action.clone())); + } + Err(_) => { + cache.condition_flow_cache.remove(&condition_key); + } + } + result + } + }?; + + match condition_action { + ConditionFlowAction::Pending(pending_condition_narrow) => { + pending_condition_narrows.push(pending_condition_narrow); + antecedent_flow_id = get_single_antecedent(flow_node)?; + } + ConditionFlowAction::Result(condition_type) => { + result_type = condition_type; + break; + } + ConditionFlowAction::Continue => { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } + } } - } - FlowNodeKind::FalseCondition(condition_ptr) => { - let condition = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; - let result_or_continue = get_type_at_condition_flow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - condition, - InferConditionFlow::FalseCondition, - )?; - - if let ResultTypeOrContinue::Result(condition_type) = result_or_continue { - result_type = condition_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::ForIStat(_) => { + // todo check for `for i = 1, 10 do end` + antecedent_flow_id = get_single_antecedent(flow_node)?; } - } - FlowNodeKind::ForIStat(_) => { - // todo check for `for i = 1, 10 do end` - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - } - FlowNodeKind::TagCast(cast_ast_ptr) => { - let tag_cast = cast_ast_ptr.to_node(root).ok_or(InferFailReason::None)?; - let cast_or_continue = - get_type_at_cast_flow(db, tree, cache, root, var_ref_id, flow_node, tag_cast)?; - - if let ResultTypeOrContinue::Result(cast_type) = cast_or_continue { - result_type = cast_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::TagCast(cast_ast_ptr) => { + let tag_cast = cast_ast_ptr.to_node(root).ok_or(InferFailReason::None)?; + let cast_or_continue = get_type_at_cast_flow( + db, tree, cache, root, var_ref_id, flow_node, tag_cast, + )?; + + if let ResultTypeOrContinue::Result(cast_type) = cast_or_continue { + result_type = cast_type; + break; + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } } } } + + let result_type = if use_condition_narrowing { + pending_condition_narrows.into_iter().rev().fold( + result_type, + |result_type, pending_condition_narrow| { + pending_condition_narrow.apply(db, cache, result_type) + }, + ) + } else { + result_type + }; + + Ok(result_type) + })(); + + match &result { + Ok(result_type) => { + cache + .flow_node_cache + .insert(key, CacheEntry::Cache(result_type.clone())); + } + Err(_) => { + cache.flow_node_cache.remove(&key); + } } - cache - .flow_node_cache - .insert(key, CacheEntry::Cache(result_type.clone())); - Ok(result_type) + result } fn get_type_at_assign_stat( @@ -231,12 +393,45 @@ fn get_type_at_assign_stat( return Ok(ResultTypeOrContinue::Continue); }; - let source_type = if let Some(explicit) = explicit_var_type.clone() { - explicit - } else { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)? - }; + let (source_type, reuse_source_narrowing) = + if let Some(explicit) = explicit_var_type.clone() { + (explicit, true) + } else { + let antecedent_flow_id = get_single_antecedent(flow_node)?; + if !preserves_assignment_expr_type(&expr_type) { + ( + get_type_at_flow_internal( + db, + tree, + cache, + root, + var_ref_id, + antecedent_flow_id, + false, + )?, + false, + ) + } else { + let narrowed_source_type = + get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; + if can_reuse_narrowed_assignment_source(db, &narrowed_source_type, &expr_type) { + (narrowed_source_type, true) + } else { + ( + get_type_at_flow_internal( + db, + tree, + cache, + root, + var_ref_id, + antecedent_flow_id, + false, + )?, + false, + ) + } + } + }; let narrowed = if source_type == LuaType::Nil { None @@ -252,7 +447,11 @@ fn get_type_at_assign_stat( narrow_down_type(db, source_type.clone(), expr_type.clone(), declared) }; - let result_type = narrowed.unwrap_or(explicit_var_type.unwrap_or(expr_type)); + let result_type = if reuse_source_narrowing || preserves_assignment_expr_type(&expr_type) { + narrowed.unwrap_or_else(|| explicit_var_type.unwrap_or_else(|| expr_type.clone())) + } else { + expr_type + }; return Ok(ResultTypeOrContinue::Result(result_type)); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs index b78b19c0c..4bd3de91d 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs @@ -12,6 +12,7 @@ use crate::{ infer_name::{find_decl_member_type, infer_global_type}, }, }; +pub(in crate::semantic) use condition_flow::ConditionFlowAction; use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr}; pub use get_type_at_cast_flow::get_type_at_call_expr_inline_cast; pub use narrow_type::{narrow_down_type, narrow_false_or_nil, remove_false_or_nil}; @@ -71,22 +72,11 @@ fn get_var_ref_type(db: &DbIndex, cache: &mut LuaInferCache, var_ref_id: &VarRef } } -fn get_single_antecedent(tree: &FlowTree, flow: &FlowNode) -> Result { +fn get_single_antecedent(flow: &FlowNode) -> Result { match &flow.antecedent { Some(antecedent) => match antecedent { FlowAntecedent::Single(id) => Ok(*id), - FlowAntecedent::Multiple(multi_id) => { - let multi_flow = tree - .get_multi_antecedents(*multi_id) - .ok_or(InferFailReason::None)?; - if !multi_flow.is_empty() { - // If there are multiple antecedents, we need to handle them separately - // For now, we just return the first one - Ok(multi_flow[0]) - } else { - Err(InferFailReason::None) - } - } + FlowAntecedent::Multiple(_) => Err(InferFailReason::None), }, None => Err(InferFailReason::None), } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs index 44e85e141..f6d62ea97 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs @@ -207,6 +207,9 @@ pub fn narrow_down_type( .into_iter() .filter_map(|t| narrow_down_type(db, real_source_ref.clone(), t, declared.clone())) .collect::>(); + if source_types.is_empty() { + return None; + } let mut result_type = LuaType::Unknown; for source_type in source_types { result_type = TypeOps::Union.apply(db, &result_type, &source_type);