Skip to content

Commit 9f22ece

Browse files
committed
Added input alignment check for bdep mapreduce
1 parent 46886fd commit 9f22ece

1 file changed

Lines changed: 30 additions & 1 deletion

File tree

src/ecCircuits.ml

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ module type CBackend = sig
307307

308308
val flatten : reg list -> reg
309309

310+
310311
module Deps : sig
311312
type deps
312313
type block_deps
@@ -326,6 +327,8 @@ module type CBackend = sig
326327
val single_dep : deps -> bool
327328
(* Assumes single_dep *)
328329
val dep_range : deps -> int * int
330+
(* Checks if all the deps are in a given list of inputs *)
331+
val check_inputs : reg -> (int * int) list -> bool
329332
end
330333
end
331334

@@ -492,6 +495,8 @@ module TestBack : CBackend = struct
492495
let concat (r1: reg) (r2: reg) : reg = Array.append r1 r2
493496
let flatten (rs: reg list) : reg = Array.concat rs
494497

498+
499+
495500
module Deps = struct
496501
type dep = (int, int Set.t) Map.t
497502
type deps = dep array
@@ -607,6 +612,18 @@ module TestBack : CBackend = struct
607612
Set.iter (fun i -> Format.eprintf "%d " i) idxs;
608613
Format.eprintf "@.Min: %d | Max: %d@." (Set.min_elt idxs) (Set.max_elt idxs);
609614
(Set.min_elt idxs, Set.max_elt idxs + 1)
615+
616+
(* Checks that all dependencies of r are in the set inps *)
617+
(* Each elements of inps is (id, width) *)
618+
let check_inputs (r: reg) (inps: (int * int) list) : bool =
619+
let ds = deps_of_reg r in
620+
Array.for_all (fun d ->
621+
Map.for_all (fun id b ->
622+
match List.find_opt (fun (id_, _) -> id = id_) inps with
623+
| Some (_, b_) -> Set.for_all (fun b -> 0 <= b && b < b_) b
624+
| None -> false
625+
) d
626+
) ds
610627
end
611628

612629
end
@@ -1651,6 +1668,15 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
16511668
assert false
16521669
| _ -> assert false
16531670

1671+
let check_decomp_inputs ((`CBitstring r, inps): cbitstring cfun) : bool =
1672+
let inps = List.map (function
1673+
| {type_ = `CIBitstring w; id} ->
1674+
(id, w)
1675+
| _ -> assert false
1676+
) inps in
1677+
Backend.Deps.check_inputs r inps
1678+
1679+
(* FIXME: what is the last return value for? *)
16541680
let decompose (in_w: width) (out_w: width) ((`CBitstring r, inps) as c: cbitstring cfun) : cbitstring cfun list * (int * int) =
16551681
if not (is_decomposable in_w out_w c) then
16561682
let deps = Backend.Deps.block_deps_of_reg out_w r in
@@ -1664,10 +1690,13 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
16641690
let cinp = (List.hd inps) in
16651691
let cinps, renamer = split_renamer n in_w cinp in
16661692
(* let renamer = fun i -> Option.bind (aligner i) renamer in *)
1667-
Array.map2 (fun r inp ->
1693+
let res = Array.map2 (fun r inp ->
16681694
let r = Backend.applys renamer r in
16691695
(`CBitstring r, [inp])
16701696
) blocks cinps |> Array.to_list, (0,0)
1697+
in
1698+
if not (List.for_all check_decomp_inputs (fst res)) then assert false else
1699+
res
16711700

16721701
let permute (w: width) (perm: (int -> int)) ((`CBitstring r, inps): cbitstring cfun) : cbitstring cfun =
16731702
`CBitstring (Backend.permute w perm r), inps

0 commit comments

Comments
 (0)