Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions Benchmarks/Blake3.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def blake3Bench : IO $ Array BenchReport := do
let data := Array.range dataSize |>.map
-- Add `idx` so every preimage is different and avoids memoization.
fun i => Aiur.G.ofUInt8 (i + idx).toUInt8
let ioKeyInfo := ⟨ioBuffer.data.size, dataSize⟩
{ ioBuffer with
data := ioBuffer.data ++ data
map := ioBuffer.map.insert #[.ofNat idx] ioKeyInfo }
ioBuffer.extend 0 #[.ofNat idx] data
throughput (.ElementsAndBytes numHashes.toUInt64 (dataSize * numHashes).toUInt64 "hashes")
bench s!"dataSize={dataSize} numHashes={numHashes}"
(aiurSystem.prove friParameters funIdx #[Aiur.G.ofNat numHashes]) ioBuffer
Expand Down
5 changes: 1 addition & 4 deletions Benchmarks/Sha256.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def sha256Bench : IO $ Array BenchReport := do
let data := Array.range dataSize |>.map
-- Add `idx` so every preimage is different and avoids memoization.
fun i => Aiur.G.ofUInt8 (i + idx).toUInt8
let ioKeyInfo := ⟨ioBuffer.data.size, dataSize⟩
{ ioBuffer with
data := ioBuffer.data ++ data
map := ioBuffer.map.insert #[.ofNat idx] ioKeyInfo }
ioBuffer.extend 0 #[.ofNat idx] data
throughput (.ElementsAndBytes numHashes.toUInt64 (dataSize * numHashes).toUInt64 "hashes")
bench s!"dataSize={dataSize} numHashes={numHashes}"
(aiurSystem.prove friParameters funIdx #[Aiur.G.ofNat numHashes]) ioBuffer
Expand Down
35 changes: 21 additions & 14 deletions Ix/Aiur/Compiler/Check.lean
Original file line number Diff line number Diff line change
Expand Up @@ -759,33 +759,37 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with
let a' ← checkNoEscape a .field
let b' ← checkNoEscape b .field
pure (Typed.Term.u32LessThan .field false a' b')
| .ioGetInfo key => do
| .ioGetInfo channel key => do
let channel' ← checkNoEscape channel .field
let key' ← inferNoEscape key
match ← walkTyp key'.typ with
| .array .. =>
pure (Typed.Term.ioGetInfo (.tuple #[.field, .field]) false key')
pure (Typed.Term.ioGetInfo (.tuple #[.field, .field]) false channel' key')
| typ' => throw $ .notAnArray typ'
| .ioSetInfo key idx len ret => do
| .ioSetInfo channel key idx len ret => do
let channel' ← checkNoEscape channel .field
let key' ← inferNoEscape key
match ← walkTyp key'.typ with
| .array keyEltTyp _ =>
unless ← unifyTyp keyEltTyp .field do throw $ .typeMismatch .field keyEltTyp
let idx' ← checkNoEscape idx .field
let len' ← checkNoEscape len .field
let ret' ← inferTerm ret
pure (Typed.Term.ioSetInfo ret'.typ ret'.escapes key' idx' len' ret')
pure (Typed.Term.ioSetInfo ret'.typ ret'.escapes channel' key' idx' len' ret')
| typ' => throw $ .notAnArray typ'
| .ioRead idx len => do
| .ioRead channel idx len => do
if len = 0 then throw .emptyArray
let channel' ← checkNoEscape channel .field
let idx' ← checkNoEscape idx .field
pure (Typed.Term.ioRead (.array .field len) false idx' len)
| .ioWrite data ret => do
pure (Typed.Term.ioRead (.array .field len) false channel' idx' len)
| .ioWrite channel data ret => do
let channel' ← checkNoEscape channel .field
let data' ← inferNoEscape data
match ← walkTyp data'.typ with
| .array dataEltTyp _ =>
unless ← unifyTyp dataEltTyp .field do throw $ .typeMismatch .field dataEltTyp
let ret' ← inferTerm ret
pure (Typed.Term.ioWrite ret'.typ ret'.escapes data' ret')
pure (Typed.Term.ioWrite ret'.typ ret'.escapes channel' data' ret')
| typ' => throw $ .notAnArray typ'
| .assertEq a b ret => do
let a' ← inferNoEscape a
Expand Down Expand Up @@ -900,12 +904,15 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with
| .ptrVal τ e a => do pure (.ptrVal (← zonkTyp τ) e (← zonkTypedTerm a))
| .assertEq τ e a b r => do
pure (.assertEq (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b) (← zonkTypedTerm r))
| .ioGetInfo τ e k => do pure (.ioGetInfo (← zonkTyp τ) e (← zonkTypedTerm k))
| .ioSetInfo τ e k i l r => do
pure (.ioSetInfo (← zonkTyp τ) e (← zonkTypedTerm k) (← zonkTypedTerm i)
(← zonkTypedTerm l) (← zonkTypedTerm r))
| .ioRead τ e i n => do pure (.ioRead (← zonkTyp τ) e (← zonkTypedTerm i) n)
| .ioWrite τ e d r => do pure (.ioWrite (← zonkTyp τ) e (← zonkTypedTerm d) (← zonkTypedTerm r))
| .ioGetInfo τ e c k => do
pure (.ioGetInfo (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm k))
| .ioSetInfo τ e c k i l r => do
pure (.ioSetInfo (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm k)
(← zonkTypedTerm i) (← zonkTypedTerm l) (← zonkTypedTerm r))
| .ioRead τ e c i n => do
pure (.ioRead (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm i) n)
| .ioWrite τ e c d r => do
pure (.ioWrite (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm d) (← zonkTypedTerm r))
| .u8BitDecomposition τ e a => do
pure (.u8BitDecomposition (← zonkTyp τ) e (← zonkTypedTerm a))
| .u8ShiftLeft τ e a => do pure (.u8ShiftLeft (← zonkTyp τ) e (← zonkTypedTerm a))
Expand Down
89 changes: 58 additions & 31 deletions Ix/Aiur/Compiler/Concretize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,21 @@ def termToConcrete
| .assertEq τ e a b r => do
pure (.assertEq (← typToConcrete mono τ) e
(← termToConcrete mono a) (← termToConcrete mono b) (← termToConcrete mono r))
| .ioGetInfo τ e k => do pure (.ioGetInfo (← typToConcrete mono τ) e (← termToConcrete mono k))
| .ioSetInfo τ e k i l r => do
| .ioGetInfo τ e c k => do
pure (.ioGetInfo (← typToConcrete mono τ) e
(← termToConcrete mono c) (← termToConcrete mono k))
| .ioSetInfo τ e c k i l r => do
pure (.ioSetInfo (← typToConcrete mono τ) e
(← termToConcrete mono k) (← termToConcrete mono i)
(← termToConcrete mono c) (← termToConcrete mono k)
(← termToConcrete mono i)
(← termToConcrete mono l) (← termToConcrete mono r))
| .ioRead τ e i n => do pure (.ioRead (← typToConcrete mono τ) e (← termToConcrete mono i) n)
| .ioWrite τ e d r => do
| .ioRead τ e c i n => do
pure (.ioRead (← typToConcrete mono τ) e
(← termToConcrete mono c) (← termToConcrete mono i) n)
| .ioWrite τ e c d r => do
pure (.ioWrite (← typToConcrete mono τ) e
(← termToConcrete mono d) (← termToConcrete mono r))
(← termToConcrete mono c) (← termToConcrete mono d)
(← termToConcrete mono r))
| .u8BitDecomposition τ e a => do
pure (.u8BitDecomposition (← typToConcrete mono τ) e (← termToConcrete mono a))
| .u8ShiftLeft τ e a => do
Expand Down Expand Up @@ -510,16 +516,21 @@ def rewriteTypedTerm (decls : Typed.Decls)
.assertEq (rewriteTyp subst mono τ) e
(rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b)
(rewriteTypedTerm decls subst mono r)
| .ioGetInfo τ e k =>
.ioGetInfo (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono k)
| .ioSetInfo τ e k i l r =>
| .ioGetInfo τ e c k =>
.ioGetInfo (rewriteTyp subst mono τ) e
(rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono k)
| .ioSetInfo τ e c k i l r =>
.ioSetInfo (rewriteTyp subst mono τ) e
(rewriteTypedTerm decls subst mono k) (rewriteTypedTerm decls subst mono i)
(rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono k)
(rewriteTypedTerm decls subst mono i)
(rewriteTypedTerm decls subst mono l) (rewriteTypedTerm decls subst mono r)
| .ioRead τ e i n => .ioRead (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono i) n
| .ioWrite τ e d r =>
| .ioRead τ e c i n =>
.ioRead (rewriteTyp subst mono τ) e
(rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono i) n
| .ioWrite τ e c d r =>
.ioWrite (rewriteTyp subst mono τ) e
(rewriteTypedTerm decls subst mono d) (rewriteTypedTerm decls subst mono r)
(rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono d)
(rewriteTypedTerm decls subst mono r)
| .u8BitDecomposition τ e a =>
.u8BitDecomposition (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a)
| .u8ShiftLeft τ e a =>
Expand Down Expand Up @@ -618,20 +629,26 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) :
| .u8LessThan τ _ a b | .u32LessThan τ _ a b =>
collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b
| .eqZero τ _ a | .store τ _ a | .load τ _ a | .ptrVal τ _ a
| .u8BitDecomposition τ _ a | .u8ShiftLeft τ _ a | .u8ShiftRight τ _ a
| .ioGetInfo τ _ a => collectInTypedTerm (collectInTyp seen τ) a
| .u8BitDecomposition τ _ a | .u8ShiftLeft τ _ a | .u8ShiftRight τ _ a =>
collectInTypedTerm (collectInTyp seen τ) a
| .ioGetInfo τ _ c k =>
collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) k
| .proj τ _ a _ | .get τ _ a _ | .slice τ _ a _ _ =>
collectInTypedTerm (collectInTyp seen τ) a
| .set τ _ a _ v => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) v
| .assertEq τ _ a b r =>
collectInTypedTerm
(collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b) r
| .ioSetInfo τ _ k i l r =>
| .ioSetInfo τ _ c k i l r =>
collectInTypedTerm
(collectInTypedTerm
(collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) k) i) l) r
| .ioRead τ _ i _ => collectInTypedTerm (collectInTyp seen τ) i
| .ioWrite τ _ d r => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) d) r
(collectInTypedTerm
(collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) k) i) l) r
| .ioRead τ _ c i _ =>
collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) i
| .ioWrite τ _ c d r =>
collectInTypedTerm
(collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) d) r
| .debug τ _ _ t r =>
let seen := collectInTyp seen τ
let seen := match t with | some t => collectInTypedTerm seen t | none => seen
Expand Down Expand Up @@ -681,17 +698,22 @@ def collectCalls (decls : Typed.Decls)
| .u8LessThan _ _ a b | .u32LessThan _ _ a b =>
collectCalls decls (collectCalls decls seen a) b
| .eqZero _ _ a | .store _ _ a | .load _ _ a | .ptrVal _ _ a
| .u8BitDecomposition _ _ a | .u8ShiftLeft _ _ a | .u8ShiftRight _ _ a
| .ioGetInfo _ _ a => collectCalls decls seen a
| .u8BitDecomposition _ _ a | .u8ShiftLeft _ _ a | .u8ShiftRight _ _ a =>
collectCalls decls seen a
| .ioGetInfo _ _ c k =>
collectCalls decls (collectCalls decls seen c) k
| .proj _ _ a _ | .get _ _ a _ | .slice _ _ a _ _ => collectCalls decls seen a
| .set _ _ a _ v => collectCalls decls (collectCalls decls seen a) v
| .assertEq _ _ a b r =>
collectCalls decls (collectCalls decls (collectCalls decls seen a) b) r
| .ioSetInfo _ _ k i l r =>
| .ioSetInfo _ _ c k i l r =>
collectCalls decls
(collectCalls decls (collectCalls decls (collectCalls decls seen k) i) l) r
| .ioRead _ _ i _ => collectCalls decls seen i
| .ioWrite _ _ d r => collectCalls decls (collectCalls decls seen d) r
(collectCalls decls
(collectCalls decls
(collectCalls decls (collectCalls decls seen c) k) i) l) r
| .ioRead _ _ c i _ => collectCalls decls (collectCalls decls seen c) i
| .ioWrite _ _ c d r =>
collectCalls decls (collectCalls decls (collectCalls decls seen c) d) r
| .debug _ _ _ t r =>
let seen := match t with | some t => collectCalls decls seen t | none => seen
collectCalls decls seen r
Expand Down Expand Up @@ -748,14 +770,19 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term
| .assertEq τ e a b r =>
.assertEq (Typ.instantiate subst τ) e (substInTypedTerm subst a)
(substInTypedTerm subst b) (substInTypedTerm subst r)
| .ioGetInfo τ e k => .ioGetInfo (Typ.instantiate subst τ) e (substInTypedTerm subst k)
| .ioSetInfo τ e k i l r =>
| .ioGetInfo τ e c k =>
.ioGetInfo (Typ.instantiate subst τ) e
(substInTypedTerm subst c) (substInTypedTerm subst k)
| .ioSetInfo τ e c k i l r =>
.ioSetInfo (Typ.instantiate subst τ) e
(substInTypedTerm subst k) (substInTypedTerm subst i)
(substInTypedTerm subst c) (substInTypedTerm subst k)
(substInTypedTerm subst i)
(substInTypedTerm subst l) (substInTypedTerm subst r)
| .ioRead τ e i n => .ioRead (Typ.instantiate subst τ) e (substInTypedTerm subst i) n
| .ioWrite τ e d r => .ioWrite (Typ.instantiate subst τ) e
(substInTypedTerm subst d) (substInTypedTerm subst r)
| .ioRead τ e c i n =>
.ioRead (Typ.instantiate subst τ) e
(substInTypedTerm subst c) (substInTypedTerm subst i) n
| .ioWrite τ e c d r => .ioWrite (Typ.instantiate subst τ) e
(substInTypedTerm subst c) (substInTypedTerm subst d) (substInTypedTerm subst r)
| .u8BitDecomposition τ e a =>
.u8BitDecomposition (Typ.instantiate subst τ) e (substInTypedTerm subst a)
| .u8ShiftLeft τ e a =>
Expand Down
4 changes: 2 additions & 2 deletions Ix/Aiur/Compiler/Layout.lean
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def opLayout : Bytecode.Op → LayoutM Unit
pushDegrees $ .replicate size 1
bumpAuxiliaries size; bumpLookups; addMemSize size
| .assertEq .. => pure ()
| .ioGetInfo _ => do pushDegrees #[1, 1]; bumpAuxiliaries 2
| .ioGetInfo _ _ => do pushDegrees #[1, 1]; bumpAuxiliaries 2
| .ioSetInfo .. => pure ()
| .ioRead _ len => do pushDegrees $ .replicate len 1; bumpAuxiliaries len
| .ioRead _ _ len => do pushDegrees $ .replicate len 1; bumpAuxiliaries len
| .ioWrite .. => pure ()
| .u8BitDecomposition _ => do pushDegrees $ .replicate 8 1; bumpAuxiliaries 8; bumpLookups
| .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do
Expand Down
30 changes: 18 additions & 12 deletions Ix/Aiur/Compiler/Lower.lean
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,25 @@ def toIndex
let b ← toIndex layoutMap bindings b
modify fun stt => { stt with ops := stt.ops.push (.assertEq a b) }
toIndex layoutMap bindings ret
| .ioGetInfo _ _ key => do
| .ioGetInfo _ _ channel key => do
let channel ← expectIdx layoutMap bindings channel
let key ← toIndex layoutMap bindings key
pushOp (.ioGetInfo key) 2
| .ioSetInfo _ _ key idx len ret => do
pushOp (.ioGetInfo channel key) 2
| .ioSetInfo _ _ channel key idx len ret => do
let channel ← expectIdx layoutMap bindings channel
let key ← toIndex layoutMap bindings key
let idx ← expectIdx layoutMap bindings idx
let len ← expectIdx layoutMap bindings len
modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo key idx len) }
modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo channel key idx len) }
toIndex layoutMap bindings ret
| .ioRead _ _ idx len => do
| .ioRead _ _ channel idx len => do
let channel ← expectIdx layoutMap bindings channel
let idx ← expectIdx layoutMap bindings idx
pushOp (.ioRead idx len) len
| .ioWrite _ _ data ret => do
pushOp (.ioRead channel idx len) len
| .ioWrite _ _ channel data ret => do
let channel ← expectIdx layoutMap bindings channel
let data ← toIndex layoutMap bindings data
modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) }
modify fun stt => { stt with ops := stt.ops.push (.ioWrite channel data) }
toIndex layoutMap bindings ret
| .u8BitDecomposition _ _ byte => do
let byte ← expectIdx layoutMap bindings byte
Expand Down Expand Up @@ -442,15 +446,17 @@ def Concrete.Term.compile
let b ← toIndex layoutMap bindings b
modify fun stt => { stt with ops := stt.ops.push (.assertEq a b) }
ret.compile returnTyp layoutMap bindings yieldCtrl
| .ioSetInfo _ _ key idx len ret => do
| .ioSetInfo _ _ channel key idx len ret => do
let channel ← toIndex layoutMap bindings channel
let key ← toIndex layoutMap bindings key
let idx ← toIndex layoutMap bindings idx
let len ← toIndex layoutMap bindings len
modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo key idx[0]! len[0]!) }
modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo channel[0]! key idx[0]! len[0]!) }
ret.compile returnTyp layoutMap bindings yieldCtrl
| .ioWrite _ _ data ret => do
| .ioWrite _ _ channel data ret => do
let channel ← toIndex layoutMap bindings channel
let data ← toIndex layoutMap bindings data
modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) }
modify fun stt => { stt with ops := stt.ops.push (.ioWrite channel[0]! data) }
ret.compile returnTyp layoutMap bindings yieldCtrl
| .match _ _ scrut cases defaultOpt => do
let idxs := bindings[scrut]?.getD #[0]
Expand Down
11 changes: 6 additions & 5 deletions Ix/Aiur/Compiler/Match.lean
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,12 @@ def typedToSimple : Term → Simple.Term
| .load τ e a => .load τ e (typedToSimple a)
| .ptrVal τ e a => .ptrVal τ e (typedToSimple a)
| .assertEq τ e a b r => .assertEq τ e (typedToSimple a) (typedToSimple b) (typedToSimple r)
| .ioGetInfo τ e k => .ioGetInfo τ e (typedToSimple k)
| .ioSetInfo τ e k i l r =>
.ioSetInfo τ e (typedToSimple k) (typedToSimple i) (typedToSimple l) (typedToSimple r)
| .ioRead τ e i n => .ioRead τ e (typedToSimple i) n
| .ioWrite τ e d r => .ioWrite τ e (typedToSimple d) (typedToSimple r)
| .ioGetInfo τ e c k => .ioGetInfo τ e (typedToSimple c) (typedToSimple k)
| .ioSetInfo τ e c k i l r =>
.ioSetInfo τ e (typedToSimple c) (typedToSimple k) (typedToSimple i)
(typedToSimple l) (typedToSimple r)
| .ioRead τ e c i n => .ioRead τ e (typedToSimple c) (typedToSimple i) n
| .ioWrite τ e c d r => .ioWrite τ e (typedToSimple c) (typedToSimple d) (typedToSimple r)
| .u8BitDecomposition τ e a => .u8BitDecomposition τ e (typedToSimple a)
| .u8ShiftLeft τ e a => .u8ShiftLeft τ e (typedToSimple a)
| .u8ShiftRight τ e a => .u8ShiftRight τ e (typedToSimple a)
Expand Down
10 changes: 6 additions & 4 deletions Ix/Aiur/Compiler/Simple.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,18 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term
let b' ← simplifyTypedTerm decls b
let r' ← simplifyTypedTerm decls r
pure (.assertEq τ e a' b' r')
| .ioSetInfo τ e k i l r => do
| .ioSetInfo τ e c k i l r => do
let c' ← simplifyTypedTerm decls c
let k' ← simplifyTypedTerm decls k
let i' ← simplifyTypedTerm decls i
let l' ← simplifyTypedTerm decls l
let r' ← simplifyTypedTerm decls r
pure (.ioSetInfo τ e k' i' l' r')
| .ioWrite τ e d r => do
pure (.ioSetInfo τ e c' k' i' l' r')
| .ioWrite τ e c d r => do
let c' ← simplifyTypedTerm decls c
let d' ← simplifyTypedTerm decls d
let r' ← simplifyTypedTerm decls r
pure (.ioWrite τ e d' r')
pure (.ioWrite τ e c' d' r')
| .u8LessThan τ e a b => do
let a' ← simplifyTypedTerm decls a
let b' ← simplifyTypedTerm decls b
Expand Down
Loading