Skip to content

Commit 8565990

Browse files
committed
Added bdep solve + example
1 parent 8288b26 commit 8565990

9 files changed

Lines changed: 88 additions & 11 deletions

File tree

examples/mapreduce_paper.ec

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,8 @@ bdep 8 8 [w] [w1] [w1] xor_left_spec predT_W8.
106106
admit.
107107
admit.
108108
qed.
109+
110+
lemma xor_left_eq_xor_right : forall (w: W8.t), xor_left w = xor_right w.
111+
proof.
112+
bdep solve.
113+
qed.

libs/lospecs/hlaig.ml

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ end
5050
module type SMTInterface = sig
5151
val circ_equiv : ?inps:(int * int) list -> reg -> reg -> node -> bool
5252

53-
val circ_sat : node -> bool
53+
val circ_sat : ?inps:(int * int) list -> node -> bool
5454

55-
val circ_taut : node -> bool
55+
val circ_taut : ?inps:(int * int) list -> node -> bool
5656
end
5757

5858
(* TODO Add model printing for circ_sat and circ_taut *)
@@ -139,7 +139,7 @@ module MakeSMTInterface(SMT: SMTInstance) : SMTInterface = struct
139139
end
140140

141141

142-
let circ_sat (n : Aig.node) : bool =
142+
let circ_sat ?(inps: (int * int) list option) (n : Aig.node) : bool =
143143
let bvvars : SMT.bvterm Map.String.t ref = ref Map.String.empty in
144144

145145
let rec bvterm_of_node : Aig.node -> SMT.bvterm =
@@ -175,6 +175,24 @@ module MakeSMTInterface(SMT: SMTInstance) : SMTInterface = struct
175175

176176
let form = bvterm_of_node n in
177177

178+
let inps = Option.bind inps (fun l ->
179+
if List.is_empty l then None
180+
else Some l
181+
) in
182+
183+
let inps = Option.map (fun inps ->
184+
List.map (fun (id,sz) ->
185+
List.init sz (fun i -> ("BV_" ^ (id |> string_of_int) ^ "_" ^ (Printf.sprintf "%X" (i))))) inps
186+
) inps in
187+
let inps = Option.map (fun inps ->
188+
List.map (List.map (fun name -> match Map.String.find_opt name !bvvars with
189+
| Some bv -> bv
190+
| None -> SMT.bvterm_of_name 1 name)) inps) inps
191+
in
192+
let bvinp = Option.map (fun inps ->
193+
List.map (fun i -> List.reduce (SMT.bvterm_concat) i) inps) inps
194+
in
195+
178196
begin
179197
SMT.assert' @@ form;
180198
if SMT.check_sat () = true then
@@ -184,16 +202,19 @@ module MakeSMTInterface(SMT: SMTInstance) : SMTInterface = struct
184202
|> List.map (fun a -> snd a) in
185203
let term = List.reduce SMT.bvterm_concat terms in
186204
Format.eprintf "input: %a@." SMT.pp_term (SMT.get_value term);
187-
205+
Option.may (fun bvinp ->
206+
List.iteri (fun i bv ->
207+
Format.eprintf "input[%d]: %a@." i SMT.pp_term (SMT.get_value bv)
208+
) bvinp) bvinp;
188209
(* Format.eprintf "fc: %a@." SMT.pp_term (SMT.get_value bvinpt1); *)
189210
(* Format.eprintf "block: %a@." SMT.pp_term (SMT.get_value bvinpt2); *)
190211
true
191212
end
192213
else false
193214
end
194215

195-
let circ_taut (n: Aig.node) : bool =
196-
not @@ circ_sat (Aig.neg n)
216+
let circ_taut ?inps (n: Aig.node) : bool =
217+
not @@ circ_sat ?inps (Aig.neg n)
197218

198219
end
199220

src/ecCircuits.ml

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,10 @@ module type CircuitInterface = sig
189189
val circuit_uninit : env -> ty -> circuit
190190
val circuit_has_uninitialized : circuit -> bool
191191

192-
(* Circuit equivalence call, should do some processing and then call some backend *)
192+
(* Logical reasoning over circuits *)
193193
val circ_equiv : ?pcond:(cbool * (cinp list)) -> circuit -> circuit -> bool
194+
val circ_sat : circuit -> bool
195+
val circ_taut : circuit -> bool
194196

