Skip to content

Commit 1af7cd9

Browse files
authored
Add some quasi-internal API nicities (#37)
* fix hygiene of `full_type` * add isvariant and getindex(::Variant) * bump version
1 parent ea855d4 commit 1af7cd9

File tree

6 files changed

+33
-5
lines changed

6 files changed

+33
-5
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ jobs:
1212
version:
1313
- '1.1'
1414
- '1.6'
15-
- '1.8'
15+
- '1.8'
16+
- '1.9'
1617
- 'nightly'
1718
os:
1819
- ubuntu-latest

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SumTypes"
22
uuid = "8e1ec7a9-0e02-4297-b0fe-6433085c89f2"
33
authors = ["MasonProtter <[email protected]>"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

src/SumTypes.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ function variants_Tuple end
2121
function strip_size_params end
2222
function full_type end
2323

24+
"""
25+
isvariant(x::SumType, s::Symbol)
26+
27+
For an `x` which was created as a `@sum_type`, check if it's variant tag is `s`. e.g.
28+
29+
@sum_type Either{L, R} begin
30+
Left{L}(::L)
31+
Right{R}(::R)
32+
end
33+
34+
let x::Either{Int, Int} = Left(1)
35+
isvariant(x, :Left) # true
36+
isvariant(x, :Right) # false
37+
end
38+
"""
39+
isvariant(x::T, s::Symbol) where {T} = get_tag(x) == symbol_to_flag(T, s)
2440

2541
struct Unsafe end
2642
const unsafe = Unsafe()
@@ -36,6 +52,7 @@ Base.:(==)(v1::Variant, v2::Variant) = v1.data == v2.data
3652

3753
Base.iterate(x::Variant, s = 1) = iterate(x.data, s)
3854
Base.indexed_iterate(x::Variant, i::Int, state=1) = (Base.@_inline_meta; (getfield(x.data, i), i+1))
55+
Base.getindex(x::Variant, i) = x.data[i]
3956

4057
const tag = Symbol("#tag#")
4158
get_tag(x) = getfield(x, tag)

src/cases.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ macro cases(to_match, block)
5656
@gensym nt
5757
variants = map(x -> x.variant, stmts)
5858

59-
ex = :(if $get_tag($data) === $symbol_to_flag($Typ, $(QuoteNode(stmts[1].variant)));
59+
ex = :(if $isvariant($data, $(QuoteNode(stmts[1].variant)));
6060
$(stmts[1].iscall ? :(($(stmts[1].fieldnames...),) =
6161
$unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[1].variant))}), $variants_Tuple($Typ)) ) : nothing);
6262
$(stmts[1].rhs)
@@ -65,7 +65,7 @@ macro cases(to_match, block)
6565
pushfirst!(ex.args[2].args, lnns[1])
6666
to_push = ex.args
6767
for i 2:length(stmts)
68-
_if = :(if $get_tag($data) === $symbol_to_flag($Typ, $(QuoteNode(stmts[i].variant)));
68+
_if = :(if $isvariant($data, $(QuoteNode(stmts[i].variant)));
6969
$(stmts[i].iscall ? :(($(stmts[i].fieldnames...),) =
7070
$unwrap($data, $constructor($Typ, $Val{$(QuoteNode(stmts[i].variant))}), $variants_Tuple($Typ))) : nothing);
7171
$(stmts[i].rhs)
@@ -87,4 +87,3 @@ macro cases(to_match, block)
8787
end
8888
end |> esc
8989
end
90-

src/compute_storage.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ make(::Type{ST}, to_make, tag) where {ST} = make(ST, to_make, tag, variants_Tupl
109109
)
110110
end
111111

112+
113+
function unwrap(x::ST, s::Symbol) where {ST}
114+
isvariant(x, s) || error("Incorrect tag used in unwrap")
115+
unwrap(x, constructor(ST, Val{s}), variants_Tuple(ST))
116+
end
112117
unwrap(x::ST, var) where {ST} = unwrap(x, var, variants_Tuple(ST))
113118
@generated function unwrap(x::ST, ::Type{Var}, ::Type{var_Tuple}) where {ST, Var, var_Tuple}
114119
variants = var_Tuple.parameters

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ end
115115
@test full_type(Either{Nothing, Int16}) == Either{Nothing, Int16, 2, 0, UInt16}
116116
@test full_type(Either{Int32, Int32}) == Either{Int32, Int32, 4, 0, UInt32}
117117
@test convert(full_type(Result{Float64}), Success(1.0)) == Success(1.0)
118+
119+
let x = Left(1.0)
120+
@test SumTypes.isvariant(x, :Left) == true
121+
@test SumTypes.isvariant(x, :Right) == false
122+
@test SumTypes.unwrap(x, :Left)[1] == 1.0
123+
end
118124
end
119125

120126
#--------------------------------------------------------

0 commit comments

Comments
 (0)