Skip to content

Commit 8892c5f

Browse files
committed
feat: let-traced convenience macro
1 parent a7d8668 commit 8892c5f

2 files changed

Lines changed: 40 additions & 1 deletion

File tree

src/gen/distribution.cljc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
(ns gen.distribution
22
"Collection of protocols and functions for working with primitive
33
distributions."
4-
(:require [gen.dynamic.choice-map :as cm]
4+
(:require [gen.distribution.math.log-likelihood :as ll]
5+
[gen.dynamic.choice-map :as cm]
56
[gen.generative-function :as gf]
67
[gen.dynamic.trace :as trace])
78
#?(:clj
@@ -184,3 +185,16 @@
184185
v)`"
185186
[ctor encode decode]
186187
(comp #(->Encoded % encode decode) ctor))
188+
189+
(defn delta-distribution
190+
"Deterministic distribution"
191+
[x]
192+
(reify
193+
Sample
194+
(sample [_] x)
195+
196+
LogPDF
197+
(logpdf [_ v] (ll/delta x v))))
198+
199+
(def delta
200+
(->GenerativeFn delta-distribution))

src/gen/dynamic.cljc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,28 @@
211211
[& args]
212212
{:clj-kondo/lint-as 'clojure.core/fn}
213213
(apply gen-body args))
214+
215+
(defmacro let-traced
216+
"Similar to `clojure.core/let`, but wraps all values in [[trace!]] calls
217+
addressed via the binding symbol.
218+
219+
Example usage:
220+
221+
```clojure
222+
(def func
223+
(gen []
224+
(let-traced [a (gen.distribution/delta \"face\")
225+
b (gen.distribution/delta \"cake\")]
226+
(str a \",\" b))))
227+
228+
(into {} (gf/simulate func []))
229+
;; => {:a \"face\" :b \"cake\"}
230+
```"
231+
[bindings & body]
232+
(let [bents (partition 2 bindings)]
233+
(assert (every? symbol? (map first bents)))
234+
`(let ~(into []
235+
(mapcat (fn [[sym expr]]
236+
[sym `(trace! (quote ~sym) ~@expr)]))
237+
bents)
238+
~@body)))

0 commit comments

Comments
 (0)