Skip to content

Commit d9312be

Browse files
ineolAlasdair
authored andcommitted
Lean: add a simp set for the Sail wrappers
The Lean support library defines wrappers to adapt the API from Sail to the Lean standard library. This commit adds a simp set `simp_sail` to simplify them away.
1 parent 0c87bbf commit d9312be

File tree

5 files changed

+85
-13
lines changed

5 files changed

+85
-13
lines changed

src/bin/dune

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@
273273
(%{workspace_root}/src/gen_lib/sail2_values.lem
274274
as
275275
src/gen_lib/sail2_values.lem)
276+
(%{workspace_root}/src/sail_lean_backend/Sail/Attr.lean
277+
as
278+
src/sail_lean_backend/Sail/Attr.lean)
276279
(%{workspace_root}/src/sail_lean_backend/Sail/Sail.lean
277280
as
278281
src/sail_lean_backend/Sail/Sail.lean)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import Lean
2+
3+
register_simp_attr simp_sail

src/sail_lean_backend/Sail/Sail.lean

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import THE_MODULE_NAME.Sail.Attr
2+
13
import Std.Data.ExtDHashMap
24
import Std.Data.ExtHashMap
35

@@ -7,73 +9,87 @@ namespace BitVec
79

810
abbrev length {w : Nat} (_ : BitVec w) : Nat := w
911

