Skip to content
This repository was archived by the owner on Sep 12, 2024. It is now read-only.

Commit a9c6a2b

Browse files
committed
expression typeinfo implemented
1 parent 4fdfafd commit a9c6a2b

File tree

3 files changed

+169
-107
lines changed

3 files changed

+169
-107
lines changed

jaclang/compiler/absyntree.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
5050
self.meta: dict[str, str] = {}
5151
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())
5252

53+
# NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the
54+
# subclasses, Adding it here, this needs a review.
55+
self.expr_type: str = ""
56+
5357
@property
5458
def sym_tab(self) -> SymbolTable:
5559
"""Get symbol table."""

jaclang/compiler/passes/main/fuse_typeinfo_pass.py

Lines changed: 74 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
from __future__ import annotations
88

9-
from typing import Callable, TypeVar
9+
from types import MethodType
10+
from typing import Callable, Optional, TypeVar
1011

1112
import jaclang.compiler.absyntree as ast
1213
from jaclang.compiler.passes import Pass
14+
from jaclang.compiler.passes.transform import Transform
1315
from jaclang.settings import settings
1416
from jaclang.utils.helpers import pascal_to_snake
1517
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack
1618

17-
1819
import mypy.nodes as MypyNodes # noqa N812
1920
import mypy.types as MypyTypes # noqa N812
2021
from mypy.checkexpr import Type as MyType
@@ -23,11 +24,82 @@
2324
T = TypeVar("T", bound=ast.AstSymbolNode)
2425

2526

27+
# List of expression nodes which we'll be extracting the type info from.
28+
JAC_EXPR_NODES = (
29+
ast.AwaitExpr,
30+
ast.BinaryExpr,
31+
ast.CompareExpr,
32+
ast.BoolExpr,
33+
ast.LambdaExpr,
34+
ast.UnaryExpr,
35+
ast.IfElseExpr,
36+
ast.AtomTrailer,
37+
ast.AtomUnit,
38+
ast.YieldExpr,
39+
ast.YieldExpr,
40+
ast.FuncCall,
41+
ast.EdgeRefTrailer,
42+
ast.ListVal,
43+
ast.SetVal,
44+
ast.TupleVal,
45+
ast.DictVal,
46+
ast.ListCompr,
47+
ast.DictCompr,
48+
)
49+
50+
2651
class FuseTypeInfoPass(Pass):
2752
"""Python and bytecode file self.__debug_printing pass."""
2853

2954
node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}
3055

56+
@staticmethod
57+
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
58+
"""
59+
Enter an expression node.
60+
61+
This function is dynamically bound as a method on insntace of this class, since the
62+
group of functions to handle expressions has a the exact same logic.
63+
"""
64+
if len(node.gen.mypy_ast) == 0:
65+
return
66+
67+
# If the corrosponding mypy ast node type has stored here, get the values.
68+
mypy_node = node.gen.mypy_ast[0]
69+
if mypy_node in self.node_type_hash:
70+
mytype: MyType = self.node_type_hash[mypy_node]
71+
node.expr_type = str(mytype)
72+
73+
# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
74+
# expression. Time and memory wasted here.
75+
collection_types_map = {
76+
ast.ListVal: "builtins.list",
77+
ast.SetVal: "builtins.set",
78+
ast.TupleVal: "builtins.tuple",
79+
ast.DictVal: "builtins.dict",
80+
ast.ListCompr: None,
81+
ast.DictCompr: None,
82+
}
83+
84+
# Set they symbol type for collection expression.
85+
if type(node) in tuple(collection_types_map.keys()):
86+
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
87+
if mypy_node in self.node_type_hash:
88+
node.name_spec.sym_type = str(mytype)
89+
collection_type = collection_types_map[type(node)]
90+
if collection_type is not None:
91+
node.name_spec.sym_type = collection_type
92+
93+
def __init__(self, input_ir: T, prior: Optional[Transform]) -> None:
94+
"""Initialize the FuseTpeInfoPass instance."""
95+
for expr_node in JAC_EXPR_NODES:
96+
method_name = "enter_" + pascal_to_snake(expr_node.__name__)
97+
method = MethodType(
98+
FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self
99+
)
100+
setattr(self, method_name, method)
101+
super().__init__(input_ir, prior)
102+
31103
def __debug_print(self, *argv: object) -> None:
32104
if settings.fuse_type_info_debug:
33105
self.log_info("FuseTypeInfo::", *argv)
@@ -310,54 +382,6 @@ def enter_f_string(self, node: ast.FString) -> None:
310382
"""Pass handler for FString nodes."""
311383
self.__debug_print("Getting type not supported in", type(node))
312384

