diff --git a/Benchmarks/Blake3.lean b/Benchmarks/Blake3.lean index 1371db37..74c36892 100644 --- a/Benchmarks/Blake3.lean +++ b/Benchmarks/Blake3.lean @@ -44,10 +44,7 @@ def blake3Bench : IO $ Array BenchReport := do let data := Array.range dataSize |>.map -- Add `idx` so every preimage is different and avoids memoization. fun i => Aiur.G.ofUInt8 (i + idx).toUInt8 - let ioKeyInfo := ⟨ioBuffer.data.size, dataSize⟩ - { ioBuffer with - data := ioBuffer.data ++ data - map := ioBuffer.map.insert #[.ofNat idx] ioKeyInfo } + ioBuffer.extend 0 #[.ofNat idx] data throughput (.ElementsAndBytes numHashes.toUInt64 (dataSize * numHashes).toUInt64 "hashes") bench s!"dataSize={dataSize} numHashes={numHashes}" (aiurSystem.prove friParameters funIdx #[Aiur.G.ofNat numHashes]) ioBuffer diff --git a/Benchmarks/Sha256.lean b/Benchmarks/Sha256.lean index be7598d5..beea833e 100644 --- a/Benchmarks/Sha256.lean +++ b/Benchmarks/Sha256.lean @@ -44,10 +44,7 @@ def sha256Bench : IO $ Array BenchReport := do let data := Array.range dataSize |>.map -- Add `idx` so every preimage is different and avoids memoization. fun i => Aiur.G.ofUInt8 (i + idx).toUInt8 - let ioKeyInfo := ⟨ioBuffer.data.size, dataSize⟩ - { ioBuffer with - data := ioBuffer.data ++ data - map := ioBuffer.map.insert #[.ofNat idx] ioKeyInfo } + ioBuffer.extend 0 #[.ofNat idx] data throughput (.ElementsAndBytes numHashes.toUInt64 (dataSize * numHashes).toUInt64 "hashes") bench s!"dataSize={dataSize} numHashes={numHashes}" (aiurSystem.prove friParameters funIdx #[Aiur.G.ofNat numHashes]) ioBuffer diff --git a/Ix/Aiur/Compiler/Check.lean b/Ix/Aiur/Compiler/Check.lean index 935192b4..e47e931b 100644 --- a/Ix/Aiur/Compiler/Check.lean +++ b/Ix/Aiur/Compiler/Check.lean @@ -759,13 +759,15 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with let a' ← checkNoEscape a .field let b' ← checkNoEscape b .field pure (Typed.Term.u32LessThan .field false a' b') - | .ioGetInfo key => do + | .ioGetInfo channel key => do + let channel' ← checkNoEscape channel .field let key' ← inferNoEscape key match ← walkTyp key'.typ with | .array .. => - pure (Typed.Term.ioGetInfo (.tuple #[.field, .field]) false key') + pure (Typed.Term.ioGetInfo (.tuple #[.field, .field]) false channel' key') | typ' => throw $ .notAnArray typ' - | .ioSetInfo key idx len ret => do + | .ioSetInfo channel key idx len ret => do + let channel' ← checkNoEscape channel .field let key' ← inferNoEscape key match ← walkTyp key'.typ with | .array keyEltTyp _ => @@ -773,19 +775,21 @@ def inferTerm (t : Term) : CheckM Typed.Term := match t with let idx' ← checkNoEscape idx .field let len' ← checkNoEscape len .field let ret' ← inferTerm ret - pure (Typed.Term.ioSetInfo ret'.typ ret'.escapes key' idx' len' ret') + pure (Typed.Term.ioSetInfo ret'.typ ret'.escapes channel' key' idx' len' ret') | typ' => throw $ .notAnArray typ' - | .ioRead idx len => do + | .ioRead channel idx len => do if len = 0 then throw .emptyArray + let channel' ← checkNoEscape channel .field let idx' ← checkNoEscape idx .field - pure (Typed.Term.ioRead (.array .field len) false idx' len) - | .ioWrite data ret => do + pure (Typed.Term.ioRead (.array .field len) false channel' idx' len) + | .ioWrite channel data ret => do + let channel' ← checkNoEscape channel .field let data' ← inferNoEscape data match ← walkTyp data'.typ with | .array dataEltTyp _ => unless ← unifyTyp dataEltTyp .field do throw $ .typeMismatch .field dataEltTyp let ret' ← inferTerm ret - pure (Typed.Term.ioWrite ret'.typ ret'.escapes data' ret') + pure (Typed.Term.ioWrite ret'.typ ret'.escapes channel' data' ret') | typ' => throw $ .notAnArray typ' | .assertEq a b ret => do let a' ← inferNoEscape a @@ -900,12 +904,15 @@ def zonkTypedTerm (t : Typed.Term) : CheckM Typed.Term := match t with | .ptrVal τ e a => do pure (.ptrVal (← zonkTyp τ) e (← zonkTypedTerm a)) | .assertEq τ e a b r => do pure (.assertEq (← zonkTyp τ) e (← zonkTypedTerm a) (← zonkTypedTerm b) (← zonkTypedTerm r)) - | .ioGetInfo τ e k => do pure (.ioGetInfo (← zonkTyp τ) e (← zonkTypedTerm k)) - | .ioSetInfo τ e k i l r => do - pure (.ioSetInfo (← zonkTyp τ) e (← zonkTypedTerm k) (← zonkTypedTerm i) - (← zonkTypedTerm l) (← zonkTypedTerm r)) - | .ioRead τ e i n => do pure (.ioRead (← zonkTyp τ) e (← zonkTypedTerm i) n) - | .ioWrite τ e d r => do pure (.ioWrite (← zonkTyp τ) e (← zonkTypedTerm d) (← zonkTypedTerm r)) + | .ioGetInfo τ e c k => do + pure (.ioGetInfo (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm k)) + | .ioSetInfo τ e c k i l r => do + pure (.ioSetInfo (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm k) + (← zonkTypedTerm i) (← zonkTypedTerm l) (← zonkTypedTerm r)) + | .ioRead τ e c i n => do + pure (.ioRead (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm i) n) + | .ioWrite τ e c d r => do + pure (.ioWrite (← zonkTyp τ) e (← zonkTypedTerm c) (← zonkTypedTerm d) (← zonkTypedTerm r)) | .u8BitDecomposition τ e a => do pure (.u8BitDecomposition (← zonkTyp τ) e (← zonkTypedTerm a)) | .u8ShiftLeft τ e a => do pure (.u8ShiftLeft (← zonkTyp τ) e (← zonkTypedTerm a)) diff --git a/Ix/Aiur/Compiler/Concretize.lean b/Ix/Aiur/Compiler/Concretize.lean index 8d89fbb8..87c45050 100644 --- a/Ix/Aiur/Compiler/Concretize.lean +++ b/Ix/Aiur/Compiler/Concretize.lean @@ -304,15 +304,21 @@ def termToConcrete | .assertEq τ e a b r => do pure (.assertEq (← typToConcrete mono τ) e (← termToConcrete mono a) (← termToConcrete mono b) (← termToConcrete mono r)) - | .ioGetInfo τ e k => do pure (.ioGetInfo (← typToConcrete mono τ) e (← termToConcrete mono k)) - | .ioSetInfo τ e k i l r => do + | .ioGetInfo τ e c k => do + pure (.ioGetInfo (← typToConcrete mono τ) e + (← termToConcrete mono c) (← termToConcrete mono k)) + | .ioSetInfo τ e c k i l r => do pure (.ioSetInfo (← typToConcrete mono τ) e - (← termToConcrete mono k) (← termToConcrete mono i) + (← termToConcrete mono c) (← termToConcrete mono k) + (← termToConcrete mono i) (← termToConcrete mono l) (← termToConcrete mono r)) - | .ioRead τ e i n => do pure (.ioRead (← typToConcrete mono τ) e (← termToConcrete mono i) n) - | .ioWrite τ e d r => do + | .ioRead τ e c i n => do + pure (.ioRead (← typToConcrete mono τ) e + (← termToConcrete mono c) (← termToConcrete mono i) n) + | .ioWrite τ e c d r => do pure (.ioWrite (← typToConcrete mono τ) e - (← termToConcrete mono d) (← termToConcrete mono r)) + (← termToConcrete mono c) (← termToConcrete mono d) + (← termToConcrete mono r)) | .u8BitDecomposition τ e a => do pure (.u8BitDecomposition (← typToConcrete mono τ) e (← termToConcrete mono a)) | .u8ShiftLeft τ e a => do @@ -510,16 +516,21 @@ def rewriteTypedTerm (decls : Typed.Decls) .assertEq (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a) (rewriteTypedTerm decls subst mono b) (rewriteTypedTerm decls subst mono r) - | .ioGetInfo τ e k => - .ioGetInfo (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono k) - | .ioSetInfo τ e k i l r => + | .ioGetInfo τ e c k => + .ioGetInfo (rewriteTyp subst mono τ) e + (rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono k) + | .ioSetInfo τ e c k i l r => .ioSetInfo (rewriteTyp subst mono τ) e - (rewriteTypedTerm decls subst mono k) (rewriteTypedTerm decls subst mono i) + (rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono k) + (rewriteTypedTerm decls subst mono i) (rewriteTypedTerm decls subst mono l) (rewriteTypedTerm decls subst mono r) - | .ioRead τ e i n => .ioRead (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono i) n - | .ioWrite τ e d r => + | .ioRead τ e c i n => + .ioRead (rewriteTyp subst mono τ) e + (rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono i) n + | .ioWrite τ e c d r => .ioWrite (rewriteTyp subst mono τ) e - (rewriteTypedTerm decls subst mono d) (rewriteTypedTerm decls subst mono r) + (rewriteTypedTerm decls subst mono c) (rewriteTypedTerm decls subst mono d) + (rewriteTypedTerm decls subst mono r) | .u8BitDecomposition τ e a => .u8BitDecomposition (rewriteTyp subst mono τ) e (rewriteTypedTerm decls subst mono a) | .u8ShiftLeft τ e a => @@ -618,20 +629,26 @@ def collectInTypedTerm (seen : Std.HashSet (Global × Array Typ)) : | .u8LessThan τ _ a b | .u32LessThan τ _ a b => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b | .eqZero τ _ a | .store τ _ a | .load τ _ a | .ptrVal τ _ a - | .u8BitDecomposition τ _ a | .u8ShiftLeft τ _ a | .u8ShiftRight τ _ a - | .ioGetInfo τ _ a => collectInTypedTerm (collectInTyp seen τ) a + | .u8BitDecomposition τ _ a | .u8ShiftLeft τ _ a | .u8ShiftRight τ _ a => + collectInTypedTerm (collectInTyp seen τ) a + | .ioGetInfo τ _ c k => + collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) k | .proj τ _ a _ | .get τ _ a _ | .slice τ _ a _ _ => collectInTypedTerm (collectInTyp seen τ) a | .set τ _ a _ v => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) v | .assertEq τ _ a b r => collectInTypedTerm (collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) a) b) r - | .ioSetInfo τ _ k i l r => + | .ioSetInfo τ _ c k i l r => collectInTypedTerm (collectInTypedTerm - (collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) k) i) l) r - | .ioRead τ _ i _ => collectInTypedTerm (collectInTyp seen τ) i - | .ioWrite τ _ d r => collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) d) r + (collectInTypedTerm + (collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) k) i) l) r + | .ioRead τ _ c i _ => + collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) i + | .ioWrite τ _ c d r => + collectInTypedTerm + (collectInTypedTerm (collectInTypedTerm (collectInTyp seen τ) c) d) r | .debug τ _ _ t r => let seen := collectInTyp seen τ let seen := match t with | some t => collectInTypedTerm seen t | none => seen @@ -681,17 +698,22 @@ def collectCalls (decls : Typed.Decls) | .u8LessThan _ _ a b | .u32LessThan _ _ a b => collectCalls decls (collectCalls decls seen a) b | .eqZero _ _ a | .store _ _ a | .load _ _ a | .ptrVal _ _ a - | .u8BitDecomposition _ _ a | .u8ShiftLeft _ _ a | .u8ShiftRight _ _ a - | .ioGetInfo _ _ a => collectCalls decls seen a + | .u8BitDecomposition _ _ a | .u8ShiftLeft _ _ a | .u8ShiftRight _ _ a => + collectCalls decls seen a + | .ioGetInfo _ _ c k => + collectCalls decls (collectCalls decls seen c) k | .proj _ _ a _ | .get _ _ a _ | .slice _ _ a _ _ => collectCalls decls seen a | .set _ _ a _ v => collectCalls decls (collectCalls decls seen a) v | .assertEq _ _ a b r => collectCalls decls (collectCalls decls (collectCalls decls seen a) b) r - | .ioSetInfo _ _ k i l r => + | .ioSetInfo _ _ c k i l r => collectCalls decls - (collectCalls decls (collectCalls decls (collectCalls decls seen k) i) l) r - | .ioRead _ _ i _ => collectCalls decls seen i - | .ioWrite _ _ d r => collectCalls decls (collectCalls decls seen d) r + (collectCalls decls + (collectCalls decls + (collectCalls decls (collectCalls decls seen c) k) i) l) r + | .ioRead _ _ c i _ => collectCalls decls (collectCalls decls seen c) i + | .ioWrite _ _ c d r => + collectCalls decls (collectCalls decls (collectCalls decls seen c) d) r | .debug _ _ _ t r => let seen := match t with | some t => collectCalls decls seen t | none => seen collectCalls decls seen r @@ -748,14 +770,19 @@ def substInTypedTerm (subst : Global → Option Typ) : Typed.Term → Typed.Term | .assertEq τ e a b r => .assertEq (Typ.instantiate subst τ) e (substInTypedTerm subst a) (substInTypedTerm subst b) (substInTypedTerm subst r) - | .ioGetInfo τ e k => .ioGetInfo (Typ.instantiate subst τ) e (substInTypedTerm subst k) - | .ioSetInfo τ e k i l r => + | .ioGetInfo τ e c k => + .ioGetInfo (Typ.instantiate subst τ) e + (substInTypedTerm subst c) (substInTypedTerm subst k) + | .ioSetInfo τ e c k i l r => .ioSetInfo (Typ.instantiate subst τ) e - (substInTypedTerm subst k) (substInTypedTerm subst i) + (substInTypedTerm subst c) (substInTypedTerm subst k) + (substInTypedTerm subst i) (substInTypedTerm subst l) (substInTypedTerm subst r) - | .ioRead τ e i n => .ioRead (Typ.instantiate subst τ) e (substInTypedTerm subst i) n - | .ioWrite τ e d r => .ioWrite (Typ.instantiate subst τ) e - (substInTypedTerm subst d) (substInTypedTerm subst r) + | .ioRead τ e c i n => + .ioRead (Typ.instantiate subst τ) e + (substInTypedTerm subst c) (substInTypedTerm subst i) n + | .ioWrite τ e c d r => .ioWrite (Typ.instantiate subst τ) e + (substInTypedTerm subst c) (substInTypedTerm subst d) (substInTypedTerm subst r) | .u8BitDecomposition τ e a => .u8BitDecomposition (Typ.instantiate subst τ) e (substInTypedTerm subst a) | .u8ShiftLeft τ e a => diff --git a/Ix/Aiur/Compiler/Layout.lean b/Ix/Aiur/Compiler/Layout.lean index 8fda9d36..c7eff3e4 100644 --- a/Ix/Aiur/Compiler/Layout.lean +++ b/Ix/Aiur/Compiler/Layout.lean @@ -178,9 +178,9 @@ def opLayout : Bytecode.Op → LayoutM Unit pushDegrees $ .replicate size 1 bumpAuxiliaries size; bumpLookups; addMemSize size | .assertEq .. => pure () - | .ioGetInfo _ => do pushDegrees #[1, 1]; bumpAuxiliaries 2 + | .ioGetInfo _ _ => do pushDegrees #[1, 1]; bumpAuxiliaries 2 | .ioSetInfo .. => pure () - | .ioRead _ len => do pushDegrees $ .replicate len 1; bumpAuxiliaries len + | .ioRead _ _ len => do pushDegrees $ .replicate len 1; bumpAuxiliaries len | .ioWrite .. => pure () | .u8BitDecomposition _ => do pushDegrees $ .replicate 8 1; bumpAuxiliaries 8; bumpLookups | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do diff --git a/Ix/Aiur/Compiler/Lower.lean b/Ix/Aiur/Compiler/Lower.lean index 744d6e3d..3071226a 100644 --- a/Ix/Aiur/Compiler/Lower.lean +++ b/Ix/Aiur/Compiler/Lower.lean @@ -256,21 +256,25 @@ def toIndex let b ← toIndex layoutMap bindings b modify fun stt => { stt with ops := stt.ops.push (.assertEq a b) } toIndex layoutMap bindings ret - | .ioGetInfo _ _ key => do + | .ioGetInfo _ _ channel key => do + let channel ← expectIdx layoutMap bindings channel let key ← toIndex layoutMap bindings key - pushOp (.ioGetInfo key) 2 - | .ioSetInfo _ _ key idx len ret => do + pushOp (.ioGetInfo channel key) 2 + | .ioSetInfo _ _ channel key idx len ret => do + let channel ← expectIdx layoutMap bindings channel let key ← toIndex layoutMap bindings key let idx ← expectIdx layoutMap bindings idx let len ← expectIdx layoutMap bindings len - modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo key idx len) } + modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo channel key idx len) } toIndex layoutMap bindings ret - | .ioRead _ _ idx len => do + | .ioRead _ _ channel idx len => do + let channel ← expectIdx layoutMap bindings channel let idx ← expectIdx layoutMap bindings idx - pushOp (.ioRead idx len) len - | .ioWrite _ _ data ret => do + pushOp (.ioRead channel idx len) len + | .ioWrite _ _ channel data ret => do + let channel ← expectIdx layoutMap bindings channel let data ← toIndex layoutMap bindings data - modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) } + modify fun stt => { stt with ops := stt.ops.push (.ioWrite channel data) } toIndex layoutMap bindings ret | .u8BitDecomposition _ _ byte => do let byte ← expectIdx layoutMap bindings byte @@ -442,15 +446,17 @@ def Concrete.Term.compile let b ← toIndex layoutMap bindings b modify fun stt => { stt with ops := stt.ops.push (.assertEq a b) } ret.compile returnTyp layoutMap bindings yieldCtrl - | .ioSetInfo _ _ key idx len ret => do + | .ioSetInfo _ _ channel key idx len ret => do + let channel ← toIndex layoutMap bindings channel let key ← toIndex layoutMap bindings key let idx ← toIndex layoutMap bindings idx let len ← toIndex layoutMap bindings len - modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo key idx[0]! len[0]!) } + modify fun stt => { stt with ops := stt.ops.push (.ioSetInfo channel[0]! key idx[0]! len[0]!) } ret.compile returnTyp layoutMap bindings yieldCtrl - | .ioWrite _ _ data ret => do + | .ioWrite _ _ channel data ret => do + let channel ← toIndex layoutMap bindings channel let data ← toIndex layoutMap bindings data - modify fun stt => { stt with ops := stt.ops.push (.ioWrite data) } + modify fun stt => { stt with ops := stt.ops.push (.ioWrite channel[0]! data) } ret.compile returnTyp layoutMap bindings yieldCtrl | .match _ _ scrut cases defaultOpt => do let idxs := bindings[scrut]?.getD #[0] diff --git a/Ix/Aiur/Compiler/Match.lean b/Ix/Aiur/Compiler/Match.lean index d57c27c8..22d044f5 100644 --- a/Ix/Aiur/Compiler/Match.lean +++ b/Ix/Aiur/Compiler/Match.lean @@ -370,11 +370,12 @@ def typedToSimple : Term → Simple.Term | .load τ e a => .load τ e (typedToSimple a) | .ptrVal τ e a => .ptrVal τ e (typedToSimple a) | .assertEq τ e a b r => .assertEq τ e (typedToSimple a) (typedToSimple b) (typedToSimple r) - | .ioGetInfo τ e k => .ioGetInfo τ e (typedToSimple k) - | .ioSetInfo τ e k i l r => - .ioSetInfo τ e (typedToSimple k) (typedToSimple i) (typedToSimple l) (typedToSimple r) - | .ioRead τ e i n => .ioRead τ e (typedToSimple i) n - | .ioWrite τ e d r => .ioWrite τ e (typedToSimple d) (typedToSimple r) + | .ioGetInfo τ e c k => .ioGetInfo τ e (typedToSimple c) (typedToSimple k) + | .ioSetInfo τ e c k i l r => + .ioSetInfo τ e (typedToSimple c) (typedToSimple k) (typedToSimple i) + (typedToSimple l) (typedToSimple r) + | .ioRead τ e c i n => .ioRead τ e (typedToSimple c) (typedToSimple i) n + | .ioWrite τ e c d r => .ioWrite τ e (typedToSimple c) (typedToSimple d) (typedToSimple r) | .u8BitDecomposition τ e a => .u8BitDecomposition τ e (typedToSimple a) | .u8ShiftLeft τ e a => .u8ShiftLeft τ e (typedToSimple a) | .u8ShiftRight τ e a => .u8ShiftRight τ e (typedToSimple a) diff --git a/Ix/Aiur/Compiler/Simple.lean b/Ix/Aiur/Compiler/Simple.lean index ec33f504..504a1aed 100644 --- a/Ix/Aiur/Compiler/Simple.lean +++ b/Ix/Aiur/Compiler/Simple.lean @@ -90,16 +90,18 @@ def simplifyTypedTerm (decls : Source.Decls) : Term → Except CheckError Term let b' ← simplifyTypedTerm decls b let r' ← simplifyTypedTerm decls r pure (.assertEq τ e a' b' r') - | .ioSetInfo τ e k i l r => do + | .ioSetInfo τ e c k i l r => do + let c' ← simplifyTypedTerm decls c let k' ← simplifyTypedTerm decls k let i' ← simplifyTypedTerm decls i let l' ← simplifyTypedTerm decls l let r' ← simplifyTypedTerm decls r - pure (.ioSetInfo τ e k' i' l' r') - | .ioWrite τ e d r => do + pure (.ioSetInfo τ e c' k' i' l' r') + | .ioWrite τ e c d r => do + let c' ← simplifyTypedTerm decls c let d' ← simplifyTypedTerm decls d let r' ← simplifyTypedTerm decls r - pure (.ioWrite τ e d' r') + pure (.ioWrite τ e c' d' r') | .u8LessThan τ e a b => do let a' ← simplifyTypedTerm decls a let b' ← simplifyTypedTerm decls b diff --git a/Ix/Aiur/Interpret.lean b/Ix/Aiur/Interpret.lean index 1ea7c985..be7431ca 100644 --- a/Ix/Aiur/Interpret.lean +++ b/Ix/Aiur/Interpret.lean @@ -408,34 +408,41 @@ partial def interp (decls : Decls) (bindings : Bindings) : Term → InterpM Valu let store ← getStore dbg_trace s!"{label}: {Value.ppDeref store 16 v}" interp decls bindings cont - | .ioGetInfo key => do + | .ioGetInfo channel key => do + let channelG ← expectField (← interp decls bindings channel) let keyGs ← expectFieldArray (← interp decls bindings key) let io ← getIOBuffer - match io.map[keyGs]? with + match io.map[(channelG, keyGs)]? with | some info => return .tuple #[.field (.ofNat info.idx), .field (.ofNat info.len)] | none => throwErr s!"ioGetInfo: key not found" - | .ioSetInfo key idx len ret => do + | .ioSetInfo channel key idx len ret => do + let channelG ← expectField (← interp decls bindings channel) let keyGs ← expectFieldArray (← interp decls bindings key) let idxVal ← expectField (← interp decls bindings idx) let lenVal ← expectField (← interp decls bindings len) let io ← getIOBuffer - if io.map.contains keyGs then + if io.map.contains (channelG, keyGs) then throwErr s!"ioSetInfo: key already set" let info : IOKeyInfo := ⟨idxVal.val.toNat, lenVal.val.toNat⟩ - modifyIOBuffer fun io => { io with map := io.map.insert keyGs info } + modifyIOBuffer fun io => { io with map := io.map.insert (channelG, keyGs) info } interp decls bindings ret - | .ioRead idx len => do + | .ioRead channel idx len => do + let channelG ← expectField (← interp decls bindings channel) let idxVal ← expectField (← interp decls bindings idx) let io ← getIOBuffer + let arena := io.data.getD channelG #[] let start := idxVal.val.toNat - if start + len > io.data.size then + if start + len > arena.size then throwErr s!"ioRead: out-of-bounds read at {start} for length {len} \ - (buffer size {io.data.size})" - return .array (io.data.extract start (start + len) |>.map .field) - | .ioWrite data ret => do + (channel {channelG.val}, arena size {arena.size})" + return .array (arena.extract start (start + len) |>.map .field) + | .ioWrite channel data ret => do + let channelG ← expectField (← interp decls bindings channel) let dataGs ← expectFieldArray (← interp decls bindings data) - modifyIOBuffer fun io => { io with data := io.data ++ dataGs } + modifyIOBuffer fun io => + let arena := io.data.getD channelG #[] + { io with data := io.data.insert channelG (arena ++ dataGs) } interp decls bindings ret end diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index 159e2193..4f432dea 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -163,11 +163,11 @@ syntax "load" "(" aiur_trm ")" : a syntax "ptr_val" "(" aiur_trm ")" : aiur_trm syntax "assert_eq!" "(" aiur_trm ", " aiur_trm ")" ";" (aiur_trm)? : aiur_trm syntax aiur_trm ": " aiur_typ : aiur_trm -syntax "io_get_info" "(" aiur_trm ")" : aiur_trm -syntax "io_set_info" "(" aiur_trm ", " aiur_trm ", " aiur_trm ")" ";" +syntax "io_get_info" "(" aiur_trm ", " aiur_trm ")" : aiur_trm +syntax "io_set_info" "(" aiur_trm ", " aiur_trm ", " aiur_trm ", " aiur_trm ")" ";" (aiur_trm)? : aiur_trm -syntax "io_read" "(" aiur_trm ", " num ")" : aiur_trm -syntax "io_write" "(" aiur_trm ")" ";" (aiur_trm)? : aiur_trm +syntax "io_read" "(" aiur_trm ", " aiur_trm ", " num ")" : aiur_trm +syntax "io_write" "(" aiur_trm ", " aiur_trm ")" ";" (aiur_trm)? : aiur_trm syntax "u8_bit_decomposition" "(" aiur_trm ")" : aiur_trm syntax "u8_shift_left" "(" aiur_trm ")" : aiur_trm syntax "u8_shift_right" "(" aiur_trm ")" : aiur_trm @@ -276,15 +276,15 @@ partial def elabTrm : ElabStxCat `aiur_trm mkAppM ``Source.Term.assertEq #[← elabTrm a, ← elabTrm b, ← elabRet ret] | `(aiur_trm| $v:aiur_trm : $t:aiur_typ) => do mkAppM ``Source.Term.ann #[← elabTyp t, ← elabTrm v] - | `(aiur_trm| io_get_info($key:aiur_trm)) => do - mkAppM ``Source.Term.ioGetInfo #[← elabTrm key] - | `(aiur_trm| io_set_info($key:aiur_trm, $idx:aiur_trm, $len:aiur_trm); $[$ret:aiur_trm]?) => do + | `(aiur_trm| io_get_info($ch:aiur_trm, $key:aiur_trm)) => do + mkAppM ``Source.Term.ioGetInfo #[← elabTrm ch, ← elabTrm key] + | `(aiur_trm| io_set_info($ch:aiur_trm, $key:aiur_trm, $idx:aiur_trm, $len:aiur_trm); $[$ret:aiur_trm]?) => do mkAppM ``Source.Term.ioSetInfo - #[← elabTrm key, ← elabTrm idx, ← elabTrm len, ← elabRet ret] - | `(aiur_trm| io_read($idx:aiur_trm, $len:num)) => do - mkAppM ``Source.Term.ioRead #[← elabTrm idx, mkNatLit len.getNat] - | `(aiur_trm| io_write($data:aiur_trm); $[$ret:aiur_trm]?) => do - mkAppM ``Source.Term.ioWrite #[← elabTrm data, ← elabRet ret] + #[← elabTrm ch, ← elabTrm key, ← elabTrm idx, ← elabTrm len, ← elabRet ret] + | `(aiur_trm| io_read($ch:aiur_trm, $idx:aiur_trm, $len:num)) => do + mkAppM ``Source.Term.ioRead #[← elabTrm ch, ← elabTrm idx, mkNatLit len.getNat] + | `(aiur_trm| io_write($ch:aiur_trm, $data:aiur_trm); $[$ret:aiur_trm]?) => do + mkAppM ``Source.Term.ioWrite #[← elabTrm ch, ← elabTrm data, ← elabRet ret] | `(aiur_trm| u8_bit_decomposition($byte:aiur_trm)) => do mkAppM ``Source.Term.u8BitDecomposition #[← elabTrm byte] | `(aiur_trm| u8_shift_left($byte:aiur_trm)) => do @@ -456,22 +456,26 @@ where | `(aiur_trm| $v:aiur_trm : $t:aiur_typ) => do let v ← replaceToken old new v `(aiur_trm| $v : $t) - | `(aiur_trm| io_get_info($key:aiur_trm)) => do + | `(aiur_trm| io_get_info($ch:aiur_trm, $key:aiur_trm)) => do + let ch ← replaceToken old new ch let key ← replaceToken old new key - `(aiur_trm| io_get_info($key)) - | `(aiur_trm| io_set_info($key:aiur_trm, $idx:aiur_trm, $len:aiur_trm); $[$ret:aiur_trm]?) => do + `(aiur_trm| io_get_info($ch, $key)) + | `(aiur_trm| io_set_info($ch:aiur_trm, $key:aiur_trm, $idx:aiur_trm, $len:aiur_trm); $[$ret:aiur_trm]?) => do + let ch ← replaceToken old new ch let key ← replaceToken old new key let idx ← replaceToken old new idx let len ← replaceToken old new len let ret' ← ret.mapM $ replaceToken old new - `(aiur_trm| io_set_info($key, $idx, $len); $[$ret']?) - | `(aiur_trm| io_read($idx:aiur_trm, $len:num)) => do + `(aiur_trm| io_set_info($ch, $key, $idx, $len); $[$ret']?) + | `(aiur_trm| io_read($ch:aiur_trm, $idx:aiur_trm, $len:num)) => do + let ch ← replaceToken old new ch let idx ← replaceToken old new idx - `(aiur_trm| io_read($idx, $len)) - | `(aiur_trm| io_write($data:aiur_trm); $[$ret:aiur_trm]?) => do + `(aiur_trm| io_read($ch, $idx, $len)) + | `(aiur_trm| io_write($ch:aiur_trm, $data:aiur_trm); $[$ret:aiur_trm]?) => do + let ch ← replaceToken old new ch let data ← replaceToken old new data let ret' ← ret.mapM $ replaceToken old new - `(aiur_trm| io_write($data); $[$ret']?) + `(aiur_trm| io_write($ch, $data); $[$ret']?) | `(aiur_trm| u8_bit_decomposition($byte:aiur_trm)) => do let byte ← replaceToken old new byte `(aiur_trm| u8_bit_decomposition($byte)) diff --git a/Ix/Aiur/Protocol.lean b/Ix/Aiur/Protocol.lean index b76aae24..df320811 100644 --- a/Ix/Aiur/Protocol.lean +++ b/Ix/Aiur/Protocol.lean @@ -49,9 +49,10 @@ opaque build : @&Bytecode.Toplevel → @&CommitmentParameters → AiurSystem @[extern "rs_aiur_system_prove"] private opaque prove' : @& AiurSystem → @& FriParameters → - @& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) → - (ioMap : @& Array (Array G × IOKeyInfo)) → - Array G × Proof × Array G × Array (Array G × IOKeyInfo) + @& Bytecode.FunIdx → @& Array G → + (ioData : @& Array (G × Array G)) → + (ioMap : @& Array ((G × Array G) × IOKeyInfo)) → + Array G × Proof × Array (G × Array G) × Array ((G × Array G) × IOKeyInfo) /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, then generates a proof of the computation. Returns the claim @@ -60,10 +61,11 @@ updated `IOBuffer`. -/ def prove (system : @& AiurSystem) (friParameters : @& FriParameters) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : Array G × Proof × IOBuffer := - let ioData := ioBuffer.data - let ioMap := ioBuffer.map + let ioData := ioBuffer.data.toArray + let ioMap := ioBuffer.map.toArray let (claim, proof, ioData, ioMap) := prove' system friParameters funIdx args - ioData ioMap.toArray + ioData ioMap + let ioData := ioData.foldl (fun acc (k, v) => acc.insert k v) ∅ let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅ (claim, proof, ⟨ioData, ioMap⟩) diff --git a/Ix/Aiur/Semantics/BytecodeEval.lean b/Ix/Aiur/Semantics/BytecodeEval.lean index 0118dd6f..b8d7c25d 100644 --- a/Ix/Aiur/Semantics/BytecodeEval.lean +++ b/Ix/Aiur/Semantics/BytecodeEval.lean @@ -199,33 +199,39 @@ def evalOp (t : Bytecode.Toplevel) (fuel : Nat) (op : Op) (st : EvalState) : let aGs ← readIdxs st as let bGs ← readIdxs st bs if aGs == bGs then .ok st else .error .assertFailed - | .ioGetInfo keyIdxs => do + | .ioGetInfo channelIdx keyIdxs => do + let channelG ← readIdx st channelIdx let keyGs ← readIdxs st keyIdxs - match st.ioBuffer.map[keyGs]? with + match st.ioBuffer.map[(channelG, keyGs)]? with | some info => let st1 := pushMap st (.ofNat info.idx) pure (pushMap st1 (.ofNat info.len)) | none => .error .ioKeyNotFound - | .ioSetInfo keyIdxs idxIdx lenIdx => do + | .ioSetInfo channelIdx keyIdxs idxIdx lenIdx => do + let channelG ← readIdx st channelIdx let keyGs ← readIdxs st keyIdxs let iG ← readIdx st idxIdx let lG ← readIdx st lenIdx - if st.ioBuffer.map.contains keyGs then + if st.ioBuffer.map.contains (channelG, keyGs) then .error .ioKeyAlreadySet else let info : IOKeyInfo := ⟨iG.val.toNat, lG.val.toNat⟩ - let newMap := st.ioBuffer.map.insert keyGs info + let newMap := st.ioBuffer.map.insert (channelG, keyGs) info pure (setIoBuffer st { st.ioBuffer with map := newMap }) - | .ioRead idxIdx len => do + | .ioRead channelIdx idxIdx len => do + let channelG ← readIdx st channelIdx let iG ← readIdx st idxIdx let start := iG.val.toNat - if start + len > st.ioBuffer.data.size then + let arena := st.ioBuffer.data.getD channelG #[] + if start + len > arena.size then .error .ioReadOoB else - pure (appendMap st (st.ioBuffer.data.extract start (start + len))) - | .ioWrite dataIdxs => do + pure (appendMap st (arena.extract start (start + len))) + | .ioWrite channelIdx dataIdxs => do + let channelG ← readIdx st channelIdx let dataGs ← readIdxs st dataIdxs - let newData := st.ioBuffer.data ++ dataGs + let arena := st.ioBuffer.data.getD channelG #[] + let newData := st.ioBuffer.data.insert channelG (arena ++ dataGs) pure (setIoBuffer st { st.ioBuffer with data := newData }) | .u8BitDecomposition idx => do let g ← readIdx st idx diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index 08a99894..e3970d70 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -29,20 +29,28 @@ instance : LawfulBEq IOKeyInfo where rfl {a} := by simp [BEq.beq] structure IOBuffer where - data : Array G - map : Std.HashMap (Array G) IOKeyInfo + /-- Per-channel data arenas. `idx` slots into `data[channel]`. -/ + data : Std.HashMap G (Array G) + /-- Channel-keyed info map. Same `key` on different channels resolves + to distinct `IOKeyInfo`. -/ + map : Std.HashMap (G × Array G) IOKeyInfo deriving Inhabited -def IOBuffer.extend (ioBuffer : IOBuffer) (key data : Array G) : IOBuffer := - let idx := ioBuffer.data.size +/-- Append `data` to the `channel` arena and register `key → (idx, len)` +on the same channel. -/ +def IOBuffer.extend (ioBuffer : IOBuffer) (channel : G) (key data : Array G) : + IOBuffer := + let arena := ioBuffer.data.getD channel #[] + let idx := arena.size let len := data.size { ioBuffer with - data := ioBuffer.data ++ data - map := ioBuffer.map.insert key { idx, len } } + data := ioBuffer.data.insert channel (arena ++ data) + map := ioBuffer.map.insert (channel, key) { idx, len } } instance : BEq IOBuffer where beq x y := - x.data == y.data && @BEq.beq _ Std.HashMap.instBEq x.map y.map + @BEq.beq _ Std.HashMap.instBEq x.data y.data && + @BEq.beq _ Std.HashMap.instBEq x.map y.map -- A `LawfulBEq IOBuffer` instance is not provided here. The reflexivity/ -- symmetry/transitivity facts needed downstream (`IOBuffer.equiv_refl`, @@ -63,9 +71,12 @@ namespace Bytecode.Toplevel @[extern "rs_aiur_toplevel_execute"] private opaque execute' : @& Bytecode.Toplevel → - @& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) → - (ioMap : @& Array (Array G × IOKeyInfo)) → - Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat)) + @& Bytecode.FunIdx → @& Array G → + (ioData : @& Array (G × Array G)) → + (ioMap : @& Array ((G × Array G) × IOKeyInfo)) → + Except String (Array G × + (Array (G × Array G) × Array ((G × Array G) × IOKeyInfo)) × + Array (Nat × Nat)) /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, returning the raw output of the function, the updated `IOBuffer`, and an array @@ -75,11 +86,12 @@ callers can recover instead of crashing. -/ def execute (toplevel : @& Bytecode.Toplevel) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : Except String (Array G × IOBuffer × Array QueryCount) := - let ioData := ioBuffer.data - let ioMap := ioBuffer.map - match execute' toplevel funIdx args ioData ioMap.toArray with + let ioData := ioBuffer.data.toArray + let ioMap := ioBuffer.map.toArray + match execute' toplevel funIdx args ioData ioMap with | .error e => .error e | .ok (output, (ioData, ioMap), queryCounts) => + let ioData := ioData.foldl (fun acc (k, v) => acc.insert k v) ∅ let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅ let queryCounts := queryCounts.map fun (uniqueRows, totalHits) => { uniqueRows, totalHits } .ok (output, ⟨ioData, ioMap⟩, queryCounts) diff --git a/Ix/Aiur/Semantics/SourceEval.lean b/Ix/Aiur/Semantics/SourceEval.lean index 2e2dd5ec..87951250 100644 --- a/Ix/Aiur/Semantics/SourceEval.lean +++ b/Ix/Aiur/Semantics/SourceEval.lean @@ -432,22 +432,28 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings) (interp decls fuel bindings t1 st) (fun st1 => interp decls fuel bindings t2 st1) | .debug _ _ ret => interp decls fuel bindings ret st - | .ioGetInfo key => - match interp decls fuel bindings key st with + | .ioGetInfo channel key => + match interp decls fuel bindings channel st with + | .error e => .error e + | .ok (vc, stc) => + match interp decls fuel bindings key stc with | .error e => .error e | .ok (v, st') => - match v with - | .array vs => + match vc, v with + | .field channelG, .array vs => match expectFieldArray vs with | none => .error (.typeMismatch "ioGetInfo key") | some keyGs => - match st'.ioBuffer.map[keyGs]? with + match st'.ioBuffer.map[(channelG, keyGs)]? with | some info => .ok (.tuple #[.field (.ofNat info.idx), .field (.ofNat info.len)], st') | none => .error .ioKeyNotFound - | _ => .error (.typeMismatch "ioGetInfo") - | .ioSetInfo key idx len ret => - match interp decls fuel bindings key st with + | _, _ => .error (.typeMismatch "ioGetInfo") + | .ioSetInfo channel key idx len ret => + match interp decls fuel bindings channel st with + | .error e => .error e + | .ok (vc, stc) => + match interp decls fuel bindings key stc with | .error e => .error e | .ok (vk, stk) => match interp decls fuel bindings idx stk with @@ -456,42 +462,50 @@ def interp (decls : Decls) (fuel : Nat) (bindings : Bindings) match interp decls fuel bindings len sti with | .error e => .error e | .ok (vl, stl) => - match vk, vi, vl with - | .array vs, .field iG, .field lG => + match vc, vk, vi, vl with + | .field channelG, .array vs, .field iG, .field lG => match expectFieldArray vs with | none => .error (.typeMismatch "ioSetInfo key") | some keyGs => - if stl.ioBuffer.map.contains keyGs then + if stl.ioBuffer.map.contains (channelG, keyGs) then .error .ioKeyAlreadySet else let info : IOKeyInfo := ⟨iG.val.toNat, lG.val.toNat⟩ let st' := { stl with ioBuffer := - { stl.ioBuffer with map := stl.ioBuffer.map.insert keyGs info } } + { stl.ioBuffer with map := stl.ioBuffer.map.insert (channelG, keyGs) info } } interp decls fuel bindings ret st' - | _, _, _ => .error (.typeMismatch "ioSetInfo") - | .ioRead idx len => - match interp decls fuel bindings idx st with + | _, _, _, _ => .error (.typeMismatch "ioSetInfo") + | .ioRead channel idx len => + match interp decls fuel bindings channel st with + | .error e => .error e + | .ok (vc, stc) => + match interp decls fuel bindings idx stc with | .error e => .error e | .ok (v, st') => - match v with - | .field g => + match vc, v with + | .field channelG, .field g => let start := g.val.toNat - if start + len > st'.ioBuffer.data.size then .error .ioReadOoB - else .ok (.array (st'.ioBuffer.data.extract start (start + len) |>.map .field), st') - | _ => .error (.typeMismatch "ioRead") - | .ioWrite data ret => - match interp decls fuel bindings data st with + let arena := st'.ioBuffer.data.getD channelG #[] + if start + len > arena.size then .error .ioReadOoB + else .ok (.array (arena.extract start (start + len) |>.map .field), st') + | _, _ => .error (.typeMismatch "ioRead") + | .ioWrite channel data ret => + match interp decls fuel bindings channel st with + | .error e => .error e + | .ok (vc, stc) => + match interp decls fuel bindings data stc with | .error e => .error e | .ok (v, st') => - match v with - | .array vs => + match vc, v with + | .field channelG, .array vs => match expectFieldArray vs with | none => .error (.typeMismatch "ioWrite") | some dataGs => + let arena := st'.ioBuffer.data.getD channelG #[] let st'' := { st' with ioBuffer := - { st'.ioBuffer with data := st'.ioBuffer.data ++ dataGs } } + { st'.ioBuffer with data := st'.ioBuffer.data.insert channelG (arena ++ dataGs) } } interp decls fuel bindings ret st'' - | _ => .error (.typeMismatch "ioWrite") + | _, _ => .error (.typeMismatch "ioWrite") termination_by (fuel, 2, sizeOf t) decreasing_by all_goals first diff --git a/Ix/Aiur/Stages/Bytecode.lean b/Ix/Aiur/Stages/Bytecode.lean index e1e407e6..1892c7de 100644 --- a/Ix/Aiur/Stages/Bytecode.lean +++ b/Ix/Aiur/Stages/Bytecode.lean @@ -28,10 +28,10 @@ inductive Op | store : Array ValIdx → Op | load : (size : Nat) → ValIdx → Op | assertEq : Array ValIdx → Array ValIdx → Op - | ioGetInfo : Array ValIdx → Op - | ioSetInfo : Array ValIdx → ValIdx → ValIdx → Op - | ioRead : ValIdx → Nat → Op - | ioWrite : Array ValIdx → Op + | ioGetInfo : ValIdx → Array ValIdx → Op + | ioSetInfo : ValIdx → Array ValIdx → ValIdx → ValIdx → Op + | ioRead : ValIdx → ValIdx → Nat → Op + | ioWrite : ValIdx → Array ValIdx → Op | u8BitDecomposition : ValIdx → Op | u8ShiftLeft : ValIdx → Op | u8ShiftRight : ValIdx → Op diff --git a/Ix/Aiur/Stages/Concrete.lean b/Ix/Aiur/Stages/Concrete.lean index 9a1b5d3a..17454bd3 100644 --- a/Ix/Aiur/Stages/Concrete.lean +++ b/Ix/Aiur/Stages/Concrete.lean @@ -69,10 +69,10 @@ inductive Term : Type where | load (typ : Typ) (escapes : Bool) (a : Term) : Term | ptrVal (typ : Typ) (escapes : Bool) (a : Term) : Term | assertEq (typ : Typ) (escapes : Bool) (a : Term) (b : Term) (r : Term) : Term - | ioGetInfo (typ : Typ) (escapes : Bool) (k : Term) : Term - | ioSetInfo (typ : Typ) (escapes : Bool) (k i l r : Term) : Term - | ioRead (typ : Typ) (escapes : Bool) (i : Term) (n : Nat) : Term - | ioWrite (typ : Typ) (escapes : Bool) (d r : Term) : Term + | ioGetInfo (typ : Typ) (escapes : Bool) (c k : Term) : Term + | ioSetInfo (typ : Typ) (escapes : Bool) (c k i l r : Term) : Term + | ioRead (typ : Typ) (escapes : Bool) (c i : Term) (n : Nat) : Term + | ioWrite (typ : Typ) (escapes : Bool) (c d r : Term) : Term | u8BitDecomposition (typ : Typ) (escapes : Bool) (a : Term) : Term | u8ShiftLeft (typ : Typ) (escapes : Bool) (a : Term) : Term | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term @@ -98,8 +98,8 @@ def Term.typ : Term → Typ | .add t _ _ _ | .sub t _ _ _ | .mul t _ _ _ | .eqZero t _ _ | .proj t _ _ _ | .get t _ _ _ | .slice t _ _ _ _ | .set t _ _ _ _ | .store t _ _ | .load t _ _ | .ptrVal t _ _ - | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ - | .ioRead t _ _ _ | .ioWrite t _ _ _ + | .assertEq t _ _ _ _ | .ioGetInfo t _ _ _ | .ioSetInfo t _ _ _ _ _ _ + | .ioRead t _ _ _ _ | .ioWrite t _ _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ @@ -115,8 +115,8 @@ def Term.escapes : Term → Bool | .add _ e _ _ | .sub _ e _ _ | .mul _ e _ _ | .eqZero _ e _ | .proj _ e _ _ | .get _ e _ _ | .slice _ e _ _ _ | .set _ e _ _ _ | .store _ e _ | .load _ e _ | .ptrVal _ e _ - | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ - | .ioRead _ e _ _ | .ioWrite _ e _ _ + | .assertEq _ e _ _ _ | .ioGetInfo _ e _ _ | .ioSetInfo _ e _ _ _ _ _ + | .ioRead _ e _ _ _ | .ioWrite _ e _ _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ diff --git a/Ix/Aiur/Stages/Simple.lean b/Ix/Aiur/Stages/Simple.lean index 8b987f4b..b19486a9 100644 --- a/Ix/Aiur/Stages/Simple.lean +++ b/Ix/Aiur/Stages/Simple.lean @@ -65,10 +65,10 @@ inductive Term : Type where | load (typ : Typ) (escapes : Bool) (a : Term) : Term | ptrVal (typ : Typ) (escapes : Bool) (a : Term) : Term | assertEq (typ : Typ) (escapes : Bool) (a : Term) (b : Term) (r : Term) : Term - | ioGetInfo (typ : Typ) (escapes : Bool) (k : Term) : Term - | ioSetInfo (typ : Typ) (escapes : Bool) (k i l r : Term) : Term - | ioRead (typ : Typ) (escapes : Bool) (i : Term) (n : Nat) : Term - | ioWrite (typ : Typ) (escapes : Bool) (d r : Term) : Term + | ioGetInfo (typ : Typ) (escapes : Bool) (c k : Term) : Term + | ioSetInfo (typ : Typ) (escapes : Bool) (c k i l r : Term) : Term + | ioRead (typ : Typ) (escapes : Bool) (c i : Term) (n : Nat) : Term + | ioWrite (typ : Typ) (escapes : Bool) (c d r : Term) : Term | u8BitDecomposition (typ : Typ) (escapes : Bool) (a : Term) : Term | u8ShiftLeft (typ : Typ) (escapes : Bool) (a : Term) : Term | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term @@ -93,8 +93,8 @@ def Term.typ : Term → Typ | .add t _ _ _ | .sub t _ _ _ | .mul t _ _ _ | .eqZero t _ _ | .proj t _ _ _ | .get t _ _ _ | .slice t _ _ _ _ | .set t _ _ _ _ | .store t _ _ | .load t _ _ | .ptrVal t _ _ - | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ - | .ioRead t _ _ _ | .ioWrite t _ _ _ + | .assertEq t _ _ _ _ | .ioGetInfo t _ _ _ | .ioSetInfo t _ _ _ _ _ _ + | .ioRead t _ _ _ _ | .ioWrite t _ _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ @@ -109,8 +109,8 @@ def Term.escapes : Term → Bool | .add _ e _ _ | .sub _ e _ _ | .mul _ e _ _ | .eqZero _ e _ | .proj _ e _ _ | .get _ e _ _ | .slice _ e _ _ _ | .set _ e _ _ _ | .store _ e _ | .load _ e _ | .ptrVal _ e _ - | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ - | .ioRead _ e _ _ | .ioWrite _ e _ _ + | .assertEq _ e _ _ _ | .ioGetInfo _ e _ _ | .ioSetInfo _ e _ _ _ _ _ + | .ioRead _ e _ _ _ | .ioWrite _ e _ _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ diff --git a/Ix/Aiur/Stages/Source.lean b/Ix/Aiur/Stages/Source.lean index 19fb7e04..2e08a5ba 100644 --- a/Ix/Aiur/Stages/Source.lean +++ b/Ix/Aiur/Stages/Source.lean @@ -373,10 +373,10 @@ inductive Term | ptrVal : Term → Term | ann : Typ → Term → Term | assertEq : Term → Term → (ret : Term) → Term - | ioGetInfo : (key : Term) → Term - | ioSetInfo : (key : Term) → (idx : Term) → (len : Term) → (ret : Term) → Term - | ioRead : (idx : Term) → (len : Nat) → Term - | ioWrite : (data : Term) → (ret : Term) → Term + | ioGetInfo : (channel : Term) → (key : Term) → Term + | ioSetInfo : (channel : Term) → (key : Term) → (idx : Term) → (len : Term) → (ret : Term) → Term + | ioRead : (channel : Term) → (idx : Term) → (len : Nat) → Term + | ioWrite : (channel : Term) → (data : Term) → (ret : Term) → Term | u8BitDecomposition : Term → Term | u8ShiftLeft : Term → Term | u8ShiftRight : Term → Term diff --git a/Ix/Aiur/Stages/Typed.lean b/Ix/Aiur/Stages/Typed.lean index da403e2d..2fc3374c 100644 --- a/Ix/Aiur/Stages/Typed.lean +++ b/Ix/Aiur/Stages/Typed.lean @@ -41,10 +41,10 @@ inductive Term : Type where | load (typ : Typ) (escapes : Bool) (a : Term) : Term | ptrVal (typ : Typ) (escapes : Bool) (a : Term) : Term | assertEq (typ : Typ) (escapes : Bool) (a : Term) (b : Term) (r : Term) : Term - | ioGetInfo (typ : Typ) (escapes : Bool) (k : Term) : Term - | ioSetInfo (typ : Typ) (escapes : Bool) (k i l r : Term) : Term - | ioRead (typ : Typ) (escapes : Bool) (i : Term) (n : Nat) : Term - | ioWrite (typ : Typ) (escapes : Bool) (d r : Term) : Term + | ioGetInfo (typ : Typ) (escapes : Bool) (c k : Term) : Term + | ioSetInfo (typ : Typ) (escapes : Bool) (c k i l r : Term) : Term + | ioRead (typ : Typ) (escapes : Bool) (c i : Term) (n : Nat) : Term + | ioWrite (typ : Typ) (escapes : Bool) (c d r : Term) : Term | u8BitDecomposition (typ : Typ) (escapes : Bool) (a : Term) : Term | u8ShiftLeft (typ : Typ) (escapes : Bool) (a : Term) : Term | u8ShiftRight (typ : Typ) (escapes : Bool) (a : Term) : Term @@ -69,8 +69,8 @@ def Term.typ : Term → Typ | .add t _ _ _ | .sub t _ _ _ | .mul t _ _ _ | .eqZero t _ _ | .proj t _ _ _ | .get t _ _ _ | .slice t _ _ _ _ | .set t _ _ _ _ | .store t _ _ | .load t _ _ | .ptrVal t _ _ - | .assertEq t _ _ _ _ | .ioGetInfo t _ _ | .ioSetInfo t _ _ _ _ _ - | .ioRead t _ _ _ | .ioWrite t _ _ _ + | .assertEq t _ _ _ _ | .ioGetInfo t _ _ _ | .ioSetInfo t _ _ _ _ _ _ + | .ioRead t _ _ _ _ | .ioWrite t _ _ _ _ | .u8BitDecomposition t _ _ | .u8ShiftLeft t _ _ | .u8ShiftRight t _ _ | .u8Xor t _ _ _ | .u8Add t _ _ _ | .u8Mul t _ _ _ | .u8Sub t _ _ _ | .u8And t _ _ _ | .u8Or t _ _ _ | .u8LessThan t _ _ _ | .u32LessThan t _ _ _ @@ -85,8 +85,8 @@ def Term.escapes : Term → Bool | .add _ e _ _ | .sub _ e _ _ | .mul _ e _ _ | .eqZero _ e _ | .proj _ e _ _ | .get _ e _ _ | .slice _ e _ _ _ | .set _ e _ _ _ | .store _ e _ | .load _ e _ | .ptrVal _ e _ - | .assertEq _ e _ _ _ | .ioGetInfo _ e _ | .ioSetInfo _ e _ _ _ _ - | .ioRead _ e _ _ | .ioWrite _ e _ _ + | .assertEq _ e _ _ _ | .ioGetInfo _ e _ _ | .ioSetInfo _ e _ _ _ _ _ + | .ioRead _ e _ _ _ | .ioWrite _ e _ _ _ | .u8BitDecomposition _ e _ | .u8ShiftLeft _ e _ | .u8ShiftRight _ e _ | .u8Xor _ e _ _ | .u8Add _ e _ _ | .u8Mul _ e _ _ | .u8Sub _ e _ _ | .u8And _ e _ _ | .u8Or _ e _ _ | .u8LessThan _ e _ _ | .u32LessThan _ e _ _ diff --git a/Ix/IxVM.lean b/Ix/IxVM.lean index 1ed8bf2c..db3aa9d9 100644 --- a/Ix/IxVM.lean +++ b/Ix/IxVM.lean @@ -33,8 +33,8 @@ def entrypoints := ⟦ 0 => (), _ => let n_minus_1 = n - 1; - let (idx, len) = io_get_info([n_minus_1]); - let bytes = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, [n_minus_1]); + let bytes = #read_byte_stream(0, idx, len); let (const, rest) = get_constant(bytes); assert_eq!(load(rest), ListNode.Nil); let bytes2 = put_constant(const, store(ListNode.Nil)); @@ -140,8 +140,8 @@ def entrypoints := ⟦ 0 => (), _ => let n_minus_1 = n - 1; - let (idx, len) = io_get_info([n_minus_1]); - let bytes = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, [n_minus_1]); + let bytes = #read_byte_stream(0, idx, len); let (const, rest) = get_constant(bytes); assert_eq!(load(rest), ListNode.Nil); let bytes2 = put_constant(const, store(ListNode.Nil)); diff --git a/Ix/IxVM/Blake3.lean b/Ix/IxVM/Blake3.lean index b484d715..de8c6ad2 100644 --- a/Ix/IxVM/Blake3.lean +++ b/Ix/IxVM/Blake3.lean @@ -9,8 +9,8 @@ def blake3 := ⟦ /- # Test entrypoints -/ pub fn blake3_test() -> [[G; 4]; 8] { - let (idx, len) = io_get_info([0]); - let byte_stream = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, [0]); + let byte_stream = #read_byte_stream(0, idx, len); blake3(byte_stream) } @@ -19,8 +19,8 @@ def blake3 := ⟦ pub fn blake3_bench(num_hashes: G) -> G { let num_hashes_pred = num_hashes - 1; let key = [num_hashes_pred]; - let (idx, len) = io_get_info(key); - let byte_stream = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, key); + let byte_stream = #read_byte_stream(0, idx, len); let _ = blake3(byte_stream); match num_hashes_pred { 0 => 0, diff --git a/Ix/IxVM/ByteStream.lean b/Ix/IxVM/ByteStream.lean index c5a0043d..7cdc0099 100644 --- a/Ix/IxVM/ByteStream.lean +++ b/Ix/IxVM/ByteStream.lean @@ -10,12 +10,12 @@ def byteStream := ⟦ type U64 = [G; 8] - fn read_byte_stream(idx: G, len: G) -> ByteStream { + fn read_byte_stream(channel: G, idx: G, len: G) -> ByteStream { match len { 0 => store(ListNode.Nil), _ => - let tail = read_byte_stream(idx + 1, len - 1); - let [byte] = io_read(idx, 1); + let tail = read_byte_stream(channel, idx + 1, len - 1); + let [byte] = io_read(channel, idx, 1); store(ListNode.Cons(byte, tail)), } } diff --git a/Ix/IxVM/ClaimHarness.lean b/Ix/IxVM/ClaimHarness.lean index 2e93e9a6..25d2c9d0 100644 --- a/Ix/IxVM/ClaimHarness.lean +++ b/Ix/IxVM/ClaimHarness.lean @@ -118,7 +118,7 @@ partial def closureFrom (env : Ixon.Env) (target : Address) : Std.HashSet Addres def buildSerdeIOBuffer (ixonEnv : Ixon.Env) : Aiur.IOBuffer × Nat := ixonEnv.consts.valuesIter.fold (init := (default, 0)) fun (ioBuffer, i) c => let (_, bytes) := Ixon.Serialize.put c |>.run default - (ioBuffer.extend #[.ofNat i] (bytes.data.map .ofUInt8), i + 1) + (ioBuffer.extend 0 #[.ofNat i] (bytes.data.map .ofUInt8), i + 1) /-- Encode a `Lean.ReducibilityHints` as a single `G` per the convention Aiur's `load_constant_hint` decodes (opaque → 0, abbrev → 0xFFFFFFFF, @@ -129,16 +129,18 @@ private def hintToG : Lean.ReducibilityHints → Aiur.G | .regular n => .ofNat (min (1 + n.toNat) 0xFFFFFFFE) /-- Insert all per-address entries for `addr`s satisfying `keep` into - `ioBuffer`, following the IOBuffer convention: - - | key | value | meaning | - |------------------------|----------------|---------| - | `addr` (32 G) | const bytes | primary data; empty value = `addr` is a blob | - | `addr ++ [0]` (33 G) | raw blob bytes | referenced data (verified by Aiur via blake3) | - | `addr ++ [1]` (33 G) | single G | Defn `ReducibilityHints` encoding | - - Suffix tags use `Array.push` (O(1) amortized) rather than prefix - `++ Array` (O(n) allocation). -/ + `ioBuffer`. Each address kind lives on its own channel; the key is + always the 32-G blake3 hash, with no disambiguating suffix. + + | channel | key (32 G) | value | meaning | + |---------|------------|----------------|---------| + | 0 | `addr` | const bytes | constant data (empty marker = `addr` is a blob) | + | 1 | `addr` | raw blob bytes | referenced data (verified by Aiur via blake3) | + | 2 | `addr` | single G | Defn `ReducibilityHints` encoding | + + Blob addrs also get an empty entry on channel 0 so the kernel's + constant-vs-blob detection (`io_get_info(0, addr) ⇒ len=0`) still + works without a separate query path. -/ def addEntries (ixonEnv : Ixon.Env) (keep : Address → Bool) (ioBuffer : Aiur.IOBuffer) : Aiur.IOBuffer := Id.run do let mut ioBuffer := ioBuffer @@ -146,19 +148,18 @@ def addEntries (ixonEnv : Ixon.Env) (keep : Address → Bool) if !keep addr then continue let (_, bytes) := Ixon.Serialize.put c |>.run default let key : Array Aiur.G := addr.hash.data.map .ofUInt8 - ioBuffer := ioBuffer.extend key (bytes.data.map .ofUInt8) + ioBuffer := ioBuffer.extend 0 key (bytes.data.map .ofUInt8) for (addr, rawBytes) in ixonEnv.blobs do if !keep addr then continue - let hashKey : Array Aiur.G := addr.hash.data.map .ofUInt8 - ioBuffer := ioBuffer.extend (hashKey.push 0) - (rawBytes.data.map fun b => .ofNat b.toNat) - ioBuffer := ioBuffer.extend hashKey #[] + let key : Array Aiur.G := addr.hash.data.map .ofUInt8 + ioBuffer := ioBuffer.extend 1 key (rawBytes.data.map fun b => .ofNat b.toNat) + ioBuffer := ioBuffer.extend 0 key #[] for (_, named) in ixonEnv.named do if !keep named.addr then continue match named.constMeta with | .defn _ _ hints _ _ _ _ _ => - let hashKey : Array Aiur.G := named.addr.hash.data.map .ofUInt8 - ioBuffer := ioBuffer.extend (hashKey.push 1) #[hintToG hints] + let key : Array Aiur.G := named.addr.hash.data.map .ofUInt8 + ioBuffer := ioBuffer.extend 2 key #[hintToG hints] | _ => pure () return ioBuffer @@ -190,7 +191,7 @@ private def seedTreeAt (root : Address) match trees.get? root with | some tree => let bytes := Ix.AssumptionTree.ser tree - .ok (ioBuffer.extend (addrKey tree.root) (bytes.data.map .ofUInt8)) + .ok (ioBuffer.extend 0 (addrKey tree.root) (bytes.data.map .ofUInt8)) | none => .error s!"no assumption tree supplied for root {root}" /-- Build the witness for `verify_claim` against `claim`. @@ -221,7 +222,7 @@ def buildClaimWitness (env : Ixon.Env) (claim : Ix.Claim) let claimBytes := Ix.Claim.ser claim let digestKey := addrKey (Address.blake3 claimBytes) let mut ioBuffer : Aiur.IOBuffer := default - ioBuffer := ioBuffer.extend digestKey (claimBytes.data.map .ofUInt8) + ioBuffer := ioBuffer.extend 0 digestKey (claimBytes.data.map .ofUInt8) let seedAsm asm buf := match asm with | some r => seedTreeAt r trees buf | none => .ok buf diff --git a/Ix/IxVM/Ingress.lean b/Ix/IxVM/Ingress.lean index 299df997..58a0c556 100644 --- a/Ix/IxVM/Ingress.lean +++ b/Ix/IxVM/Ingress.lean @@ -10,8 +10,8 @@ def ingress := ⟦ -- Load a constant from IOBuffer by address, verify blake3, deserialize fn load_verified_constant(addr: Addr) -> Constant { let raw = load(addr); - let (idx, len) = io_get_info(raw); - let bytes = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, raw); + let bytes = #read_byte_stream(0, idx, len); let h = blake3(bytes); assert_eq!( [ @@ -55,48 +55,26 @@ def ingress := ⟦ } } - -- Load reducibility hint G for a Defn at `addr`. Stored under suffixed - -- key `addr ++ [1]` (suffix tag 1 = metadata-tier). Encoding (mirror - -- Lean.ReducibilityHints): + -- Load reducibility hint G for a Defn at `addr`. Lives on channel 2. + -- Encoding (mirror Lean.ReducibilityHints): -- 0 = Opaque -- 1 + h = Regular(h) -- 0xFFFFFFFF = Abbrev - -- If absent (no entry under suffixed key), defaults to 1 (Regular(0)). + -- Caller MUST only invoke this for Defn addrs; the harness only seeds + -- channel 2 for defns. A missing key aborts execution (correct). fn load_constant_hint(addr: Addr) -> G { - let [a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15, - a16, a17, a18, a19, a20, a21, a22, a23, - a24, a25, a26, a27, a28, a29, a30, a31] = load(addr); - let key = [a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15, - a16, a17, a18, a19, a20, a21, a22, a23, - a24, a25, a26, a27, a28, a29, a30, a31, 1]; - let (idx, len) = io_get_info(key); - match len { - 0 => 1, - _ => - let bytes = #read_byte_stream(idx, len); - match load(bytes) { - ListNode.Cons(b, _) => b, - ListNode.Nil => 1, - }, + let (idx, len) = io_get_info(2, load(addr)); + let bytes = #read_byte_stream(2, idx, len); + match load(bytes) { + ListNode.Cons(b, _) => b, } } -- Load a blob from IOBuffer by address, verify blake3, return raw bytes. - -- Blobs are stored under key `addr ++ [0]` (suffix tag 0 = referenced - -- data) so they don't collide with constants stored at bare `addr`. + -- Blobs live on channel 1; constants live on channel 0 with the same key. fn load_verified_blob(addr: Addr) -> ByteStream { - let [a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15, - a16, a17, a18, a19, a20, a21, a22, a23, - a24, a25, a26, a27, a28, a29, a30, a31] = load(addr); - let blob_key = [a0, a1, a2, a3, a4, a5, a6, a7, - a8, a9, a10, a11, a12, a13, a14, a15, - a16, a17, a18, a19, a20, a21, a22, a23, - a24, a25, a26, a27, a28, a29, a30, a31, 0]; - let (idx, len) = io_get_info(blob_key); - let bytes = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(1, load(addr)); + let bytes = #read_byte_stream(1, idx, len); let h = blake3(bytes); assert_eq!( [ @@ -1233,7 +1211,7 @@ def ingress := ⟦ -- Check if this address has constant data in IOBuffer. -- io_get_info is unconstrained; the prover provides (0, 0) for blob addresses. -- Soundness: if the prover lies and skips a real constant, type checking will fail. - let (_, len) = io_get_info(load(addr)); + let (_, len) = io_get_info(0, load(addr)); match len { 0 => -- Blob address: skip (blob values are loaded on demand in build_lit_blobs) diff --git a/Ix/IxVM/Kernel/Claim.lean b/Ix/IxVM/Kernel/Claim.lean index 26c7b008..fa9903fb 100644 --- a/Ix/IxVM/Kernel/Claim.lean +++ b/Ix/IxVM/Kernel/Claim.lean @@ -439,8 +439,8 @@ def claim := ⟦ -- `load_verified_constant`: read bytes, recompute blake3, assert -- equality, deserialize, assert no trailing data. fn load_verified_claim(digest: [G; 32]) -> Claim { - let (idx, len) = io_get_info(digest); - let bytes = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, digest); + let bytes = #read_byte_stream(0, idx, len); let h = blake3(bytes); assert_eq!( [ @@ -780,8 +780,8 @@ def claim := ⟦ fn load_assumption_tree(root: Addr) -> List‹Addr› { let raw = load(root); - let (idx, len) = io_get_info(raw); - let bytes = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, raw); + let bytes = #read_byte_stream(0, idx, len); let (tag, s) = get_tag4(bytes); let (flag, size) = tag; assert_eq!(flag, 0xE); diff --git a/Ix/IxVM/Sha256.lean b/Ix/IxVM/Sha256.lean index 54ed3d65..80eb73e5 100644 --- a/Ix/IxVM/Sha256.lean +++ b/Ix/IxVM/Sha256.lean @@ -9,8 +9,8 @@ def sha256 := ⟦ /- # Test entrypoints -/ pub fn sha256_test() -> [[G; 4]; 8] { - let (idx, len) = io_get_info([0]); - let byte_stream = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, [0]); + let byte_stream = #read_byte_stream(0, idx, len); sha256(byte_stream) } @@ -19,8 +19,8 @@ def sha256 := ⟦ pub fn sha256_bench(num_hashes: G) -> G { let num_hashes_pred = num_hashes - 1; let key = [num_hashes_pred]; - let (idx, len) = io_get_info(key); - let byte_stream = #read_byte_stream(idx, len); + let (idx, len) = io_get_info(0, key); + let byte_stream = #read_byte_stream(0, idx, len); let _ = sha256(byte_stream); match num_hashes_pred { 0 => 0, diff --git a/Tests/Aiur/Aiur.lean b/Tests/Aiur/Aiur.lean index a6019943..7836d763 100644 --- a/Tests/Aiur/Aiur.lean +++ b/Tests/Aiur/Aiur.lean @@ -304,11 +304,17 @@ def toplevel := ⟦ --------------------------------------------------------------------------- -- IO --------------------------------------------------------------------------- + -- Exercises channel disambiguation: same key #[0] on channels 0 and 1 + -- resolves to distinct (idx, len) and arenas. Reads from each, writes + -- the concatenation back to channel 2, and registers `[1]` on channel 0. pub fn read_write_io() { - let (idx, len) = io_get_info([0]); - let xs: [G; 4] = io_read(idx, 4); - io_write(xs); - io_set_info([1], idx, len + 4); + let (idx_a, len_a) = io_get_info(0, [0]); + let (idx_b, _len_b) = io_get_info(1, [0]); + let xs: [G; 4] = io_read(0, idx_a, 4); + let ys: [G; 4] = io_read(1, idx_b, 4); + io_write(2, xs); + io_write(2, ys); + io_set_info(0, [1], idx_a, len_a + 4); } --------------------------------------------------------------------------- @@ -862,8 +868,15 @@ def aiurTestCases : List AiurTestCase := [ -- IO { functionName := `read_write_io - inputIOBuffer := ⟨#[1, 2, 3, 4], .ofList [(#[0], ⟨0, 4⟩)]⟩ - expectedIOBuffer := ⟨#[1, 2, 3, 4, 1, 2, 3, 4], .ofList [(#[0], ⟨0, 4⟩), (#[1], ⟨0, 8⟩)]⟩ }, + inputIOBuffer := + ⟨.ofList [(0, #[1, 2, 3, 4]), (1, #[5, 6, 7, 8])], + .ofList [((0, #[0]), ⟨0, 4⟩), ((1, #[0]), ⟨0, 4⟩)]⟩ + expectedIOBuffer := + ⟨.ofList [(0, #[1, 2, 3, 4]), + (1, #[5, 6, 7, 8]), + (2, #[1, 2, 3, 4, 5, 6, 7, 8])], + .ofList [((0, #[0]), ⟨0, 4⟩), ((1, #[0]), ⟨0, 4⟩), + ((0, #[1]), ⟨0, 8⟩)]⟩ }, -- Byte operations .noIO `shr_shr_shl_decompose #[87] #[0, 1, 0, 1, 0, 1, 0, 0], diff --git a/Tests/Aiur/Cross.lean b/Tests/Aiur/Cross.lean index fc8ac19b..d3debaa1 100644 --- a/Tests/Aiur/Cross.lean +++ b/Tests/Aiur/Cross.lean @@ -68,7 +68,7 @@ def toplevel : Source.Toplevel := ⟦ -- Assertions / IO / pointer ops pub fn assert_same(x: G, y: G) -> G { assert_eq!(x, y); x } pub fn io_roundtrip(x: G) -> [G; 1] { - io_write([x]); io_read(0, 1) + io_write(0, [x]); io_read(0, 0, 1) } pub fn ptr_index(x: G) -> G { let p = store(x); @@ -261,11 +261,17 @@ def toplevel : Source.Toplevel := ⟦ } -- Full IO: get_info + read + write + set_info + -- Exercises channel disambiguation: same key #[0] on channels 0 and 1 + -- resolves to distinct (idx, len) and arenas. Reads from each, writes + -- the concatenation back to channel 2, and registers `[1]` on channel 0. pub fn read_write_io() { - let (idx, len) = io_get_info([0]); - let xs: [G; 4] = io_read(idx, 4); - io_write(xs); - io_set_info([1], idx, len + 4); + let (idx_a, len_a) = io_get_info(0, [0]); + let (idx_b, _len_b) = io_get_info(1, [0]); + let xs: [G; 4] = io_read(0, idx_a, 4); + let ys: [G; 4] = io_read(1, idx_b, 4); + io_write(2, xs); + io_write(2, ys); + io_set_info(0, [1], idx_a, len_a + 4); } -- u8 shifts + bit decomposition chain @@ -1249,7 +1255,8 @@ def tests : TestSeq := runAgreement "ntm_large(5)" "ntm_large" [5] ++ runAgreement "ntm_shape_let" "ntm_shape_let" [] ++ runAgreement "read_write_io" "read_write_io" [] - (io := { data := #[1, 2, 3, 4], map := .ofList [(#[0], ⟨0, 4⟩)] }) ++ + (io := { data := .ofList [(0, #[1, 2, 3, 4]), (1, #[5, 6, 7, 8])], + map := .ofList [((0, #[0]), ⟨0, 4⟩), ((1, #[0]), ⟨0, 4⟩)] }) ++ runAgreement "template_basic" "template_basic" [] ++ runAgreement "template_unwrap_some" "template_unwrap_some" [] ++ runAgreement "template_unwrap_none" "template_unwrap_none" [] ++ diff --git a/Tests/Aiur/Hashes.lean b/Tests/Aiur/Hashes.lean index bfd95ccd..817b6351 100644 --- a/Tests/Aiur/Hashes.lean +++ b/Tests/Aiur/Hashes.lean @@ -12,7 +12,9 @@ def mkBlake3HashTestCase (size : Nat) : AiurTestCase := let outputBytes := Blake3.Rust.hash ⟨inputBytes⟩ |>.val.data let input := inputBytes.map .ofUInt8 let output := outputBytes.map .ofUInt8 - let buffer := ⟨input, .ofList [(#[0], ⟨0, size⟩)]⟩ -- key is fixed as #[0] + let buffer : Aiur.IOBuffer := + ⟨.ofList [(0, input)], .ofList [((0, #[0]), ⟨0, size⟩)]⟩ + -- channel 0; key fixed as #[0] { functionName := `blake3_test, label := s!"blake3 (size={size})" expectedOutput := output, inputIOBuffer := buffer, expectedIOBuffer := buffer interpret := false } @@ -22,7 +24,9 @@ def mkSha256HashTestCase (size : Nat) : AiurTestCase := let outputBytes := Sha256.hash ⟨inputBytes⟩ |>.data let input := inputBytes.map .ofUInt8 let output := outputBytes.map .ofUInt8 - let buffer := ⟨input, .ofList [(#[0], ⟨0, size⟩)]⟩ -- key is fixed as #[0] + let buffer : Aiur.IOBuffer := + ⟨.ofList [(0, input)], .ofList [((0, #[0]), ⟨0, size⟩)]⟩ + -- channel 0; key fixed as #[0] { functionName := `sha256_test, label := s!"sha256 (size={size})" expectedOutput := output, inputIOBuffer := buffer, expectedIOBuffer := buffer interpret := false } diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 11288c07..8d9f65b1 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -43,10 +43,10 @@ pub enum Op { Store(Vec), Load(usize, ValIdx), AssertEq(Vec, Vec), - IOGetInfo(Vec), - IOSetInfo(Vec, ValIdx, ValIdx), - IORead(ValIdx, usize), - IOWrite(Vec), + IOGetInfo(ValIdx, Vec), + IOSetInfo(ValIdx, Vec, ValIdx, ValIdx), + IORead(ValIdx, ValIdx, usize), + IOWrite(ValIdx, Vec), U8BitDecomposition(ValIdx), U8ShiftLeft(ValIdx), U8ShiftRight(ValIdx), diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index d3d17aaf..333cfa14 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -452,11 +452,11 @@ impl Op { state.constraints.zeros.push(sel.clone() * (x.clone() - y.clone())); } }, - Op::IOGetInfo(_) => (0..2).for_each(|_| { + Op::IOGetInfo(_, _) => (0..2).for_each(|_| { let col = state.next_auxiliary(); state.map.push((col, 1)); }), - Op::IORead(_, len) => (0..*len).for_each(|_| { + Op::IORead(_, _, len) => (0..*len).for_each(|_| { let col = state.next_auxiliary(); state.map.push((col, 1)); }), @@ -661,7 +661,7 @@ impl Op { let output = Expr::ONE - carry; state.map.push((output, 1)); }, - Op::IOSetInfo(..) | Op::IOWrite(_) | Op::Debug(..) => (), + Op::IOSetInfo(..) | Op::IOWrite(..) | Op::Debug(..) => (), } } } diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 25ee18fb..00fb8b29 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -50,36 +50,50 @@ pub(crate) struct IOKeyInfo { } pub struct IOBuffer { - pub(crate) data: Vec, - pub(crate) map: FxHashMap, IOKeyInfo>, + /// Per-channel data arenas. `idx` slots into `data[&channel]`. + pub(crate) data: FxHashMap>, + /// Channel-keyed info map; same `key` on different channels resolves + /// to distinct `IOKeyInfo`. + pub(crate) map: FxHashMap<(G, Vec), IOKeyInfo>, } impl IOBuffer { #[inline] - pub(crate) fn get_info(&self, key: &[G]) -> Result<&IOKeyInfo, ExecError> { - self.map.get(key).ok_or(ExecError::InvalidIOKey) + pub(crate) fn get_info( + &self, + channel: G, + key: &[G], + ) -> Result<&IOKeyInfo, ExecError> { + self.map.get(&(channel, key.to_vec())).ok_or(ExecError::InvalidIOKey) } fn set_info( &mut self, + channel: G, key: Vec, idx: usize, len: usize, ) -> Result<(), ExecError> { - let Entry::Vacant(e) = self.map.entry(key) else { + let Entry::Vacant(e) = self.map.entry((channel, key)) else { return Err(ExecError::IOMappingAlreadySet); }; e.insert(IOKeyInfo { idx, len }); Ok(()) } #[inline] - pub(crate) fn read(&self, idx: usize, len: usize) -> Result<&[G], ExecError> { - self - .data + pub(crate) fn read( + &self, + channel: G, + idx: usize, + len: usize, + ) -> Result<&[G], ExecError> { + let empty: &[G] = &[]; + let arena = self.data.get(&channel).map_or(empty, |v| v.as_slice()); + arena .get(idx..idx.saturating_add(len)) .ok_or(ExecError::IOReadOutOfBounds { idx, len }) } - fn write(&mut self, data: impl Iterator) { - self.data.extend(data) + fn write(&mut self, channel: G, data: impl Iterator) { + self.data.entry(channel).or_default().extend(data) } } @@ -297,30 +311,34 @@ impl Function { } } }, - ExecEntry::Op(Op::IOGetInfo(key)) => { + ExecEntry::Op(Op::IOGetInfo(channel, key)) => { + let channel = map[*channel]; let key = key.iter().map(|v| map[*v]).collect::>(); - let IOKeyInfo { idx, len } = io_buffer.get_info(&key)?; + let IOKeyInfo { idx, len } = io_buffer.get_info(channel, &key)?; map.push(G::from_usize(*idx)); map.push(G::from_usize(*len)); }, - ExecEntry::Op(Op::IOSetInfo(key, idx, len)) => { + ExecEntry::Op(Op::IOSetInfo(channel, key, idx, len)) => { + let channel = map[*channel]; let key = key.iter().map(|v| map[*v]).collect::>(); let get = |x: &usize| { let v = map[*x].as_canonical_u64(); usize::try_from(v).ok().ok_or(ExecError::IndexTooLarge(v)) }; - io_buffer.set_info(key, get(idx)?, get(len)?)?; + io_buffer.set_info(channel, key, get(idx)?, get(len)?)?; }, - ExecEntry::Op(Op::IORead(idx, len)) => { + ExecEntry::Op(Op::IORead(channel, idx, len)) => { + let channel = map[*channel]; let idx_val = map[*idx].as_canonical_u64(); let idx = usize::try_from(idx_val) .ok() .ok_or(ExecError::IndexTooLarge(idx_val))?; - let data = io_buffer.read(idx, *len)?; + let data = io_buffer.read(channel, idx, *len)?.to_vec(); map.extend(data); }, - ExecEntry::Op(Op::IOWrite(data)) => { - io_buffer.write(data.iter().map(|v| map[*v])) + ExecEntry::Op(Op::IOWrite(channel, data)) => { + let channel = map[*channel]; + io_buffer.write(channel, data.iter().map(|v| map[*v])) }, ExecEntry::Op(Op::U8BitDecomposition(byte)) => { if unconstrained { diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 193c6d50..8e19b909 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -360,22 +360,28 @@ impl Op { let lookup = Memory::lookup(G::ONE, G::from_usize(*size), ptr, values); slice.push_lookup(index, lookup); }, - Op::IOGetInfo(key) => { + Op::IOGetInfo(channel, key) => { + let channel = map[*channel].0; let key = key.iter().map(|a| map[*a].0).collect::>(); let IOKeyInfo { idx, len } = - io_buffer.get_info(&key).expect("Invalid IO key"); + io_buffer.get_info(channel, &key).expect("Invalid IO key"); for f in [G::from_usize(*idx), G::from_usize(*len)] { map.push((f, 1)); slice.push_auxiliary(index, f); } }, - Op::IORead(idx, len) => { + Op::IORead(channel, idx, len) => { + let channel = map[*channel].0; let idx = map[*idx] .0 .as_canonical_u64() .try_into() .expect("Index is too big for an usize"); - for &f in io_buffer.read(idx, *len).expect("IO read out of bounds") { + let data = io_buffer + .read(channel, idx, *len) + .expect("IO read out of bounds") + .to_vec(); + for f in data { map.push((f, 1)); slice.push_auxiliary(index, f); } @@ -545,8 +551,10 @@ impl Op { let result = G::from_bool(a_u32 < b_u32); map.push((result, 1)); }, - Op::AssertEq(..) | Op::IOSetInfo(..) | Op::IOWrite(_) | Op::Debug(..) => { - }, + Op::AssertEq(..) + | Op::IOSetInfo(..) + | Op::IOWrite(..) + | Op::Debug(..) => {}, } } } diff --git a/src/ffi/aiur/protocol.rs b/src/ffi/aiur/protocol.rs index 772b18ef..ed2ec7b4 100644 --- a/src/ffi/aiur/protocol.rs +++ b/src/ffi/aiur/protocol.rs @@ -204,23 +204,37 @@ fn decode_io_buffer( io_data_arr: &LeanArray>, io_map_arr: &LeanArray>, ) -> IOBuffer { - let data = io_data_arr.map(|x| lean_unbox_g(&x)); + let data = decode_io_buffer_data(io_data_arr); let map = decode_io_buffer_map(io_map_arr); IOBuffer { data, map } } -/// Build a Lean `Array G × Array (Array G × IOKeyInfo)` from an `IOBuffer`. +/// Build a Lean +/// `Array (G × Array G) × Array ((G × Array G) × IOKeyInfo)` from an +/// `IOBuffer`. The first array enumerates per-channel data arenas; +/// the second is the channel-keyed info map. fn build_lean_io_buffer(io_buffer: &IOBuffer) -> LeanOwned { - let lean_io_data = build_g_array(&io_buffer.data); + let lean_io_data = { + let arr = LeanArray::alloc(io_buffer.data.len()); + for (i, (channel, arena)) in io_buffer.data.iter().enumerate() { + let channel_box = LeanOwned::box_u64(channel.as_canonical_u64()); + let arena_arr = build_g_array(arena); + let elt = LeanProd::new(channel_box, arena_arr); + arr.set(i, elt); + } + arr + }; let lean_io_map = { let arr = LeanArray::alloc(io_buffer.map.len()); - for (i, (key, info)) in io_buffer.map.iter().enumerate() { + for (i, ((channel, key), info)) in io_buffer.map.iter().enumerate() { + let channel_box = LeanOwned::box_u64(channel.as_canonical_u64()); let key_arr = build_g_array(key); + let channel_key = LeanProd::new(channel_box, key_arr); let key_info = LeanProd::new( LeanOwned::box_usize(info.idx), LeanOwned::box_usize(info.len), ); - let map_elt = LeanProd::new(key_arr, key_info); + let map_elt = LeanProd::new(channel_key, key_info); arr.set(i, map_elt); } arr @@ -252,19 +266,34 @@ fn decode_fri_parameters( } } +fn decode_io_buffer_data( + arr: &LeanArray>, +) -> FxHashMap> { + let mut data = FxHashMap::with_capacity_and_hasher(arr.len(), FxBuildHasher); + for elt in arr.iter() { + let pair = elt.as_ctor(); + let channel = lean_unbox_g(&pair.get(0)); + let arena = pair.get(1).as_array().map(|x| lean_unbox_g(&x)); + data.insert(channel, arena); + } + data +} + fn decode_io_buffer_map( arr: &LeanArray>, -) -> FxHashMap, IOKeyInfo> { +) -> FxHashMap<(G, Vec), IOKeyInfo> { let mut map = FxHashMap::with_capacity_and_hasher(arr.len(), FxBuildHasher); for elt in arr.iter() { let pair = elt.as_ctor(); - let key = pair.get(0).as_array().map(|x| lean_unbox_g(&x)); + let channel_key = pair.get(0).as_ctor(); + let channel = lean_unbox_g(&channel_key.get(0)); + let key = channel_key.get(1).as_array().map(|x| lean_unbox_g(&x)); let info_ctor = pair.get(1).as_ctor(); let info = IOKeyInfo { idx: lean_unbox_nat_as_usize(&info_ctor.get(0)), len: lean_unbox_nat_as_usize(&info_ctor.get(1)), }; - map.insert(key, info); + map.insert((channel, key), info); } map } diff --git a/src/ffi/aiur/toplevel.rs b/src/ffi/aiur/toplevel.rs index 9258ff57..a78b4ee8 100644 --- a/src/ffi/aiur/toplevel.rs +++ b/src/ffi/aiur/toplevel.rs @@ -66,24 +66,29 @@ fn decode_op(ctor: LeanCtor>) -> Op { Op::AssertEq(decode_vec_val_idx(a), decode_vec_val_idx(b)) }, 9 => { - let [key] = ctor.objs::<1>(); - Op::IOGetInfo(decode_vec_val_idx(key)) + let [channel, key] = ctor.objs::<2>(); + Op::IOGetInfo(lean_unbox_nat_as_usize(&channel), decode_vec_val_idx(key)) }, 10 => { - let [key, idx, len] = ctor.objs::<3>(); + let [channel, key, idx, len] = ctor.objs::<4>(); Op::IOSetInfo( + lean_unbox_nat_as_usize(&channel), decode_vec_val_idx(key), lean_unbox_nat_as_usize(&idx), lean_unbox_nat_as_usize(&len), ) }, 11 => { - let [idx, len] = ctor.objs::<2>(); - Op::IORead(lean_unbox_nat_as_usize(&idx), lean_unbox_nat_as_usize(&len)) + let [channel, idx, len] = ctor.objs::<3>(); + Op::IORead( + lean_unbox_nat_as_usize(&channel), + lean_unbox_nat_as_usize(&idx), + lean_unbox_nat_as_usize(&len), + ) }, 12 => { - let [data] = ctor.objs::<1>(); - Op::IOWrite(decode_vec_val_idx(data)) + let [channel, data] = ctor.objs::<2>(); + Op::IOWrite(lean_unbox_nat_as_usize(&channel), decode_vec_val_idx(data)) }, 13 => { let [byte] = ctor.objs::<1>();