Skip to content

Commit 55d624e

Browse files
authored
fix some convert corner cases (#22)
* fix some convert corner cases * fix earlier julia versions * fixes * accomodate julia 1.1
1 parent a5f989c commit 55d624e

File tree

4 files changed

+43
-12
lines changed

4 files changed

+43
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SumTypes"
22
uuid = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2"
33
authors = ["MasonProtter <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

src/SumTypes.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,28 @@ show_sumtype(io::IO, m::MIME, x) = show_sumtype(io, x)
4545
function show_sumtype(io::IO, x::T) where {T}
4646
tag = get_tag(x)
4747
sym = flag_to_symbol(T, tag)
48-
T_stripped = if length(T.parameters) == 2
49-
String(T.name.name)
50-
else
51-
string(String(T.name.name), "{", join(repr.(T.parameters[1:end-2]), ", "), "}")
52-
end
48+
T_stripped = T_string_stripped(T)
5349
if unwrap(x) isa Variant{(), Tuple{}}
5450
print(io, String(sym), "::", T_stripped)
5551
else
5652
print(io, String(sym), '(', join((repr(data) for data unwrap(x)), ", "), ")::", T_stripped)
5753
end
5854
end
55+
function T_string_stripped(::Type{_T}) where {_T}
56+
@assert is_sumtype(_T)
57+
T = full_type(_T)
58+
T_stripped = if length(T.parameters) == 2
59+
String(T.name.name)
60+
else
61+
string(String(T.name.name), "{", join(repr.(T.parameters[1:end-2]), ", "), "}")
62+
end
63+
end
64+
65+
66+
struct Converter{T, U} end
67+
(::Converter{T, U})(x) where {T, U} = convert(T, U(x))
68+
Base.show(io::IO, x::Converter{T, U}) where {T, U} = print(io, "$(T_string_stripped(T))'.$U")
69+
5970

6071
include("compute_storage.jl")
6172
include("sum_type.jl") # @sum_type defined here

src/sum_type.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_nameparam, constructors)
138138
out = Expr(:toplevel)
139139
converts = []
140-
foreach(constructors) do nt
140+
for nt constructors
141141
name = nt.name
142142
gname = nt.gname
143143
params = nt.params
@@ -190,16 +190,27 @@ function generate_constructor_exprs(T_name, T_params, T_params_constrained, T_na
190190
enumerate_constructors = collect(enumerate(constructors))
191191

192192
if true
193+
@gensym N M _tag _T x
194+
195+
if_nest_conv = mapfoldr(((cond, data), old) -> Expr(:if, cond, data, old), enumerate_constructors, init=:(error("invalid tag"))) do (i, nt)
196+
:($_tag == $(i-1) ), :($make($T_init, $unwrap(x, $(nt.store_type)) , $_tag))
197+
end
198+
193199
push!(converts, T_uninit => quote
194-
$Base.convert(::Type{$T_init}, x::$T_uninit) where {$(T_params...)} =
195-
$make($T_init, $unwrap(x), $getfield(x, $(QuoteNode(tag)) ))
196-
$T_init(x::$T_uninit) where {$(T_params...)} = $convert($T_init, x)
200+
$Base.convert(::$Type{$_T}, $x::$_T) where {$_T <: $T_name} = $x
201+
$Base.convert(::$Type{<:$T_init}, x::$T_uninit) where {$(T_params...)} = let $_tag = $get_tag(x)
202+
$if_nest_conv
203+
end
204+
(::$Type{<:$T_init})(x::$T_uninit) where {$(T_params...)} = $convert($T_init, x)
205+
$Base.convert(::$Type{<:$T_init}, x::$T_uninit{$N, $M}) where {$(T_params...), $N, $M} = let $_tag = $get_tag(x)
206+
$if_nest_conv
207+
end
208+
(::$Type{<:$_T})(x::$T_name) where {$_T <: $T_name} = $convert($_T, x)
197209
end)
198210
end
199211
end
200212
unique!(x -> x[1], converts)
201213
append!(out.args, map(x -> x[2], converts))
202-
203214
out
204215
end
205216

@@ -228,11 +239,12 @@ function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_p
228239
end
229240

230241
only_define_with_params = if !isempty(T_params)
242+
@gensym x
231243
quote
232244
$SumTypes.constructors(::Type{<:$T_nameparam}) where {$(T_params...)} =
233245
$NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.store_type for nt constructors)...)))
234246
$Base.adjoint(::Type{<:$T_nameparam}) where {$(T_params...)} =
235-
$NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.value ? :($T_nameparam($(nt.gname))) : nt.gouter_type for nt constructors)...)))
247+
$NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.value ? :($T_nameparam($(nt.gname))) : :($Converter{$T_nameparam, $(nt.gouter_type)}()) for nt constructors)...)))
236248
$SumTypes.variants_Tuple(::Type{<:$T_nameparam}) where {$(T_params...)} =
237249
$Tuple{$((nt.store_type for nt constructors)...)}
238250
$SumTypes.full_type(::Type{$T_name}) = $full_type($T_name{$(T_param_bounds...)}, $variants_Tuple($T_nameparam{$(T_param_bounds...)}))
@@ -269,6 +281,7 @@ function generate_sum_struct_expr(T, T_name, T_params, T_params_constrained, T_p
269281
$NamedTuple{$tags($T_name)}($(Expr(:tuple, (nt.gname for nt constructors)...)))
270282

271283
$SumTypes.full_type(::Type{$T_nameparam}) where {$(T_params...)} = $full_type($T_nameparam, $variants_Tuple($T_nameparam))
284+
$SumTypes.full_type(::Type{$T_nameparam{$N, $M}}) where {$(T_params...), $N, $M} = $T_nameparam{$N, $M}
272285

273286
$Base.show(io::IO, x::$T_name) = $show_sumtype(io, x)
274287
$Base.show(io::IO, m::MIME"text/plain", x::$T_name) = $show_sumtype(io, m, x)

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ end
8282
@test x != z
8383
end
8484
@test SumTypes.get_tag_sym(Left([1])) == :Left
85+
86+
@test convert(full_type(Either{Int, Int}), Left(1)) == Left(1)
87+
@test convert(full_type(Either{Int, Int}), Left(1)) !== Left(1)
88+
@test convert(full_type(Either{Int, Int}), Left(1)) === Either{Int, Int}'.Left(1)
89+
@test Either{Int, Int, 15, 0}(Left(1)) isa Either{Int, Int, 15, 0}
90+
@test Either{Int, Int, 15, 0}(Either{Int, Int}(Left(1))) isa Either{Int, Int, 15, 0}
8591

8692
@test_throws MethodError Left{Int}("hi")
8793
@test_throws MethodError Right{String}(1)
@@ -254,5 +260,6 @@ end
254260
@test repr(Right(3)) == "R(3)"
255261
end
256262
@test repr(apple) == "apple::Fruit"
263+
@test repr(Either{Int, Int}'.Left) ("Either{Int64, Int64}'.Left{Int64}", "Either{Int64,Int64}'.Left{Int64}")
257264
end
258265

0 commit comments

Comments
 (0)