313-
@__handle_node
314-
def enter_list_val(self, node: ast.ListVal) -> None:
315-
"""Pass handler for ListVal nodes."""
316-
mypy_node = node.gen.mypy_ast[0]
317-
if mypy_node in self.node_type_hash:
318-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
319-
else:
320-
node.name_spec.sym_type = "builtins.list"
321-
322-
@__handle_node
323-
def enter_set_val(self, node: ast.SetVal) -> None:
324-
"""Pass handler for SetVal nodes."""
325-
mypy_node = node.gen.mypy_ast[0]
326-
if mypy_node in self.node_type_hash:
327-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
328-
else:
329-
node.name_spec.sym_type = "builtins.set"
330-
331-
@__handle_node
332-
def enter_tuple_val(self, node: ast.TupleVal) -> None:
333-
"""Pass handler for TupleVal nodes."""
334-
mypy_node = node.gen.mypy_ast[0]
335-
if mypy_node in self.node_type_hash:
336-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
337-
else:
338-
node.name_spec.sym_type = "builtins.tuple"
339-
340-
@__handle_node
341-
def enter_dict_val(self, node: ast.DictVal) -> None:
342-
"""Pass handler for DictVal nodes."""
343-
mypy_node = node.gen.mypy_ast[0]
344-
if mypy_node in self.node_type_hash:
345-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
346-
else:
347-
node.name_spec.sym_type = "builtins.dict"
348-
349-
@__handle_node
350-
def enter_list_compr(self, node: ast.ListCompr) -> None:
351-
"""Pass handler for ListCompr nodes."""
352-
mypy_node = node.gen.mypy_ast[0]
353-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
354-
355-
@__handle_node
356-
def enter_dict_compr(self, node: ast.DictCompr) -> None:
357-
"""Pass handler for DictCompr nodes."""
358-
mypy_node = node.gen.mypy_ast[0]
359-
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
360-
361385
@__handle_node
362386
def enter_index_slice(self, node: ast.IndexSlice) -> None:
363387
"""Pass handler for IndexSlice nodes."""

jaclang/compiler/passes/utils/mypy_ast_build.py

Lines changed: 91 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44

55
import ast
66
import os
7+
from types import MethodType
78

89
from jaclang.compiler.absyntree import AstNode
910
from jaclang.compiler.passes import Pass
1011
from jaclang.compiler.passes.main.fuse_typeinfo_pass import (
1112
FuseTypeInfoPass,
1213
)
14+
from jaclang.utils.helpers import pascal_to_snake
1315

1416
import mypy.build as myb
1517
import mypy.checkexpr as mycke
1618
import mypy.errors as mye
1719
import mypy.fastparse as myfp
20+
import mypy.nodes as mypy_nodes
1821
from mypy.build import BuildSource
1922
from mypy.build import BuildSourceSet
2023
from mypy.build import FileSystemCache
@@ -29,6 +32,55 @@
2932
from mypy.semanal_main import semantic_analysis_for_scc
3033

3134

