@@ -307,6 +307,7 @@ module type CBackend = sig
307307
308308 val flatten : reg list -> reg
309309
310+
310311 module Deps : sig
311312 type deps
312313 type block_deps
@@ -326,6 +327,8 @@ module type CBackend = sig
326327 val single_dep : deps -> bool
327328 (* Assumes single_dep *)
328329 val dep_range : deps -> int * int
330+ (* Checks if all the deps are in a given list of inputs *)
331+ val check_inputs : reg -> (int * int ) list -> bool
329332 end
330333end
331334
@@ -492,6 +495,8 @@ module TestBack : CBackend = struct
492495 let concat (r1 : reg ) (r2 : reg ) : reg = Array. append r1 r2
493496 let flatten (rs : reg list ) : reg = Array. concat rs
494497
498+
499+
495500 module Deps = struct
496501 type dep = (int , int Set .t ) Map .t
497502 type deps = dep array
@@ -607,6 +612,18 @@ module TestBack : CBackend = struct
607612 Set. iter (fun i -> Format. eprintf " %d " i) idxs;
608613 Format. eprintf " @.Min: %d | Max: %d@." (Set. min_elt idxs) (Set. max_elt idxs);
609614 (Set. min_elt idxs, Set. max_elt idxs + 1 )
615+
616+ (* Checks that all dependencies of r are in the set inps *)
617+ (* Each elements of inps is (id, width) *)
618+ let check_inputs (r : reg ) (inps : (int * int) list ) : bool =
619+ let ds = deps_of_reg r in
620+ Array. for_all (fun d ->
621+ Map. for_all (fun id b ->
622+ match List. find_opt (fun (id_ , _ ) -> id = id_) inps with
623+ | Some (_ , b_ ) -> Set. for_all (fun b -> 0 < = b && b < b_) b
624+ | None -> false
625+ ) d
626+ ) ds
610627 end
611628
612629end
@@ -1651,6 +1668,15 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
16511668 assert false
16521669 | _ -> assert false
16531670
1671+ let check_decomp_inputs ((`CBitstring r , inps ): cbitstring cfun ) : bool =
1672+ let inps = List. map (function
1673+ | {type_ = `CIBitstring w ; id} ->
1674+ (id, w)
1675+ | _ -> assert false
1676+ ) inps in
1677+ Backend.Deps. check_inputs r inps
1678+
1679+ (* FIXME: what is the last return value for? *)
16541680 let decompose (in_w : width ) (out_w : width ) ((`CBitstring r , inps ) as c : cbitstring cfun ) : cbitstring cfun list * (int * int) =
16551681 if not (is_decomposable in_w out_w c) then
16561682 let deps = Backend.Deps. block_deps_of_reg out_w r in
@@ -1664,10 +1690,13 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
16641690 let cinp = (List. hd inps) in
16651691 let cinps, renamer = split_renamer n in_w cinp in
16661692(* let renamer = fun i -> Option.bind (aligner i) renamer in *)
1667- Array. map2 (fun r inp ->
1693+ let res = Array. map2 (fun r inp ->
16681694 let r = Backend. applys renamer r in
16691695 (`CBitstring r, [inp])
16701696 ) blocks cinps |> Array. to_list, (0 ,0 )
1697+ in
1698+ if not (List. for_all check_decomp_inputs (fst res)) then assert false else
1699+ res
16711700
16721701 let permute (w : width ) (perm : (int -> int) ) ((`CBitstring r , inps ): cbitstring cfun ) : cbitstring cfun =
16731702 `CBitstring (Backend. permute w perm r), inps
0 commit comments