-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlorenz.py
More file actions
42 lines (32 loc) · 1.02 KB
/
lorenz.py
File metadata and controls
42 lines (32 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import logging
from pathlib import Path
# import os
# os.environ['KERAS_BACKEND']='torch'
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import bayesflow as bf
import keras
import h5py
from tqdm import tqdm
logging.getLogger("bayesflow").setLevel(logging.ERROR)
np.random.seed(302)
x_train = h5py.File("traj_train.h5", "r")["traj"]
x = np.reshape(np.array(x_train), [-1, 3])
data = { "x": x, "index": np.arange(x.shape[0]) }
def prior():
return dict(index=np.random.randint(0, x.shape[0]))
def sim(index):
return dict(x=x[index])
simulator = bf.simulators.make_simulator([prior, sim])
workflow = bf.BasicWorkflow(
simulator=simulator,
summary_network=None,
inference_network=bf.networks.DiffusionModel(),
inference_variables="x",
inference_conditions=None,
)
train_data = simulator.sample(500000) # train on these samples
workflow.fit_offline(train_data, epochs=20, batch_size=64, verbose=2)
# sample 5 times:
workflow.sample(num_samples=5, conditions=dict())