35+
# All the expression nodes of mypy.
36+
EXPRESSION_NODES = (
37+
mypy_nodes.AssertTypeExpr,
38+
mypy_nodes.AssignmentExpr,
39+
mypy_nodes.AwaitExpr,
40+
mypy_nodes.BytesExpr,
41+
mypy_nodes.CallExpr,
42+
mypy_nodes.CastExpr,
43+
mypy_nodes.ComparisonExpr,
44+
mypy_nodes.ComplexExpr,
45+
mypy_nodes.ConditionalExpr,
46+
mypy_nodes.DictionaryComprehension,
47+
mypy_nodes.DictExpr,
48+
mypy_nodes.EllipsisExpr,
49+
mypy_nodes.EnumCallExpr,
50+
mypy_nodes.Expression,
51+
mypy_nodes.FloatExpr,
52+
mypy_nodes.GeneratorExpr,
53+
mypy_nodes.IndexExpr,
54+
mypy_nodes.IntExpr,
55+
mypy_nodes.LambdaExpr,
56+
mypy_nodes.ListComprehension,
57+
mypy_nodes.ListExpr,
58+
mypy_nodes.MemberExpr,
59+
mypy_nodes.NamedTupleExpr,
60+
mypy_nodes.NameExpr,
61+
mypy_nodes.NewTypeExpr,
62+
mypy_nodes.OpExpr,
63+
mypy_nodes.ParamSpecExpr,
64+
mypy_nodes.PromoteExpr,
65+
mypy_nodes.RefExpr,
66+
mypy_nodes.RevealExpr,
67+
mypy_nodes.SetComprehension,
68+
mypy_nodes.SetExpr,
69+
mypy_nodes.SliceExpr,
70+
mypy_nodes.StarExpr,
71+
mypy_nodes.StrExpr,
72+
mypy_nodes.SuperExpr,
73+
mypy_nodes.TupleExpr,
74+
mypy_nodes.TypeAliasExpr,
75+
mypy_nodes.TypedDictExpr,
76+
mypy_nodes.TypeVarExpr,
77+
mypy_nodes.TypeVarTupleExpr,
78+
mypy_nodes.UnaryExpr,
79+
mypy_nodes.YieldExpr,
80+
mypy_nodes.YieldFromExpr,
81+
)
82+
83+
3284
mypy_to_jac_node_map: dict[
3385
tuple[int, int | None, int | None, int | None], list[AstNode]
3486
] = {}
@@ -87,63 +139,45 @@ def __init__(
87139
"""Override to mypy expression checker for direct AST pass through."""
88140
super().__init__(tc, msg, plugin, per_line_checking_time_ns)
89141

90-
def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type:
91-
"""Type check a list expression [...]."""
92-
out = super().visit_list_expr(e)
93-
FuseTypeInfoPass.node_type_hash[e] = out
94-
return out
95-
96-
def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type:
97-
"""Type check a set expression {...}."""
98-
out = super().visit_set_expr(e)
99-
FuseTypeInfoPass.node_type_hash[e] = out
100-
return out
101-
102-
def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type:
103-
"""Type check a tuple expression (...)."""
104-
out = super().visit_tuple_expr(e)
105-
FuseTypeInfoPass.node_type_hash[e] = out
106-
return out
107-
108-
def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type:
109-
"""Type check a dictionary expression {...}."""
110-
out = super().visit_dict_expr(e)
111-
FuseTypeInfoPass.node_type_hash[e] = out
112-
return out
113-
114-
def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type:
115-
"""Type check a list comprehension."""
116-
out = super().visit_list_comprehension(e)
117-
FuseTypeInfoPass.node_type_hash[e] = out
118-
return out
119-
120-
def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type:
121-
"""Type check a set comprehension."""
122-
out = super().visit_set_comprehension(e)
123-
FuseTypeInfoPass.node_type_hash[e] = out
124-
return out
125-
126-
def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type:
127-
"""Type check a generator expression."""
128-
out = super().visit_generator_expr(e)
129-
FuseTypeInfoPass.node_type_hash[e] = out
130-
return out
131-
132-
def visit_dictionary_comprehension(
133-
self, e: myfp.DictionaryComprehension
134-
) -> myb.Type:
135-
"""Type check a dict comprehension."""
136-
out = super().visit_dictionary_comprehension(e)
137-
FuseTypeInfoPass.node_type_hash[e] = out
138-
return out
139-
140-
def visit_member_expr(
141-
self, e: myfp.MemberExpr, is_lvalue: bool = False
142-
) -> myb.Type:
143-
"""Type check a member expr."""
144-
out = super().visit_member_expr(e, is_lvalue)
145-
FuseTypeInfoPass.node_type_hash[e] = out
146-
return out
142+
# For every expression there, create attach a method on this instance (self) named "enter_expr()"
143+
for expr_node in EXPRESSION_NODES:
144+
method_name = "visit_" + pascal_to_snake(expr_node.__name__)
145+
146+
# We call the super() version of the method so ensure the parent class has the method or else continue.
147+
if not hasattr(mycke.ExpressionChecker, method_name):
148+
continue
149+
150+
# If the method already overriden then don't override it again here. Continue. Note that the method exists
151+
# on the parent class and if it's also exists on this class and it's a different object that means it's
152+
# overrident method.
153+
if getattr(mycke.ExpressionChecker, method_name) != getattr(
154+
ExpressionChecker, method_name
155+
):
156+
continue
157+
158+
# Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the
159+
# "method_name" variable is used inside a loop and by the time the closure close the "method_name" value,
160+
# it'll be changed by the loop, so we need another method ("make_closure") to persist the value.
161+
def make_closure(method_name: str): # noqa: ANN201
162+
def closure(
163+
self: ExpressionChecker,
164+
e: mycke.Expression,
165+
*args, # noqa: ANN002
166+
**kwargs, # noqa: ANN003
167+
) -> mycke.Type:
168+
# Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm
169+
# (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269).
170+
out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023
171+
self, e, *args, **kwargs
172+
)
173+
FuseTypeInfoPass.node_type_hash[e] = out
174+
return out
175+
176+
return closure
177+
178+
# Attach the new "visit_expr()" method to this instance.
179+
method = make_closure(method_name)
180+
setattr(self, method_name, MethodType(method, self))
147181

148182

149183
class State(myb.State):

0 commit comments

Comments
 (0)