195197
(* Composition of circuit functions, should deal with inputs and call some backend *)
196198
val circuit_compose : circuit -> circuit list -> circuit
@@ -248,6 +250,8 @@ module type CBackend = sig
248250
val applys : (inp -> node option) -> reg -> reg
249251
val circuit_from_spec : Lospecs.Ast.adef -> reg list -> reg
250252
val equiv : ?inps:inp list -> pcond:node -> reg -> reg -> bool
253+
val sat : ?inps:inp list -> node -> bool
254+
val taut : ?inps:inp list -> node -> bool
251255

252256
val slice : reg -> int -> int -> reg
253257
val insert : reg -> int -> reg -> reg
@@ -395,6 +399,16 @@ module TestBack : CBackend = struct
395399
let module BWZ = (val makeBWZinterface ()) in
396400
BWZ.circ_equiv ?inps (node_list_of_reg r1) (node_list_of_reg r2) pcond
397401

402+
let sat ?(inps: inp list option) (n: node) : bool =
403+
let open HL in
404+
let module BWZ = (val makeBWZinterface ()) in
405+
BWZ.circ_sat ?inps n
406+
407+
let taut ?(inps: inp list option) (n: node) : bool =
408+
let open HL in
409+
let module BWZ = (val makeBWZinterface ()) in
410+
BWZ.circ_taut ?inps n
411+
398412
let slice (r: reg) (idx: int) (len: int) : reg =
399413
Array.sub r idx len
400414

