diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index 39ed5e6..92b560a 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -1,4 +1,5 @@ from typing import Iterable + import substrait.gen.proto.type_pb2 as stt diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index f4d18d7..17950c6 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -1,8 +1,10 @@ from typing import Optional -from antlr4 import InputStream, CommonTokenStream + +from antlr4 import CommonTokenStream, InputStream + from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser -from substrait.gen.proto.type_pb2 import Type +from substrait.gen.proto.type_pb2 import NamedStruct, Type def _evaluate(x, values: dict): @@ -65,22 +67,128 @@ def _evaluate(x, values: dict): return Type(fp64=Type.FP64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): return Type(bool=Type.Boolean(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.StringContext): + return Type(string=Type.String(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext): + return Type(timestamp=Type.Timestamp(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.DateContext): + return Type(date=Type.Date(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext): + return Type(interval_year=Type.IntervalYear(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.UuidContext): + return Type(uuid=Type.UUID(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext): + return Type(binary=Type.Binary(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimeContext): + return Type(time=Type.Time(nullability=nullability)) + elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext): + return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability)) else: raise Exception(f"Unknown scalar type {type(scalar_type)}") elif parametrized_type: + nullability = ( + Type.NULLABILITY_NULLABLE + if parametrized_type.isnull + else Type.NULLABILITY_REQUIRED + ) if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): precision = _evaluate(parametrized_type.precision, values) scale = _evaluate(parametrized_type.scale, values) - nullability = ( - Type.NULLABILITY_NULLABLE - if parametrized_type.isnull - else Type.NULLABILITY_REQUIRED - ) return Type( decimal=Type.Decimal( precision=precision, scale=scale, nullability=nullability ) ) + elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext): + length = _evaluate(parametrized_type.length, values) + return Type( + varchar=Type.VarChar( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext): + length = _evaluate(parametrized_type.length, values) + return Type( + fixed_char=Type.FixedChar( + length=length, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext): + length = _evaluate(parametrized_type.length, values) + return Type( + fixed_binary=Type.FixedBinary( + length=length, + nullability=nullability, + ) + ) + elif isinstance( + parametrized_type, SubstraitTypeParser.PrecisionTimestampContext + ): + precision = _evaluate(parametrized_type.precision, values) + return Type( + precision_timestamp=Type.PrecisionTimestamp( + precision=precision, + nullability=nullability, + ) + ) + elif isinstance( + parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext + ): + precision = _evaluate(parametrized_type.precision, values) + return Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + precision=precision, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext): + return Type( + interval_year=Type.IntervalYear( + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.StructContext): + types = list( + map(lambda x: _evaluate(x, values), parametrized_type.expr()) + ) + return Type( + struct=Type.Struct( + types=types, + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.ListContext): + list_type = _evaluate(parametrized_type.expr(), values) + return Type( + list=Type.List( + type=list_type, + nullability=nullability, + ) + ) + + elif isinstance(parametrized_type, SubstraitTypeParser.MapContext): + return Type( + map=Type.Map( + key=_evaluate(parametrized_type.key, values), + value=_evaluate(parametrized_type.value, values), + nullability=nullability, + ) + ) + elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext): + names = list(map(lambda k: k.getText(), parametrized_type.Identifier())) + struct = Type.Struct( + types=list( + map(lambda k: _evaluate(k, values), parametrized_type.expr()) + ), + nullability=nullability, + ) + return NamedStruct( + names=names, + struct=struct, + ) + raise Exception(f"Unknown parametrized type {type(parametrized_type)}") elif any_type: any_var = any_type.AnyVar() diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index c854c02..58bd6f5 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -1,17 +1,19 @@ -import yaml import itertools import re -from substrait.gen.proto.type_pb2 import Type -from importlib.resources import files as importlib_files from collections import defaultdict +from importlib.resources import files as importlib_files from pathlib import Path from typing import Optional, Union -from .derivation_expression import evaluate, _evaluate, _parse + +import yaml + from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser from substrait.gen.json import simple_extensions as se +from substrait.gen.proto.type_pb2 import Type from substrait.simple_extension_utils import build_simple_extensions -from .bimap import UriUrnBiDiMap +from .bimap import UriUrnBiDiMap +from .derivation_expression import _evaluate, _parse, evaluate DEFAULT_URN_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" @@ -69,20 +71,20 @@ def normalize_substrait_type_names(typ: str) -> str: def violates_integer_option(actual: int, option, parameters: dict): + option_numeric = None if isinstance(option, SubstraitTypeParser.NumericLiteralContext): - return actual != int(str(option.Number())) + option_numeric = int(str(option.Number())) elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): parameter_name = str(option.Identifier()) - if parameter_name in parameters and parameters[parameter_name] != actual: - return True - else: + + if parameter_name not in parameters: parameters[parameter_name] = actual + option_numeric = parameters[parameter_name] else: raise Exception( f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" ) - - return False + return actual != option_numeric def types_equal(type1: Type, type2: Type, check_nullability=False): @@ -112,6 +114,27 @@ def handle_parameter_cover( return True +def _check_nullability(check_nullability, parameterized_type, covered, kind) -> bool: + if not check_nullability: + return True + # The ANTLR context stores a Token called ``isnull`` – it is + # present when the type is declared as nullable. + nullability = ( + Type.Nullability.NULLABILITY_NULLABLE + if getattr(parameterized_type, "isnull", None) is not None + else Type.Nullability.NULLABILITY_REQUIRED + ) + # if nullability == Type.Nullability.NULLABILITY_NULLABLE: + # return True # is still true even if the covered is required + # The protobuf message stores its own enum – we compare the two. + covered_nullability = getattr( + getattr(covered, kind), # e.g. covered.varchar + "nullability", + None, + ) + return nullability == covered_nullability + + def covers( covered: Type, covering: SubstraitTypeParser.TypeLiteralContext, @@ -123,7 +146,6 @@ def covers( return handle_parameter_cover( covered, parameter_name, parameters, check_nullability ) - covering: SubstraitTypeParser.TypeDefContext = covering.typeDef() any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType() @@ -142,33 +164,119 @@ def covers( parameterized_type = covering.parameterizedType() if parameterized_type: - if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): - if covered.WhichOneof("kind") != "decimal": - return False + return _cover_parametrized_type( + covered, parameterized_type, parameters, check_nullability + ) - nullability = ( - Type.NULLABILITY_NULLABLE - if parameterized_type.isnull - else Type.NULLABILITY_REQUIRED + +def check_violates_integer_option_parameters( + covered, parameterized_type, attributes, parameters +): + for attr in attributes: + if not hasattr(covered, attr) and not hasattr(parameterized_type, attr): + return False + covered_attr = getattr(covered, attr) + param_attr = getattr(parameterized_type, attr) + if violates_integer_option(covered_attr, param_attr, parameters): + return True + return False + + +def _cover_parametrized_type( + covered: Type, + parameterized_type: SubstraitTypeParser.ParameterizedTypeContext, + parameters: dict, + check_nullability=False, +): + kind = covered.WhichOneof("kind") + + if not _check_nullability(check_nullability, parameterized_type, covered, kind): + return False + + if isinstance(parameterized_type, SubstraitTypeParser.VarCharContext): + return kind == "varchar" and not check_violates_integer_option_parameters( + covered.varchar, parameterized_type, ["length"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.FixedCharContext): + return kind == "fixed_char" and not check_violates_integer_option_parameters( + covered.fixed_char, parameterized_type, ["length"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.FixedBinaryContext): + return kind == "fixed_binary" and not check_violates_integer_option_parameters( + covered.fixed_binary, parameterized_type, ["length"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + return kind == "decimal" and not check_violates_integer_option_parameters( + covered.decimal, parameterized_type, ["scale", "precision"], parameters + ) + + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampContext): + return ( + kind == "precision_timestamp" + and not check_violates_integer_option_parameters( + covered.precision_timestamp, + parameterized_type, + ["precision"], + parameters, ) + ) + + if isinstance(parameterized_type, SubstraitTypeParser.PrecisionTimestampTZContext): + return ( + kind == "precision_timestamp_tz" + and not check_violates_integer_option_parameters( + covered.precision_timestamp_tz, + parameterized_type, + ["precision"], + parameters, + ) + ) + + if isinstance(parameterized_type, SubstraitTypeParser.ListContext): + return kind == "list" and covers( + covered.list.type, + parameterized_type.expr(), + parameters, + check_nullability, + ) + + if isinstance(parameterized_type, SubstraitTypeParser.MapContext): + return ( + kind == "map" + and covers( + covered.map.key, parameterized_type.key, parameters, check_nullability + ) + and covers( + covered.map.value, + parameterized_type.value, + parameters, + check_nullability, + ) + ) - if ( - check_nullability - and nullability - != covered.__getattribute__(covered.WhichOneof("kind")).nullability + if isinstance(parameterized_type, SubstraitTypeParser.StructContext): + if kind != "struct": + return False + covered_types = covered.struct.types + param_types = parameterized_type.expr() or [] + if not isinstance(param_types, list): + param_types = [param_types] + if len(covered_types) != len(param_types): + return False + for covered_field, param_field_ctx in zip(covered_types, param_types): + if not covers( + covered_field, + param_field_ctx, + parameters, + check_nullability, # type: ignore ): return False + return True - return not ( - violates_integer_option( - covered.decimal.scale, parameterized_type.scale, parameters - ) - or violates_integer_option( - covered.decimal.precision, parameterized_type.precision, parameters - ) - ) - else: - raise Exception(f"Unhandled type {type(parameterized_type)}") + raise Exception(f"Unhandled type {type(parameterized_type)}") class FunctionEntry: @@ -231,14 +339,12 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: output_type = evaluate(self.impl.return_, parameters) if self.nullability == se.NullabilityHandling.MIRROR: - sig_contains_nullable = any( - [ - p.__getattribute__(p.WhichOneof("kind")).nullability - == Type.NULLABILITY_NULLABLE - for p in signature - if isinstance(p, Type) - ] - ) + sig_contains_nullable = any([ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if isinstance(p, Type) + ]) output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( Type.NULLABILITY_NULLABLE if sig_contains_nullable diff --git a/tests/test_derivation_expression.py b/tests/test_derivation_expression.py index 4b11b3d..68c29b0 100644 --- a/tests/test_derivation_expression.py +++ b/tests/test_derivation_expression.py @@ -1,4 +1,4 @@ -from substrait.gen.proto.type_pb2 import Type +from substrait.gen.proto.type_pb2 import NamedStruct, Type from substrait.derivation_expression import evaluate @@ -113,3 +113,59 @@ def func(P1, S1, P2, S2): ) == func_eval ) + + +def test_struct_simple(): + """Test simple struct with two i32 fields.""" + result = evaluate("struct", {}) + expected = Type( + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + ], + nullability=Type.NULLABILITY_REQUIRED, + ) + ) + assert result == expected + + +def test_nstruct_simple(): + """Test named struct with field names and types.""" + result = evaluate("nStruct", {}) + expected = NamedStruct( + names=["a", "b"], + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + ], + nullability=Type.NULLABILITY_REQUIRED, + ), + ) + assert result == expected + + +def test_nstruct_nested(): + """Test named struct with nested struct field.""" + result = evaluate("nStruct>", {}) + expected = NamedStruct( + names=["a", "b", "c"], + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type( + struct=Type.Struct( + types=[ + Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)), + Type(fp32=Type.FP32(nullability=Type.NULLABILITY_REQUIRED)), + ], + nullability=Type.NULLABILITY_REQUIRED, + ) + ), + ], + nullability=Type.NULLABILITY_REQUIRED, + ), + ) + assert result == expected diff --git a/tests/test_extension_registry.py b/tests/test_extension_registry.py index f9d63bd..1b17f5c 100644 --- a/tests/test_extension_registry.py +++ b/tests/test_extension_registry.py @@ -1,9 +1,22 @@ import pytest import yaml -from substrait.gen.proto.type_pb2 import Type -from substrait.extension_registry import ExtensionRegistry, covers +from substrait.builders.type import ( + decimal, + i8, + i16, + i32, + struct, +) +from substrait.builders.type import ( + list as list_, +) +from substrait.builders.type import ( + map as map_, +) from substrait.derivation_expression import _parse +from substrait.extension_registry import ExtensionRegistry, covers +from substrait.gen.proto.type_pb2 import Type content = """%YAML 1.2 --- @@ -104,10 +117,19 @@ value: decimal nullability: DISCRETE return: decimal? + - name: "equal_test" + impls: + - args: + - name: x + value: any + - name: y + value: any + nullability: DISCRETE + return: any """ -registry = ExtensionRegistry() +registry = ExtensionRegistry(load_default_extensions=True) registry.register_extension_dict( yaml.safe_load(content), @@ -115,52 +137,12 @@ ) -def i8(nullable=False): - return Type( - i8=Type.I8( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def i16(nullable=False): - return Type( - i16=Type.I16( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def bool(nullable=False): - return Type( - bool=Type.Boolean( - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE - ) - ) - - -def decimal(precision, scale, nullable=False): - return Type( - decimal=Type.Decimal( - scale=scale, - precision=precision, - nullability=Type.NULLABILITY_REQUIRED - if not nullable - else Type.NULLABILITY_NULLABLE, - ) - ) - - def test_non_existing_urn(): assert ( registry.lookup_function( - urn="non_existent", function_name="add", signature=[i8(), i8()] + urn="non_existent", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -169,7 +151,9 @@ def test_non_existing_urn(): def test_non_existing_function(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="sub", signature=[i8(), i8()] + urn="extension:test:functions", + function_name="sub", + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -178,7 +162,9 @@ def test_non_existing_function(): def test_non_existing_function_signature(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8()] + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False)], ) is None ) @@ -186,7 +172,9 @@ def test_non_existing_function_signature(): def test_exact_match(): assert registry.lookup_function( - urn="extension:test:functions", function_name="add", signature=[i8(), i8()] + urn="extension:test:functions", + function_name="add", + signature=[i8(nullable=False), i8(nullable=False)], )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) @@ -194,7 +182,7 @@ def test_wildcard_match(): assert registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(), i8(), bool()], + signature=[i8(nullable=False), i8(nullable=False), bool()], )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) @@ -203,49 +191,42 @@ def test_wildcard_match_fails_with_constraits(): registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(), i16(), i16()], + signature=[i8(nullable=False), i16(nullable=False), i16(nullable=False)], ) is None ) def test_wildcard_match_with_constraits(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="add", - signature=[i16(), i16(), i8()], - )[1] - == i8() - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="add", + signature=[i16(nullable=False), i16(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=False) def test_variadic(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn", - signature=[i8(), i8(), i8()], - )[1] - == i8() - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn", + signature=[i8(nullable=False), i8(nullable=False), i8(nullable=False)], + )[1] == i8(nullable=False) def test_variadic_any(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_fn_variadic_any", - signature=[i16(), i16(), i16()], - )[1] - == i16() - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_fn_variadic_any", + signature=[i16(nullable=False), i16(nullable=False), i16(nullable=False)], + )[1] == i16(nullable=False) def test_variadic_fails_min_constraint(): assert ( registry.lookup_function( - urn="extension:test:functions", function_name="test_fn", signature=[i8()] + urn="extension:test:functions", + function_name="test_fn", + signature=[i8(nullable=False)], ) is None ) @@ -255,8 +236,8 @@ def test_decimal_happy_path(): assert registry.lookup_function( urn="extension:test:functions", function_name="test_decimal", - signature=[decimal(10, 8), decimal(8, 6)], - )[1] == decimal(11, 7) + signature=[decimal(8, 10, nullable=False), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=False) def test_decimal_violates_constraint(): @@ -264,7 +245,7 @@ def test_decimal_violates_constraint(): registry.lookup_function( urn="extension:test:functions", function_name="test_decimal", - signature=[decimal(10, 8), decimal(12, 10)], + signature=[decimal(8, 10, nullable=False), decimal(10, 12, nullable=False)], ) is None ) @@ -274,19 +255,16 @@ def test_decimal_happy_path_discrete(): assert registry.lookup_function( urn="extension:test:functions", function_name="test_decimal_discrete", - signature=[decimal(10, 8, nullable=True), decimal(8, 6)], - )[1] == decimal(11, 7, nullable=True) + signature=[decimal(8, 10, nullable=True), decimal(6, 8, nullable=False)], + )[1] == decimal(7, 11, nullable=True) def test_enum_with_valid_option(): - assert ( - registry.lookup_function( - urn="extension:test:functions", - function_name="test_enum", - signature=["FLIP", i8()], - )[1] - == i8() - ) + assert registry.lookup_function( + urn="extension:test:functions", + function_name="test_enum", + signature=["FLIP", i8(nullable=False)], + )[1] == i8(nullable=False) def test_enum_with_nonexistent_option(): @@ -294,7 +272,7 @@ def test_enum_with_nonexistent_option(): registry.lookup_function( urn="extension:test:functions", function_name="test_enum", - signature=["NONEXISTENT", i8()], + signature=["NONEXISTENT", i8(nullable=False)], ) is None ) @@ -304,7 +282,7 @@ def test_function_with_nullable_args(): assert registry.lookup_function( urn="extension:test:functions", function_name="add", - signature=[i8(nullable=True), i8()], + signature=[i8(nullable=True), i8(nullable=False)], )[1] == i8(nullable=True) @@ -312,7 +290,7 @@ def test_function_with_declared_output_nullability(): assert registry.lookup_function( urn="extension:test:functions", function_name="add_declared", - signature=[i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False)], )[1] == i8(nullable=True) @@ -320,7 +298,7 @@ def test_function_with_discrete_nullability(): assert registry.lookup_function( urn="extension:test:functions", function_name="add_discrete", - signature=[i8(nullable=True), i8()], + signature=[i8(nullable=True), i8(nullable=False)], )[1] == i8(nullable=True) @@ -329,7 +307,7 @@ def test_function_with_discrete_nullability_nonexisting(): registry.lookup_function( urn="extension:test:functions", function_name="add_discrete", - signature=[i8(), i8()], + signature=[i8(nullable=False), i8(nullable=False)], ) is None ) @@ -337,7 +315,7 @@ def test_function_with_discrete_nullability_nonexisting(): def test_covers(): params = {} - assert covers(i8(), _parse("i8"), params) + assert covers(i8(nullable=False), _parse("i8"), params) assert params == {} @@ -346,18 +324,132 @@ def test_covers_nullability(): assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) -def test_covers_decimal(): - assert not covers(decimal(10, 8), _parse("decimal<11, A>"), {}) +def test_covers_decimal(nullable=False): + assert not covers(decimal(8, 10), _parse("decimal<11, A>"), {}) + assert covers(decimal(8, 10), _parse("decimal<10, A>"), {}) + assert covers(decimal(8, 10), _parse("decimal<10, 8>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<10, 9>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<11, 8>"), {}) + assert not covers(decimal(8, 10), _parse("decimal<11, 9>"), {}) def test_covers_decimal_happy_path(): params = {} - assert covers(decimal(10, 8), _parse("decimal<10, A>"), params) + assert covers(decimal(8, 10), _parse("decimal<10, A>"), params) assert params == {"A": 8} def test_covers_any(): - assert covers(decimal(10, 8), _parse("any"), {}) + assert covers(decimal(8, 10), _parse("any"), {}) + + +def test_covers_varchar_length_ok(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=15) + ) + param_ctx = _parse("varchar<15>") + assert covers(covered, param_ctx, {}, check_nullability=True) + + +def test_covers_varchar_length_fail(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_ctx = _parse("varchar<5>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_varchar_nullability(): + covered = Type( + varchar=Type.VarChar(nullability=Type.NULLABILITY_REQUIRED, length=10) + ) + param_tx = _parse("varchar?<10>") + assert covers(covered, param_tx, {}) + assert not covers(covered, param_tx, {}, True) + param_ctx2 = _parse("varchar<10>") + assert covers(covered, param_ctx2, {}, True) + + +def test_covers_fixed_char_length_ok(): + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) + ) + param_ctx = _parse("fixedchar<8>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_char_length_fail(): + covered = Type( + fixed_char=Type.FixedChar(nullability=Type.NULLABILITY_REQUIRED, length=8) + ) + param_ctx = _parse("fixedchar<4>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_ok(): + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + param_ctx = _parse("fixedbinary<16>") + assert covers(covered, param_ctx, {}) + + +def test_covers_fixed_binary_length_fail(): + covered = Type( + fixed_binary=Type.FixedBinary(nullability=Type.NULLABILITY_REQUIRED, length=16) + ) + param_ctx = _parse("fixedbinary<10>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_decimal_precision_scale_fail(): + covered = decimal(8, 10, nullable=False) + param_ctx = _parse("decimal<6, 5>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_ok(): + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=5 + ) + ) + param_ctx = _parse("precision_timestamp<5>") + assert covers(covered, param_ctx, {}) + param_ctx = _parse("precision_timestamp") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_fail(): + covered = Type( + precision_timestamp=Type.PrecisionTimestamp( + nullability=Type.NULLABILITY_REQUIRED, precision=3 + ) + ) + param_ctx = _parse("precision_timestamp<2>") + assert not covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_ok(): + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<4>") + assert covers(covered, param_ctx, {}) + param_ctx = _parse("precision_timestamp_tz") + assert covers(covered, param_ctx, {}) + + +def test_covers_precision_timestamp_tz_fail(): + covered = Type( + precision_timestamp_tz=Type.PrecisionTimestampTZ( + nullability=Type.NULLABILITY_REQUIRED, precision=4 + ) + ) + param_ctx = _parse("precision_timestamp_tz<3>") + assert not covers(covered, param_ctx, {}) def test_registry_uri_urn(): @@ -489,3 +581,53 @@ def test_register_requires_uri(): # During migration, URI is required - this should fail with TypeError with pytest.raises(TypeError): registry.register_extension_dict(yaml.safe_load(content)) + + +def test_covers_list_of_i8(): + """Test that a list of i8 covers list.""" + covered = list_(i8(nullable=False), nullable=False) + param_ctx = _parse("list") + assert covers(covered, param_ctx, {}) + + +def test_covers_map_string_to_i8(): + """Test that a map with string keys and i8 values covers map.""" + covered = map_( + key=Type(string=Type.String(nullability=Type.NULLABILITY_REQUIRED)), + value=i8(nullable=False), + nullable=False, + ) + param_ctx = _parse("map") + assert covers(covered, param_ctx, {}) + + +def test_covers_struct_with_two_fields(): + """Test that a struct with two i8 fields covers struct.""" + covered = struct([i8(nullable=False), i8(nullable=False)], nullable=False) + param_ctx = _parse("struct") + assert covers(covered, param_ctx, {}) + + +def test_covers_list_of_i16_fails_i8(): + """Test that a list of i16 does not cover list.""" + covered = list_(i16(nullable=False), nullable=False) + param_ctx = _parse("list") + assert not covers(covered, param_ctx, {}) + + +def test_covers_map_i8_to_i16_fails(): + """Test that a map with i8 keys and i16 values does not cover map.""" + covered = map_( + key=i8(nullable=False), + value=i16(nullable=False), + nullable=False, + ) + param_ctx = _parse("map") + assert not covers(covered, param_ctx, {}) + + +def test_covers_struct_mismatched_types_fails(): + """Test that a struct with mismatched field types does not cover struct.""" + covered = struct([i32(nullable=False), i8(nullable=False)], nullable=False) + param_ctx = _parse("struct") + assert not covers(covered, param_ctx, {})