Skip to content

Commit 0efa12f

Browse files
SamChou19815facebook-github-bot
authored andcommitted
Hover type with chosen overload
Summary: In this diff, I setup the infra to record the chosen overload, and use it in hover. When we hover on callees, we will show the chosen overload instead. In the next diff, we will use the same infra to power signature help. Reviewed By: kinto0 Differential Revision: D76222224 fbshipit-source-id: 9ff8578bb387ad4aec2b4292150882b2d39cbd13
1 parent b17eb04 commit 0efa12f

File tree

4 files changed

+210
-2
lines changed

4 files changed

+210
-2
lines changed

pyrefly/lib/alt/answers.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ use crate::table;
6161
use crate::table_for_each;
6262
use crate::table_mut_for_each;
6363
use crate::table_try_for_each;
64+
use crate::types::callable::Callable;
6465
use crate::types::class::Class;
6566
use crate::types::equality::TypeEq;
6667
use crate::types::equality::TypeEqCtx;
@@ -81,9 +82,18 @@ pub struct Index {
8182
pub externally_defined_attribute_references: SmallMap<ModulePath, Vec<(TextRange, TextRange)>>,
8283
}
8384

85+
#[derive(Debug)]
86+
struct OverloadedCallee {
87+
all_overloads: Vec<Callable>,
88+
closest_overload: Callable,
89+
is_closest_overload_chosen: bool,
90+
}
91+
8492
#[derive(Debug, Default)]
8593
pub struct Traces {
8694
types: SmallMap<TextRange, Arc<Type>>,
95+
/// A map from (range of callee, overload information)
96+
overloaded_callees: SmallMap<TextRange, OverloadedCallee>,
8797
}
8898

8999
/// Invariants:
@@ -535,6 +545,42 @@ impl Answers {
535545
let lock = self.trace.as_ref()?.lock();
536546
lock.types.get(&range).duped()
537547
}
548+
549+
pub fn get_chosen_overload_trace(&self, range: TextRange) -> Option<Callable> {
550+
let lock = self.trace.as_ref()?.lock();
551+
let overloaded_callee = lock.overloaded_callees.get(&range)?;
552+
if overloaded_callee.is_closest_overload_chosen {
553+
Some(overloaded_callee.closest_overload.clone())
554+
} else {
555+
None
556+
}
557+
}
558+
559+
/// Returns all the overload, and the index of a chosen one
560+
#[allow(dead_code)]
561+
pub fn get_all_overload_trace(
562+
&self,
563+
range: TextRange,
564+
) -> Option<(Vec<Callable>, Option<usize>)> {
565+
let lock = self.trace.as_ref()?.lock();
566+
let overloaded_callee = lock.overloaded_callees.get(&range)?;
567+
let chosen_overload_index =
568+
overloaded_callee
569+
.all_overloads
570+
.iter()
571+
.enumerate()
572+
.find_map(|(index, signature)| {
573+
if signature == &overloaded_callee.closest_overload {
574+
Some(index)
575+
} else {
576+
None
577+
}
578+
});
579+
Some((
580+
overloaded_callee.all_overloads.clone(),
581+
chosen_overload_index,
582+
))
583+
}
538584
}
539585

540586
impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
@@ -699,6 +745,27 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
699745
}
700746
}
701747

