|
1 | 1 | from typing import Optional |
2 | | -from antlr4 import InputStream, CommonTokenStream |
| 2 | + |
| 3 | +from antlr4 import CommonTokenStream, InputStream |
| 4 | + |
3 | 5 | from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer |
4 | 6 | from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser |
5 | | -from substrait.gen.proto.type_pb2 import Type |
| 7 | +from substrait.gen.proto.type_pb2 import NamedStruct, Type |
6 | 8 |
|
7 | 9 |
|
8 | 10 | def _evaluate(x, values: dict): |
@@ -65,22 +67,128 @@ def _evaluate(x, values: dict): |
65 | 67 | return Type(fp64=Type.FP64(nullability=nullability)) |
66 | 68 | elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): |
67 | 69 | return Type(bool=Type.Boolean(nullability=nullability)) |
| 70 | + elif isinstance(scalar_type, SubstraitTypeParser.StringContext): |
| 71 | + return Type(string=Type.String(nullability=nullability)) |
| 72 | + elif isinstance(scalar_type, SubstraitTypeParser.TimestampContext): |
| 73 | + return Type(timestamp=Type.Timestamp(nullability=nullability)) |
| 74 | + elif isinstance(scalar_type, SubstraitTypeParser.DateContext): |
| 75 | + return Type(date=Type.Date(nullability=nullability)) |
| 76 | + elif isinstance(scalar_type, SubstraitTypeParser.IntervalYearContext): |
| 77 | + return Type(interval_year=Type.IntervalYear(nullability=nullability)) |
| 78 | + elif isinstance(scalar_type, SubstraitTypeParser.UuidContext): |
| 79 | + return Type(uuid=Type.UUID(nullability=nullability)) |
| 80 | + elif isinstance(scalar_type, SubstraitTypeParser.BinaryContext): |
| 81 | + return Type(binary=Type.Binary(nullability=nullability)) |
| 82 | + elif isinstance(scalar_type, SubstraitTypeParser.TimeContext): |
| 83 | + return Type(time=Type.Time(nullability=nullability)) |
| 84 | + elif isinstance(scalar_type, SubstraitTypeParser.TimestampTzContext): |
| 85 | + return Type(timestamp_tz=Type.TimestampTZ(nullability=nullability)) |
68 | 86 | else: |
69 | 87 | raise Exception(f"Unknown scalar type {type(scalar_type)}") |
70 | 88 | elif parametrized_type: |
| 89 | + nullability = ( |
| 90 | + Type.NULLABILITY_NULLABLE |
| 91 | + if parametrized_type.isnull |
| 92 | + else Type.NULLABILITY_REQUIRED |
| 93 | + ) |
71 | 94 | if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): |
72 | 95 | precision = _evaluate(parametrized_type.precision, values) |
73 | 96 | scale = _evaluate(parametrized_type.scale, values) |
74 | | - nullability = ( |
75 | | - Type.NULLABILITY_NULLABLE |
76 | | - if parametrized_type.isnull |
77 | | - else Type.NULLABILITY_REQUIRED |
78 | | - ) |
79 | 97 | return Type( |
80 | 98 | decimal=Type.Decimal( |
81 | 99 | precision=precision, scale=scale, nullability=nullability |
82 | 100 | ) |
83 | 101 | ) |
| 102 | + elif isinstance(parametrized_type, SubstraitTypeParser.VarCharContext): |
| 103 | + length = _evaluate(parametrized_type.length, values) |
| 104 | + return Type( |
| 105 | + varchar=Type.VarChar( |
| 106 | + length=length, |
| 107 | + nullability=nullability, |
| 108 | + ) |
| 109 | + ) |
| 110 | + elif isinstance(parametrized_type, SubstraitTypeParser.FixedCharContext): |
| 111 | + length = _evaluate(parametrized_type.length, values) |
| 112 | + return Type( |
| 113 | + fixed_char=Type.FixedChar( |
| 114 | + length=length, |
| 115 | + nullability=nullability, |
| 116 | + ) |
| 117 | + ) |
| 118 | + elif isinstance(parametrized_type, SubstraitTypeParser.FixedBinaryContext): |
| 119 | + length = _evaluate(parametrized_type.length, values) |
| 120 | + return Type( |
| 121 | + fixed_binary=Type.FixedBinary( |
| 122 | + length=length, |
| 123 | + nullability=nullability, |
| 124 | + ) |
| 125 | + ) |
| 126 | + elif isinstance( |
| 127 | + parametrized_type, SubstraitTypeParser.PrecisionTimestampContext |
| 128 | + ): |
| 129 | + precision = _evaluate(parametrized_type.precision, values) |
| 130 | + return Type( |
| 131 | + precision_timestamp=Type.PrecisionTimestamp( |
| 132 | + precision=precision, |
| 133 | + nullability=nullability, |
| 134 | + ) |
| 135 | + ) |
| 136 | + elif isinstance( |
| 137 | + parametrized_type, SubstraitTypeParser.PrecisionTimestampTZContext |
| 138 | + ): |
| 139 | + precision = _evaluate(parametrized_type.precision, values) |
| 140 | + return Type( |
| 141 | + precision_timestamp_tz=Type.PrecisionTimestampTZ( |
| 142 | + precision=precision, |
| 143 | + nullability=nullability, |
| 144 | + ) |
| 145 | + ) |
| 146 | + elif isinstance(parametrized_type, SubstraitTypeParser.IntervalYearContext): |
| 147 | + return Type( |
| 148 | + interval_year=Type.IntervalYear( |
| 149 | + nullability=nullability, |
| 150 | + ) |
| 151 | + ) |
| 152 | + elif isinstance(parametrized_type, SubstraitTypeParser.StructContext): |
| 153 | + types = list( |
| 154 | + map(lambda x: _evaluate(x, values), parametrized_type.expr()) |
| 155 | + ) |
| 156 | + return Type( |
| 157 | + struct=Type.Struct( |
| 158 | + types=types, |
| 159 | + nullability=nullability, |
| 160 | + ) |
| 161 | + ) |
| 162 | + elif isinstance(parametrized_type, SubstraitTypeParser.ListContext): |
| 163 | + list_type = _evaluate(parametrized_type.expr(), values) |
| 164 | + return Type( |
| 165 | + list=Type.List( |
| 166 | + type=list_type, |
| 167 | + nullability=nullability, |
| 168 | + ) |
| 169 | + ) |
| 170 | + |
| 171 | + elif isinstance(parametrized_type, SubstraitTypeParser.MapContext): |
| 172 | + return Type( |
| 173 | + map=Type.Map( |
| 174 | + key=_evaluate(parametrized_type.key, values), |
| 175 | + value=_evaluate(parametrized_type.value, values), |
| 176 | + nullability=nullability, |
| 177 | + ) |
| 178 | + ) |
| 179 | + elif isinstance(parametrized_type, SubstraitTypeParser.NStructContext): |
| 180 | + names = list(map(lambda k: k.getText(), parametrized_type.Identifier())) |
| 181 | + struct = Type.Struct( |
| 182 | + types=list( |
| 183 | + map(lambda k: _evaluate(k, values), parametrized_type.expr()) |
| 184 | + ), |
| 185 | + nullability=nullability, |
| 186 | + ) |
| 187 | + return NamedStruct( |
| 188 | + names=names, |
| 189 | + struct=struct, |
| 190 | + ) |
| 191 | + |
84 | 192 | raise Exception(f"Unknown parametrized type {type(parametrized_type)}") |
85 | 193 | elif any_type: |
86 | 194 | any_var = any_type.AnyVar() |
|
0 commit comments