@@ -206,7 +206,7 @@ module type CircuitInterface = sig
206206
207207 (* Mapreduce/Dependecy analysis related functions *)
208208 val is_decomposable : int -> int -> cbitstring cfun -> bool
209- val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list
209+ val decompose : int -> int -> cbitstring cfun -> (cbitstring cfun ) list * ( int * int )
210210 val permute : int -> (int -> int ) -> cbitstring cfun -> cbitstring cfun
211211
212212 (* Wraps the backend call to deal with args/inputs *)
@@ -320,6 +320,10 @@ module type CBackend = sig
320320 val is_splittable : int -> int -> deps -> bool
321321
322322 val are_independent : block_deps -> bool
323+
324+ val single_dep : deps -> bool
325+ (* Assumes single_dep *)
326+ val dep_range : deps -> int * int
323327 end
324328end
325329
@@ -425,11 +429,14 @@ module TestBack : CBackend = struct
425429 let get (r : reg ) (idx : int ) = r.(idx)
426430
427431 let permute (w : int ) (perm : int -> int ) (r : reg ) : reg =
432+ Format. eprintf " Applying permutation to reg of size %d with block size of %d@." (size_of_reg r) w;
428433 Array. init (size_of_reg r) (fun i ->
429- let block_idx, bit_idx = (i / w), (i mod w) in
430- let idx = (perm block_idx)* w + bit_idx in
431- r.(idx)
432- )
434+ let block_idx, bit_idx = perm (i / w), (i mod w) in
435+ if block_idx < 0 then None
436+ else
437+ let idx = block_idx* w + bit_idx in
438+ Some r.(idx)
439+ ) |> Array. filter_map (fun x -> x)
433440
434441
435442 (* Node operations *)
@@ -536,17 +543,17 @@ module TestBack : CBackend = struct
536543 | 0 -> true
537544 | 1 ->
538545 let blocks = block_deps_of_deps w_out d in
539- (* Format.eprintf "Checking block width...@."; *)
546+ Format. eprintf " Checking block width...@." ;
540547 Array. for_all (fun (_ , d ) ->
541548 if Map. is_empty d then true
542549 else
543550 let _, bits = Map. any d in
544551 Set. is_empty bits ||
545552 let base = Set. at_rank_exn 0 bits in
546- (* Format.eprintf "Base for current block: %d@." base; *)
553+ Format. eprintf " Base for current block: %d@." base;
547554 Set. for_all (fun bit ->
548555 let dist = bit - base in
549- (* Format.eprintf "Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in; *)
556+ Format. eprintf " Current bit: %d | Current dist: %d | Limit: %d@." bit dist w_in;
550557 0 < = dist && dist < w_in
551558 ) bits
552559 ) blocks
@@ -576,6 +583,28 @@ module TestBack : CBackend = struct
576583 true
577584 with BreakOut ->
578585 false
586+
587+
588+ let single_dep (d : deps ) : bool =
589+ match Set. cardinal
590+ (Array. fold_left (Set. union) Set. empty
591+ (Array. map (fun dep -> Map. keys dep |> Set. of_enum) d))
592+ with
593+ | 0 | 1 -> true
594+ | _ -> false
595+
596+ (* Assumes single_dep, returns range (bot, top) such that valid idxs are bot <= i < top *)
597+ let dep_range (d : deps ) : int * int =
598+ assert (single_dep d);
599+ let idxs =
600+ Array. fold_left (fun acc d ->
601+ Set. union (Map. fold Set. union d Set. empty) acc) Set. empty d
602+ in
603+ Format. eprintf " %a@." pp_deps d;
604+ Format. eprintf " Dep range for dependencies:@." ;
605+ Set. iter (fun i -> Format. eprintf " %d " i) idxs;
606+ Format. eprintf " @.Min: %d | Max: %d@." (Set. min_elt idxs) (Set. max_elt idxs);
607+ (Set. min_elt idxs, Set. max_elt idxs + 1 )
579608 end
580609
581610end
@@ -1272,7 +1301,7 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
12721301 let array_oflist (circs : circuit list ) (dfl : circuit ) (len : int ) : circuit =
12731302 let circs, inps = List. split circs in
12741303 let dif = len - List. length circs in
1275- Format. eprintf " Len, Dif in array_oflist: %d, %d@." len dif;
1304+ (* Format.eprintf "Len, Dif in array_oflist: %d, %d@." len dif; *)
12761305 let circs = circs @ (List. init dif (fun _ -> fst dfl)) in
12771306 let inps = if dif > 0 then inps @ [snd dfl] else inps in
12781307 let circs = List. map
@@ -1518,14 +1547,32 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15181547 (* For more complex circuits, we might be able to simulate this with a int -> (int, int) map *)
15191548 let is_decomposable (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : bool =
15201549 match inps with
1521- | {type_ =`CIBitstring w } :: [] when w mod in_w = 0 && Backend. size_of_reg r mod out_w = 0 ->
1550+ | {type_ =`CIBitstring w } :: [] when ( Backend. size_of_reg r mod out_w = 0 ) ->
15221551 let deps = Backend.Deps. deps_of_reg r in
1523- Backend.Deps. is_splittable in_w out_w deps
1552+ Backend.Deps. is_splittable in_w out_w deps &&
1553+ let base, top = Backend.Deps. dep_range deps in
1554+ let () = Format. eprintf " Passed backend check, checking width of deps (top - base = %d | in_w = %d)@." (top - base) in_w in
1555+ (top - base) mod in_w = 0
15241556 | _ ->
15251557 Format. eprintf " Failed decomposition type check@\n " ;
15261558 Format. eprintf " In_w: %d | Out_w : %d | Circ: %a" in_w out_w pp_circuit c;
15271559 false
15281560
1561+ (* TODO: Extend this for multiple inputs? *)
1562+ let align_renamer ((`CBitstring r , inps ) : cbitstring cfun ) : (int * int) * cinp * (Backend.inp -> Backend.inp option) =
1563+ match inps with
1564+ | [{type_ = `CIBitstring w; id}] ->
1565+ let d = Backend.Deps. deps_of_reg r in
1566+ assert (Backend.Deps. single_dep d);
1567+ let (start_idx, end_idx) as range = Backend.Deps. dep_range d in
1568+ range,
1569+ {type_ = `CIBitstring (end_idx - start_idx); id},
1570+ (fun (id_ , w ) ->
1571+ if id <> id_ then None else
1572+ if w < start_idx || w > = end_idx then None
1573+ else Some (id_, w - start_idx))
1574+ | _ -> assert false
1575+
15291576 let split_renamer (n : count ) (in_w : width ) (inp : cinp ) : (cinp array) * (Backend.inp -> cbool_type option) =
15301577 match inp with
15311578 | {type_ = `CIBitstring w ; id} when w mod in_w = 0 ->
@@ -1535,9 +1582,12 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15351582 if id <> id_ then None else
15361583 let id_idx, bit_idx = (w / in_w), (w mod in_w) in
15371584 Some (Backend. input_node ~id: ids.(id_idx) bit_idx))
1585+ | {type_ = `CIBitstring w ; id} ->
1586+ Format. eprintf " Failed to build split renamer for n=%d in_w=%d w=%d@." n in_w w;
1587+ assert false
15381588 | _ -> assert false
15391589
1540- let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list =
1590+ let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list * (int * int) =
15411591 if not (is_decomposable in_w out_w c) then
15421592 let deps = Backend.Deps. block_deps_of_reg out_w r in
15431593 Format. eprintf " Failed to decompose. in_w=%d out_w=%d Deps:@.%a" in_w out_w (Backend.Deps. pp_block_deps) deps;
@@ -1546,11 +1596,13 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
15461596 let n = (Backend. size_of_reg r) / out_w in
15471597 let blocks = Array. init n (fun i ->
15481598 Backend. slice r (i* out_w) out_w) in
1549- let cinps, renamer = split_renamer n in_w (List. hd inps) in
1599+ let range, cinp, aligner = align_renamer c in
1600+ let cinps, renamer = split_renamer n in_w cinp in
1601+ let renamer = fun i -> Option. bind (aligner i) renamer in
15501602 Array. map2 (fun r inp ->
15511603 let r = Backend. applys renamer r in
15521604 (`CBitstring r, [inp])
1553- ) blocks cinps |> Array. to_list
1605+ ) blocks cinps |> Array. to_list, range
15541606
15551607 let permute (w : width ) (perm : (int -> int) ) ((`CBitstring r , inps ): cbitstring cfun ) : cbitstring cfun =
15561608 `CBitstring (Backend. permute w perm r), inps
@@ -2164,13 +2216,13 @@ let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit =
21642216 in
21652217 (permute bsz perm c :> circuit )
21662218
2167- let circuit_mapreduce ?(perm : (int -> int) option ) (c : circuit ) (w_in : width ) (w_out : width ) : circuit list =
2219+ let circuit_mapreduce ?(perm : (int -> int) option ) (c : circuit ) (w_in : width ) (w_out : width ) : circuit list * (int * int) =
21682220 let c = match c, perm with
21692221 | (`CBitstring _ , inps ) as c , None -> c
21702222 | (`CBitstring _ , inps ) as c , Some perm -> permute w_out perm c
21712223 | _ -> assert false
21722224 in
2173- (decompose w_in w_out c :> circuit list )
2225+ (decompose w_in w_out c :> circuit list * (int * int ) )
21742226
21752227type circuit = ExampleInterface .circuit
21762228type pstate = ExampleInterface.PState .pstate
0 commit comments