@@ -1397,7 +1411,28 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
13971411
| `CBool b1, `CBool b2 ->
13981412
Backend.equiv ~inps ~pcond:pcc (Backend.reg_of_node b1) (Backend.reg_of_node b2)
13991413
| _ -> false
1414+
1415+
let circ_sat (c: circuit) : bool =
1416+
let `CBool c, inps = cbool_of_circuit ~strict:false c in
1417+
let inps = List.map (function
1418+
| { type_ = `CIBool; id } -> (id, 1)
1419+
| { type_ = `CIBitstring w; id } -> (id, w)
1420+
| { type_ = `CIArray (w1, w2); id } -> (id, w1*w2)
1421+
| { type_ = `CITuple szs; id } -> (id, List.sum szs)
1422+
1423+
) inps in
1424+
Backend.sat ~inps c
14001425

1426+
let circ_taut (c: circuit) : bool =
1427+
let `CBool c, inps = cbool_of_circuit ~strict:false c in
1428+
let inps = List.map (function
1429+
| { type_ = `CIBool; id } -> (id, 1)
1430+
| { type_ = `CIBitstring w; id } -> (id, w)
1431+
| { type_ = `CIArray (w1, w2); id } -> (id, w1*w2)
1432+
| { type_ = `CITuple szs; id } -> (id, List.sum szs)
1433+
1434+
) inps in
1435+
Backend.taut ~inps c
14011436

14021437
module CircuitSpec = struct
14031438
let circuit_from_spec env (c : [`Path of path | `Bind of EcDecl.crb_circuit ] ) : [> `CBitstring of cbitstring_type] cfun =
@@ -1788,10 +1823,8 @@ let circuit_of_form
17881823
begin match EcFol.op_kind (destr_op f_ |> fst) with
17891824
| Some `True ->
17901825
hyps, (circuit_true :> circuit)
1791-
(*hyps, {circ = BWCirc([C.true_]); inps=[]}*)
17921826
| Some `False ->
17931827
hyps, (circuit_false :> circuit)
1794-
(*hyps, {circ = BWCirc([C.false_]); inps=[]}*)
17951828
| _ ->
17961829
let err = Format.sprintf "Unsupported op kind%s@." (EcPath.tostring pth) in
17971830
raise (CircError err)
@@ -1848,8 +1881,8 @@ let circuit_of_form
18481881
let cache = open_circ_lambda_cache env cache binds in
18491882
let hyps, circ = doit cache hyps f in
18501883
begin match qnt with
1851-
| Llambda -> hyps, close_circ_lambda env binds circ
18521884
| Lforall
1885+
| Llambda -> hyps, close_circ_lambda env binds circ
18531886
| Lexists -> raise (CircError "Universal/Existential quantification not supported ")
18541887
(* TODO: figure out how to handle quantifiers *)
18551888
end
@@ -2078,6 +2111,9 @@ let circ_equiv ?(pcond: circuit option) c1 c2 =
20782111
in
20792112
circ_equiv ?pcond c1 c2
20802113

2114+
let circ_sat = circ_sat
2115+
let circ_taut = circ_taut
2116+
20812117
let circuit_permute (bsz: int) (perm: int -> int) (c: circuit) : circuit =
20822118
let c = match c with
20832119
| (`CBitstring r, inps) as c -> c

src/ecCircuits.mli

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ val circuit_permute : int -> (int -> int) -> circuit -> circuit
3636
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list
3737

3838
(* Use circuits *)
39-
val compute : sign:bool -> circuit -> BI.zint list -> BI.zint
39+
val compute : sign:bool -> circuit -> BI.zint list -> BI.zint
4040
val circ_equiv : ?pcond:circuit -> circuit -> circuit -> bool
41+
val circ_sat : circuit -> bool
42+
val circ_taut : circuit -> bool
4143

4244
(* Generate circuits *)
4345
(* Form processors *)

src/ecHiTacticals.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ and process1_phl (_ : ttenv) (t : phltactic located) (tc : tcenv1) =
235235
| Pbdep bdinfo -> EcPhlBDep.process_bdep bdinfo
236236
| Pbdepeval bdeinfo -> EcPhlBDep.process_bdep_eval bdeinfo
237237
| Pbdepeq bdeinfo -> EcPhlBDep.process_bdepeq bdeinfo
238+
| Pbdepsolve -> EcPhlBDep.t_bdep_solve
238239
| Pcirc (invs, f, v) -> EcPhlBDep.process_bdep_form invs f v
239240
| Prwprgm infos -> EcPhlRwPrgm.process_rw_prgm infos
240241
in

src/ecParser.mly

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3264,6 +3264,9 @@ bdepeq_out_info:
32643264
| BDEP BITSTRING invs=bracket(bd_vars) f=bracket(form) v=lident
32653265
{ Pcirc (invs, f, (`Var v :> bdepvar)) }
32663266

3267+
| BDEP SOLVE
3268+
{ Pbdepsolve }
3269+
32673270
bdhoare_split:
32683271
| b1=sform b2=sform b3=sform?
32693272
{ BDH_split_bop (b1,b2,b3) }

src/ecParsetree.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ type phltactic =
831831
| Pbdepeval of bdep_eval_info
832832
| Pbdepeq of bdepeq_info
833833
| Pcirc of (bdepvar list * pformula * bdepvar)
834+
| Pbdepsolve
834835

835836
(* Program rewriting *)
836837
| Prwprgm of rwprgm

src/phl/ecPhlBDep.ml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,4 +1000,10 @@ let process_bdep_eval (bdeinfo: bdep_eval_info) (tc: tcenv1) =
10001000
let tc = EcPhlConseq.t_hoareS_conseq_nm pre post tc in
10011001
FApi.t_last (t_bdep_eval n m inpvs outvs lane frange sign) tc
10021002

1003+
let t_bdep_solve
1004+
(tc : tcenv1) =
1005+
if circ_taut (circuit_of_form (FApi.tc1_hyps tc) (FApi.tc1_goal tc)) then
1006+
FApi.close (!@ tc) VBdep
1007+
else
1008+
tc_error (FApi.tc1_penv tc) "Failed to solve goal through circuit reasoning@\n"
10031009

src/phl/ecPhlBDep.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ val t_bdep : int -> int -> variable list -> variable list -> psymbol -> psymbol
1313
val t_bdepeq : variable list * variable list -> int -> (int * variable list * variable list) list -> form option -> bool -> tcenv1 -> tcenv
1414

1515
val t_bdep_eval : int -> int -> variable list -> variable list -> psymbol -> form list -> bool -> tcenv1 -> tcenv
16+
17+
val t_bdep_solve : tcenv1 -> tcenv
1618

1719
val process_bdep : bdep_info -> tcenv1 -> tcenv
1820

0 commit comments

Comments
 (0)