748+
/// Record all the overloads and the chosen overload.
749+
/// The trace will be used to power signature help and hover for overloaded functions.
750+
pub fn record_overload_trace(
751+
&self,
752+
loc: TextRange,
753+
all_overloads: &[Callable],
754+
closest_overload: &Callable,
755+
is_closest_overload_chosen: bool,
756+
) {
757+
if let Some(trace) = &self.current.trace {
758+
trace.lock().overloaded_callees.insert(
759+
loc,
760+
OverloadedCallee {
761+
all_overloads: all_overloads.to_vec(),
762+
closest_overload: closest_overload.clone(),
763+
is_closest_overload_chosen,
764+
},
765+
);
766+
}
767+
}
768+
702769
/// Check if `want` matches `got` returning `want` if the check fails.
703770
pub fn check_and_return_type_info(
704771
&self,

pyrefly/lib/alt/call.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
670670
None,
671671
);
672672
if arg_errors.is_empty() && call_errors.is_empty() {
673+
// An overload is chosen, we should record it to power IDE services.
674+
self.record_overload_trace(range, overloads.as_slice(), callable, true);
673675
// It's only safe to return immediately if both arg_errors and call_errors are
674676
// empty, as parameter types from the overload signature may be used as hints when
675677
// evaluating arguments, producing arg_errors for some overloads but not others.
@@ -692,6 +694,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
692694
}
693695
// We're guaranteed to have at least one overload.
694696
let closest_overload = closest_overload.unwrap();
697+
self.record_overload_trace(
698+
range,
699+
overloads.as_slice(),
700+
&closest_overload.signature,
701+
false,
702+
);
695703
errors.extend(closest_overload.arg_errors);
696704
if closest_overload.call_errors.is_empty() {
697705
// No overload evaluated completely successfully, but we still say we found a match if

pyrefly/lib/state/lsp.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8+
use std::sync::Arc;
9+
810
use dupe::Dupe;
911
use itertools::Itertools;
1012
use lsp_types::CompletionItem;
@@ -17,7 +19,9 @@ use pyrefly_util::visit::Visit;
1719
use ruff_python_ast::AnyNodeRef;
1820
use ruff_python_ast::Expr;
1921
use ruff_python_ast::ExprAttribute;
22+
use ruff_python_ast::ExprCall;
2023
use ruff_python_ast::Identifier;
24+
use ruff_python_ast::ModModule;
2125
use ruff_python_ast::Stmt;
2226
use ruff_python_ast::name::Name;
2327
use ruff_text_size::Ranged;
@@ -242,13 +246,45 @@ impl<'a> Transaction<'a> {
242246
if let Some(key) = self.definition_at(handle, position) {
243247
return self.get_type(handle, &key);
244248
}
249+
fn callee_at(mod_module: Arc<ModModule>, position: TextSize) -> Option<ExprCall> {
250+
fn f(x: &Expr, find: TextSize, res: &mut Option<ExprCall>) {
251+
if let Expr::Call(call) = x
252+
&& call.func.range().contains_inclusive(find)
253+
{
254+
f(call.func.as_ref(), find, res);
255+
if res.is_some() {
256+
return;
257+
}
258+
*res = Some(call.clone());
259+
} else {
260+
x.recurse(&mut |x| f(x, find, res));
261+
}
262+
}
263+
let mut res = None;
264+
mod_module.visit(&mut |x| f(x, position, &mut res));
265+
res
266+
}
267+
let callee = callee_at(self.get_ast(handle)?, position);
245268
match self.identifier_at(handle, position) {
246269
Some(IdentifierWithContext {
247270
identifier: id,
248271
context: IdentifierContext::Expr,
249272
}) => {
250273
if self.get_bindings(handle)?.is_valid_usage(&id) {
251-
return self.get_type(handle, &Key::BoundName(ShortIdentifier::new(&id)));
274+
if let Some(ExprCall {
275+
range: _,
276+
func,
277+
arguments,
278+
}) = &callee
279+
&& func.range() == id.range
280+
&& let Some(chosen_overload) = self
281+
.get_answers(handle)
282+
.and_then(|answers| answers.get_chosen_overload_trace(arguments.range))
283+
{
284+
return Some(Type::Callable(Box::new(chosen_overload)));
285+
} else {
286+
return self.get_type(handle, &Key::BoundName(ShortIdentifier::new(&id)));
287+
}
252288
} else {
253289
return None;
254290
}
@@ -269,7 +305,20 @@ impl<'a> Transaction<'a> {
269305
None => {}
270306
}
271307
let attribute = self.attribute_at(handle, position)?;
272-
self.get_type_trace(handle, attribute.range)
308+
if let Some(ExprCall {
309+
range: _,
310+
func,
311+
arguments,
312+
}) = &callee
313+
&& func.range() == attribute.range
314+
&& let Some(chosen_overload) = self
315+
.get_answers(handle)
316+
.and_then(|answers| answers.get_chosen_overload_trace(arguments.range))
317+
{
318+
Some(Type::Callable(Box::new(chosen_overload)))
319+
} else {
320+
self.get_type_trace(handle, attribute.range)
321+
}
273322
}
274323

275324
fn resolve_named_import(

pyrefly/lib/test/lsp/hover_type.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use ruff_text_size::TextSize;
1111
use crate::state::handle::Handle;
1212
use crate::state::state::State;
1313
use crate::test::util::get_batched_lsp_operations_report;
14+
use crate::test::util::get_batched_lsp_operations_report_allow_error;
1415

1516
fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String {
1617
if let Some(t) = state.transaction().get_type_at(handle, position) {
@@ -282,3 +283,86 @@ Hover Result: `Literal[5] | int`
282283
report.trim(),
283284
);
284285
}
286+
287+
#[test]
288+
fn overloaded_functions_test() {
289+
let code = r#"
290+
from typing import overload
291+
292+
@overload
293+
def overloaded_func(a: str) -> bool: ...
294+
@overload
295+
def overloaded_func(a: int, b: bool) -> str: ...
296+
def overloaded_func():
297+
pass
298+
299+
overloaded_func("")
300+
# ^
301+
overloaded_func(1, True)
302+
# ^
303+
overloaded_func(False)
304+
# ^
305+
306+
"#;
307+
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
308+
assert_eq!(
309+
r#"
310+
# main.py
311+
11 | overloaded_func("")
312+
^
313+
Hover Result: `(a: str) -> bool`
314+
315+
13 | overloaded_func(1, True)
316+
^
317+
Hover Result: `(a: int, b: bool) -> str`
318+
319+
15 | overloaded_func(False)
320+
^
321+
Hover Result: `Overload[(a: str) -> bool, (a: int, b: bool) -> str]`
322+
"#
323+
.trim(),
324+
report.trim(),
325+
);
326+
}
327+
328+
#[test]
329+
fn overloaded_methods_test() {
330+
let code = r#"
331+
from typing import overload
332+
333+
class Foo:
334+
@overload
335+
def overloaded_meth(self, a: str) -> bool: ...
336+
@overload
337+
def overloaded_meth(self, a: int, b: bool) -> str: ...
338+
def overloaded_meth(self):
339+
pass
340+
341+
foo = Foo()
342+
foo.overloaded_meth("")
343+
# ^
344+
foo.overloaded_meth(1, True)
345+
# ^
346+
foo.overloaded_meth(False)
347+
# ^
348+
"#;
349+
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
350+
assert_eq!(
351+
r#"
352+
# main.py
353+
13 | foo.overloaded_meth("")
354+
^
355+
Hover Result: `(self: Foo, a: str) -> bool`
356+
357+
15 | foo.overloaded_meth(1, True)
358+
^
359+
Hover Result: `(self: Foo, a: int, b: bool) -> str`
360+
361+
17 | foo.overloaded_meth(False)
362+
^
363+
Hover Result: `BoundMethod[Foo, Overload[(self: Self@Foo, a: str) -> bool, (self: Self@Foo, a: int, b: bool) -> str]]`
364+
"#
365+
.trim(),
366+
report.trim(),
367+
);
368+
}

0 commit comments

Comments
 (0)