@@ -68,21 +68,24 @@ def normalize_substrait_type_names(typ: str) -> str:
6868 raise Exception (f"Unrecognized substrait type { typ } " )
6969
7070
71- def violates_integer_option (actual : int , option , parameters : dict ):
71+ def violates_integer_option (actual : int , option , parameters : dict , subset = False ):
72+ option_numeric = None
7273 if isinstance (option , SubstraitTypeParser .NumericLiteralContext ):
73- return actual ! = int (str (option .Number ()))
74+ option_numeric = int (str (option .Number ()))
7475 elif isinstance (option , SubstraitTypeParser .NumericParameterNameContext ):
7576 parameter_name = str (option .Identifier ())
76- if parameter_name in parameters and parameters [parameter_name ] != actual :
77- return True
78- else :
77+
78+ if parameter_name not in parameters :
7979 parameters [parameter_name ] = actual
80+ option_numeric = parameters [parameter_name ]
8081 else :
8182 raise Exception (
8283 f"Input should be either NumericLiteralContext or NumericParameterNameContext, got { type (option )} instead"
8384 )
84-
85- return False
85+ if subset :
86+ return actual < option_numeric
87+ else :
88+ return actual != option_numeric
8689
8790
8891def types_equal (type1 : Type , type2 : Type , check_nullability = False ):
@@ -112,6 +115,27 @@ def handle_parameter_cover(
112115 return True
113116
114117
118+ def _check_nullability (check_nullability , parameterized_type , covered , kind ) -> bool :
119+ if not check_nullability :
120+ return True
121+ # The ANTLR context stores a Token called ``isnull`` – it is
122+ # present when the type is declared as nullable.
123+ nullability = (
124+ Type .Nullability .NULLABILITY_NULLABLE
125+ if getattr (parameterized_type , "isnull" , None ) is not None
126+ else Type .Nullability .NULLABILITY_REQUIRED
127+ )
128+ # if nullability == Type.Nullability.NULLABILITY_NULLABLE:
129+ # return True # is still true even if the covered is required
130+ # The protobuf message stores its own enum – we compare the two.
131+ covered_nullability = getattr (
132+ getattr (covered , kind ), # e.g. covered.varchar
133+ "nullability" ,
134+ None ,
135+ )
136+ return nullability == covered_nullability
137+
138+
115139def covers (
116140 covered : Type ,
117141 covering : SubstraitTypeParser .TypeLiteralContext ,
@@ -123,7 +147,6 @@ def covers(
123147 return handle_parameter_cover (
124148 covered , parameter_name , parameters , check_nullability
125149 )
126-
127150 covering : SubstraitTypeParser .TypeDefContext = covering .typeDef ()
128151
129152 any_type : SubstraitTypeParser .AnyTypeContext = covering .anyType ()
@@ -142,31 +165,99 @@ def covers(
142165
143166 parameterized_type = covering .parameterizedType ()
144167 if parameterized_type :
145- if isinstance (parameterized_type , SubstraitTypeParser .DecimalContext ):
146- if covered .WhichOneof ("kind" ) != "decimal" :
168+ kind = covered .WhichOneof ("kind" )
169+ if isinstance (parameterized_type , SubstraitTypeParser .VarCharContext ):
170+ if kind != "varchar" :
171+ return False
172+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
173+ covered .varchar .length , parameterized_type .length , parameters
174+ ):
147175 return False
148176
149- nullability = (
150- Type .NULLABILITY_NULLABLE
151- if parameterized_type .isnull
152- else Type .NULLABILITY_REQUIRED
177+ return _check_nullability (
178+ check_nullability , parameterized_type , covered , kind
153179 )
154-
155- if (
156- check_nullability
157- and nullability
158- != covered .__getattribute__ ( covered . WhichOneof ( "kind" )). nullability
180+ if isinstance ( parameterized_type , SubstraitTypeParser . FixedCharContext ):
181+ if kind != "fixed_char" :
182+ return False
183+ if hasattr ( parameterized_type , "length" ) and violates_integer_option (
184+ covered .fixed_char . length , parameterized_type . length , parameters
159185 ):
160186 return False
187+ return _check_nullability (
188+ check_nullability , parameterized_type , covered , kind
189+ )
161190
191+ if isinstance (parameterized_type , SubstraitTypeParser .FixedBinaryContext ):
192+ if kind != "fixed_binary" :
193+ return False
194+ if hasattr (parameterized_type , "length" ) and violates_integer_option (
195+ covered .fixed_binary .length , parameterized_type .length , parameters
196+ ):
197+ return False
198+ # return True
199+ return _check_nullability (
200+ check_nullability , parameterized_type , covered , kind
201+ )
202+ if isinstance (parameterized_type , SubstraitTypeParser .DecimalContext ):
203+ if kind != "decimal" :
204+ return False
205+ if not _check_nullability (
206+ check_nullability , parameterized_type , covered , kind
207+ ):
208+ return False
209+ # precision / scale are both optional – a missing value means “no limit”.
210+ covered_scale = getattr (covered .decimal , "scale" , 0 )
211+ param_scale = getattr (parameterized_type , "scale" , 0 )
212+ covered_prec = getattr (covered .decimal , "precision" , 0 )
213+ param_prec = getattr (parameterized_type , "precision" , 0 )
162214 return not (
163- violates_integer_option (
164- covered .decimal .scale , parameterized_type .scale , parameters
165- )
166- or violates_integer_option (
167- covered .decimal .precision , parameterized_type .precision , parameters
168- )
215+ violates_integer_option (covered_scale , param_scale , parameters )
216+ or violates_integer_option (covered_prec , param_prec , parameters )
169217 )
218+ if isinstance (
219+ parameterized_type , SubstraitTypeParser .PrecisionTimestampContext
220+ ):
221+ if kind != "precision_timestamp" :
222+ return False
223+ if not _check_nullability (
224+ check_nullability , parameterized_type , covered , kind
225+ ):
226+ return False
227+ # return True
228+ covered_prec = getattr (covered .precision_timestamp , "precision" , 0 )
229+ param_prec = getattr (parameterized_type , "precision" , 0 )
230+ return not violates_integer_option (covered_prec , param_prec , parameters )
231+
232+ if isinstance (
233+ parameterized_type , SubstraitTypeParser .PrecisionTimestampTZContext
234+ ):
235+ if kind != "precision_timestamp_tz" :
236+ return False
237+ if not _check_nullability (
238+ check_nullability , parameterized_type , covered , kind
239+ ):
240+ return False
241+ # return True
242+ covered_prec = getattr (covered .precision_timestamp_tz , "precision" , 0 )
243+ param_prec = getattr (parameterized_type , "precision" , 0 )
244+ return not violates_integer_option (covered_prec , param_prec , parameters )
245+
246+ kind_mapping = {
247+ SubstraitTypeParser .ListContext : "list" ,
248+ SubstraitTypeParser .MapContext : "map" ,
249+ SubstraitTypeParser .StructContext : "struct" ,
250+ SubstraitTypeParser .UserDefinedContext : "user_defined" ,
251+ SubstraitTypeParser .PrecisionIntervalDayContext : "interval_day" ,
252+ }
253+
254+ for ctx_cls , expected_kind in kind_mapping .items ():
255+ if isinstance (parameterized_type , ctx_cls ):
256+ if kind != expected_kind :
257+ return False
258+ return _check_nullability (
259+ check_nullability , parameterized_type , covered , kind
260+ )
170261 else :
171262 raise Exception (f"Unhandled type { type (parameterized_type )} " )
172263
0 commit comments