12+
@[simp_sail]
1013
def signExtend {w : Nat} (x : BitVec w) (w' : Nat) : BitVec w' :=
1114
x.signExtend w'
1215

16+
@[simp_sail]
1317
def zeroExtend {w : Nat} (x : BitVec w) (w' : Nat) : BitVec w' :=
1418
x.zeroExtend w'
1519

20+
@[simp_sail]
1621
def truncate {w : Nat} (x : BitVec w) (w' : Nat) : BitVec w' :=
1722
x.truncate w'
1823

24+
@[simp_sail]
1925
def truncateLsb {w : Nat} (x : BitVec w) (w' : Nat) : BitVec w' :=
2026
x.extractLsb' (w - w') w'
2127

28+
@[simp_sail]
2229
def extractLsb {w : Nat} (x : BitVec w) (hi lo : Nat) : BitVec (hi - lo + 1) :=
2330
x.extractLsb hi lo
2431

32+
@[simp_sail]
2533
def updateSubrange' {w : Nat} (x : BitVec w) (start len : Nat) (y : BitVec len) : BitVec w :=
2634
let mask := ~~~(((BitVec.allOnes len).zeroExtend w) <<< start)
2735
let y' := ((y.zeroExtend w) <<< start)
2836
(mask &&& x) ||| y'
2937

38+
@[simp_sail]
3039
def slice {w : Nat} (x : BitVec w) (start len : Nat) : BitVec len :=
3140
x.extractLsb' start len
3241

42+
@[simp_sail]
3343
def sliceBE {w : Nat} (x : BitVec w) (start len : Nat) : BitVec len :=
3444
x.extractLsb' (w - start - len) len
3545

46+
@[simp_sail]
3647
def subrangeBE {w : Nat} (x : BitVec w) (lo hi : Nat) : BitVec (hi - lo + 1) :=
3748
x.extractLsb' (w - hi - 1) _
3849

50+
@[simp_sail]
3951
def updateSubrange {w : Nat} (x : BitVec w) (hi lo : Nat) (y : BitVec (hi - lo + 1)) : BitVec w :=
4052
updateSubrange' x lo _ y
4153

54+
@[simp_sail]
4255
def updateSubrangeBE {w : Nat} (x : BitVec w) (lo hi : Nat) (y : BitVec (hi - lo + 1)) : BitVec w :=
4356
updateSubrange' x (w - hi - 1) _ y
4457

58+
@[simp_sail]
4559
def replicateBits {w : Nat} (x : BitVec w) (i : Nat) := BitVec.replicate i x
4660

61+
@[simp_sail]
4762
def access {w : Nat} (x : BitVec w) (i : Nat) : BitVec 1 :=
4863
BitVec.ofBool x[i]!
4964

65+
@[simp_sail]
5066
def accessBE {w : Nat} (x : BitVec w) (i : Nat) : BitVec 1 :=
5167
BitVec.ofBool x[w - i - 1]!
5268

69+
@[simp_sail]
5370
def addInt {w : Nat} (x : BitVec w) (i : Int) : BitVec w :=
5471
x + BitVec.ofInt w i
5572

73+
@[simp_sail]
5674
def subInt (x : BitVec w) (i : Int) : BitVec w :=
5775
x - BitVec.ofInt w i
5876

59-
def countLeadingZeros (x : BitVec w) : Nat :=
60-
if h : w = 0 || BitVec.msb x then
61-
0
62-
else
63-
1 + countLeadingZeros (x.extractLsb' 0 (w - 1))
64-
decreasing_by
65-
simp only [Bool.or_eq_true, decide_eq_true_eq, not_or, Bool.not_eq_true] at h
66-
omega
77+
@[simp_sail]
78+
def countLeadingZeros (x : BitVec w) : Nat := x.clz.toNat
6779

80+
@[simp_sail]
6881
def countTrailingZeros (x : BitVec w) : Nat :=
6982
countLeadingZeros (x.reverse)
7083

84+
@[simp_sail]
7185
def append' (x : BitVec n) (y : BitVec m) {mn}
7286
(hmn : mn = n + m := by (conv => rhs; simp); try rfl) : BitVec mn :=
7387
(x.append y).cast hmn.symm
7488

89+
@[simp_sail]
7590
def update (x : BitVec m) (n : Nat) (b : BitVec 1) := updateSubrange' x n _ b
7691

92+
@[simp_sail]
7793
def updateBE (x : BitVec m) (n : Nat) (b : BitVec 1) := updateSubrange' x (m - n - 1) _ b
7894

7995
def toBin {w : Nat} (x : BitVec w) : String :=
@@ -85,6 +101,7 @@ def toFormatted {w : Nat} (x : BitVec w) : String :=
85101
else
86102
s!"0b{BitVec.toBin x}"
87103

104+
@[simp_sail]
88105
def join1 (xs : List (BitVec 1)) : BitVec xs.length :=
89106
(BitVec.ofBoolListBE (xs.map fun x => x[0])).cast (by simp)
90107

@@ -135,18 +152,23 @@ def valid_dec_bits (_ : Nat) (str : String) : Bool :=
135152
str.all fun x => x.toLower ∈
136153
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
137154

155+
@[simp_sail]
138156
def shift_bits_left (bv : BitVec n) (sh : BitVec m) : BitVec n :=
139157
bv <<< sh
140158

159+
@[simp_sail]
141160
def shift_bits_right (bv : BitVec n) (sh : BitVec m) : BitVec n :=
142161
bv >>> sh
143162

163+
@[simp_sail]
144164
def shiftl (bv : BitVec n) (m : Nat) : BitVec n :=
145165
bv <<< m
146166

167+
@[simp_sail]
147168
def shiftr (bv : BitVec n) (m : Nat) : BitVec n :=
148169
bv >>> m
149170

171+
@[simp_sail]
150172
def pow2 := (2 ^ ·)
151173

152174
namespace Nat
@@ -171,11 +193,13 @@ namespace Int
171193

172194
def intAbs (x : Int) : Int := Int.ofNat (Int.natAbs x)
173195

196+
@[simp_sail]
174197
def shiftl (a : Int) (n : Int) : Int :=
175198
match n with
176199
| Int.ofNat n => Sail.Nat.iterate (fun x => x * 2) n a
177200
| Int.negSucc n => Sail.Nat.iterate (fun x => x / 2) (n+1) a
178201

202+
@[simp_sail]
179203
def shiftr (a : Int) (n : Int) : Int :=
180204
match n with
181205
| Int.ofNat n => Sail.Nat.iterate (fun x => x / 2) n a
@@ -194,12 +218,15 @@ def toHexUpper (i : Int) : String :=
194218

195219
end Int
196220

221+
@[simp_sail]
197222
def get_slice_int (len : Nat) (n : Int) (lo : Nat) : BitVec len :=
198223
BitVec.extractLsb' lo len (BitVec.ofInt (lo + len + 1) n)
199224

225+
@[simp_sail]
200226
def set_slice_int (len : Nat) (n : Int) (lo : Nat) (x : BitVec len) : Int :=
201227
BitVec.toInt (BitVec.updateSubrange' (BitVec.ofInt len n) lo len x)
202228

229+
@[simp_sail]
203230
def set_slice {n : Nat} (m : Nat) (bv : BitVec n) (start : Nat) (bv' : BitVec m) : BitVec n :=
204231
BitVec.updateSubrange' bv start m bv'
205232

@@ -208,16 +235,20 @@ def String.leadingSpaces (s : String) : Nat :=
208235

209236
abbrev Vector.length (_v : Vector α n) : Nat := n
210237

238+
@[simp_sail]
211239
def vectorInit {n : Nat} (a : α) : Vector α n := Vector.replicate n a
212240

241+
@[simp_sail]
213242
def vectorUpdate (v : Vector α m) (n : Nat) (a : α) := v.set! n a
214243

244+
@[simp_sail]
215245
instance : HShiftLeft (BitVec w) Int (BitVec w) where
216246
hShiftLeft b i :=
217247
match i with
218248
| .ofNat n => BitVec.shiftLeft b n
219249
| .negSucc n => BitVec.ushiftRight b (n + 1)
220250

251+
@[simp_sail]
221252
instance : HShiftRight (BitVec w) Int (BitVec w) where
222253
hShiftRight b i := b <<< (- i)
223254

@@ -391,13 +422,15 @@ export RegisterRef (Reg)
391422
abbrev PreSailM (RegisterType : Register → Type) (c : ChoiceSource) (ue: Type) :=
392423
EStateM (Error ue) (SequentialState RegisterType c)
393424

425+
@[simp_sail]
394426
def sailTryCatch (e : PreSailM RegisterType c ue α) (h : ue → PreSailM RegisterType c ue α) :
395427
PreSailM RegisterType c ue α :=
396428
EStateM.tryCatch e fun e =>
397429
match e with
398430
| .User u => h u
399431
| _ => EStateM.throw e
400432

433+
@[simp_sail]
401434
def sailThrow (e : ue) : PreSailM RegisterType c ue α := EStateM.throw (.User e)
402435

403436
def choose (p : Primitive) : PreSailM RegisterType c ue p.reflect :=
@@ -436,54 +469,66 @@ def internal_pick {α : Type} : List α → PreSailM RegisterType c ue α
436469
let idx ← choose <| .fin (as.length)
437470
pure <| (a :: as).get idx
438471

472+
@[simp_sail]
439473
def writeReg (r : Register) (v : RegisterType r) : PreSailM RegisterType c ue PUnit :=
440474
modify fun s => { s with regs := s.regs.insert r v }
441475

476+
@[simp_sail]
442477
def readReg (r : Register) : PreSailM RegisterType c ue (RegisterType r) := do
443478
let .some s := (← get).regs.get? r
444479
| throw .Unreachable
445480
pure s
446481

482+
@[simp_sail]
447483
def readRegRef (reg_ref : @RegisterRef Register RegisterType α) : PreSailM RegisterType c ue α := do
448484
match reg_ref with | .Reg r => readReg r
449485

486+
@[simp_sail]
450487
def writeRegRef (reg_ref : @RegisterRef Register RegisterType α) (a : α) :
451488
PreSailM RegisterType c ue Unit := do
452489
match reg_ref with | .Reg r => writeReg r a
453490

491+
@[simp_sail]
454492
def reg_deref (reg_ref : @RegisterRef Register RegisterType α) : PreSailM RegisterType c ue α :=
455493
readRegRef reg_ref
456494

495+
@[simp_sail]
457496
def assert (p : Bool) (s : String) : PreSailM RegisterType c ue Unit :=
458497
if p then pure () else throw (.Assertion s)
459498

460499
section ConcurrencyInterface
461500

501+
@[simp_sail]
462502
def writeByte (addr : Nat) (value : BitVec 8) : PreSailM RegisterType c ue PUnit := do
463503
modify fun s => { s with mem := s.mem.insert addr value }
464504

505+
@[simp_sail]
465506
def writeBytes (addr : Nat) (value : BitVec (8 * n)) : PreSailM RegisterType c ue Bool := do
466507
let list := List.ofFn (λ i : Fin n => (addr + i.val, value.extractLsb' (8 * i.val) 8))
467508
List.forM list (λ (a, v) => writeByte a v)
468509
pure true
469510

511+
@[simp_sail]
470512
def sail_mem_write [Arch] (req : Mem_write_request n vasize (BitVec pa_size) ts arch) : PreSailM RegisterType c ue (Result (Option Bool) Arch.abort) := do
471513
let addr := req.pa.toNat
472514
let b ← match req.value with
473515
| some v => writeBytes addr v
474516
| none => pure true
475517
pure (Ok (some b))
476518

519+
@[simp_sail]
477520
def write_ram (addr_size data_size : Nat) (_hex_ram addr : BitVec addr_size) (value : BitVec (8 * data_size)) :
478521
PreSailM RegisterType c ue Unit := do
479522
let _ ← writeBytes addr.toNat value
480523
pure ()
481524

525+
@[simp_sail]
482526
def readByte (addr : Nat) : PreSailM RegisterType c ue (BitVec 8) := do
483527
let .some s := (← get).mem.get? addr
484528
| throw (.OutOfMemoryRange addr)
485529
pure s
486530

531+
@[simp_sail]
487532
def readBytes (size : Nat) (addr : Nat) : PreSailM RegisterType c ue ((BitVec (8 * size)) × Option Bool) :=
488533
match size with
489534
| 0 => pure (default, none)
@@ -497,26 +542,37 @@ def readBytes (size : Nat) (addr : Nat) : PreSailM RegisterType c ue ((BitVec (8
497542
have h : 8 * n + 8 = 8 * (n + 1) := by omega
498543
return (h ▸ bytes.append b, bool)
499544

545+
@[simp_sail]
500546
def sail_mem_read [Arch] (req : Mem_read_request n vasize (BitVec pa_size) ts arch) : PreSailM RegisterType c ue (Result ((BitVec (8 * n)) × (Option Bool)) Arch.abort) := do
501547
let addr := req.pa.toNat
502548
let value ← readBytes n addr
503549
pure (Ok value)
504550

551+
@[simp_sail]
505552
def read_ram (addr_size data_size : Nat) (_hex_ram addr : BitVec addr_size) : PreSailM RegisterType c ue (BitVec (8 * data_size)) := do
506553
let ⟨bytes, _⟩ ← readBytes data_size addr.toNat
507554
pure bytes
508555

556+
@[simp_sail]
509557
def sail_barrier (_ : α) : PreSailM RegisterType c ue Unit := pure ()
558+
@[simp_sail]
510559
def sail_cache_op [Arch] (_ : Arch.cache_op) : PreSailM RegisterType c ue Unit := pure ()
560+
@[simp_sail]
511561
def sail_tlbi [Arch] (_ : Arch.tlb_op) : PreSailM RegisterType c ue Unit := pure ()
562+
@[simp_sail]
512563
def sail_translation_start [Arch] (_ : Arch.trans_start) : PreSailM RegisterType c ue Unit := pure ()
564+
@[simp_sail]
513565
def sail_translation_end [Arch] (_ : Arch.trans_end) : PreSailM RegisterType c ue Unit := pure ()
566+
@[simp_sail]
514567
def sail_take_exception [Arch] (_ : Arch.fault) : PreSailM RegisterType c ue Unit := pure ()
568+
@[simp_sail]
515569
def sail_return_exception [Arch] (_ : Arch.pa) : PreSailM RegisterType c ue Unit := pure ()
516570

571+
@[simp_sail]
517572
def cycle_count (_ : Unit) : PreSailM RegisterType c ue Unit :=
518573
modify fun s => { s with cycleCount := s.cycleCount + 1 }
519574

575+
@[simp_sail]
520576
def get_cycle_count (_ : Unit) : PreSailM RegisterType c ue Nat := do
521577
pure (← get).cycleCount
522578

@@ -534,6 +590,7 @@ def print_bits_effect {w : Nat} (str : String) (x : BitVec w) : PreSailM Registe
534590
def print_endline_effect (str : String) : PreSailM RegisterType c ue Unit :=
535591
print_effect s!"{str}\n"
536592

593+
@[simp_sail]
537594
def sailTryCatchE (e : ExceptT β (PreSailM RegisterType c ue) α) (h : ue → ExceptT β (PreSailM RegisterType c ue) α) : ExceptT β (PreSailM RegisterType c ue) α :=
538595
EStateM.tryCatch e fun e =>
539596
match e with

src/sail_lean_backend/Sail/Specialization.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ import THE_MODULE_NAME.Defs
33

44
namespace Sail
55

6+
@[simp_sail]
67
def sailTryCatch (e : SailM α) (h : exception → SailM α) : SailM α := PreSail.sailTryCatch e h
78

9+
@[simp_sail]
810
def sailThrow (e : exception) : SailM α := PreSail.sailThrow e
911

1012
abbrev undefined_unit (_ : Unit) : SailM Unit := PreSail.undefined_unit ()

src/sail_lean_backend/sail_plugin_lean.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,14 @@ open Sail
247247

248248
let path_to_static_library sail_dir str = Filename.quote (sail_dir ^ "/src/sail_lean_backend/Sail/" ^ str ^ ".lean")
249249

250-
let copy_from_static_library sail_dir lean_sail_dir str =
251-
Unix.system ("cp " ^ path_to_static_library sail_dir str ^ " " ^ Filename.quote lean_sail_dir)
250+
let copy_from_static_library out_name_camel sail_dir lean_sail_dir str =
251+
Unix.system
252+
(Printf.sprintf "sed 's/THE_MODULE_NAME/%s/g' %s > %s.lean" out_name_camel (path_to_static_library sail_dir str)
253+
(lean_sail_dir ^ "/" ^ str |> Filename.quote)
254+
)
255+
|> ignore
256+
257+
(* Unix.system ("cp " ^ path_to_static_library sail_dir str ^ " " ^ Filename.quote lean_sail_dir) *)
252258

253259
let print_function_file_prelude file out_name_camel (imp_refs : string list) =
254260
let _ =
@@ -292,9 +298,10 @@ let start_lean_output (out_name : string) (import_names : string list) (import_r
292298
if not (Sys.file_exists lean_src_dir) then Unix.mkdir lean_src_dir 0o775;
293299
let lean_sail_dir = lean_src_dir ^ "/Sail/" in
294300
Unix.mkdir lean_sail_dir 0o775;
295-
let _ = copy_from_static_library sail_dir lean_sail_dir "BitVec" in
296-
let _ = copy_from_static_library sail_dir lean_sail_dir "IntRange" in
297-
let _ = copy_from_static_library sail_dir lean_sail_dir "Sail" in
301+
let _ = copy_from_static_library out_name_camel sail_dir lean_sail_dir "Attr" in
302+
let _ = copy_from_static_library out_name_camel sail_dir lean_sail_dir "BitVec" in
303+
let _ = copy_from_static_library out_name_camel sail_dir lean_sail_dir "IntRange" in
304+
let _ = copy_from_static_library out_name_camel sail_dir lean_sail_dir "Sail" in
298305
let real_numbers_file =
299306
if !opt_lean_real_numbers then "/src/sail_lean_backend/Sail/Real.lean"
300307
else "/src/sail_lean_backend/Sail/FakeReal.lean"

0 commit comments

Comments
 (0)