Skip to content

[Feature] Add Latent Diffusion and generic latent inference wrapper for all inference networks.#656

Open
bhanuprasanna2001 wants to merge 31 commits into
bayesflow-org:devfrom
bhanuprasanna2001:ldm-dev
Open

[Feature] Add Latent Diffusion and generic latent inference wrapper for all inference networks.#656
bhanuprasanna2001 wants to merge 31 commits into
bayesflow-org:devfrom
bhanuprasanna2001:ldm-dev

Conversation

@bhanuprasanna2001
Copy link
Copy Markdown

  • Included the use of any inference network as part of the latent diffusion model.
  • Only remaining task is to update the model's name to a more suitable one.
  • Developed a notebook experiment with two moons to test all inference networks and observe their behavior in the latent space.
  • The experiment is functioning well.

@bhanuprasanna2001 bhanuprasanna2001 marked this pull request as ready for review March 17, 2026 21:37
Copilot AI review requested due to automatic review settings March 17, 2026 21:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new latent-diffusion-based inference network to BayesFlow by combining a VAE-style encoder/decoder with an existing latent-space inference network (DiffusionModel by default), along with a dedicated pytest suite and public exports.

Changes:

  • Introduces LatentDiffusionModel plus standalone Encoder/Decoder components under bayesflow.networks.inference.latent_diffusion.
  • Exposes the new classes via bayesflow.networks.inference and bayesflow.networks package exports.
  • Adds comprehensive tests for build/shape/metrics/serialization and flow-matching integration.

Reviewed changes

Copilot reviewed 10 out of 11 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
bayesflow/networks/inference/latent_diffusion/latent_diffusion_model.py Implements the composed latent diffusion inference model and training objective.
bayesflow/networks/inference/latent_diffusion/encoder.py Adds a VAE-style encoder producing (z, mean, log_var) with reparameterization.
bayesflow/networks/inference/latent_diffusion/decoder.py Adds a decoder mapping latent vectors back to the original parameter space.
bayesflow/networks/inference/latent_diffusion/__init__.py Registers and exports latent diffusion components for the subpackage.
bayesflow/networks/inference/__init__.py Re-exports LatentDiffusionModel, Encoder, and Decoder at the inference package level.
bayesflow/networks/__init__.py Re-exports the new inference components at the top-level networks package.
tests/test_networks/test_latent_diffusion_model/conftest.py Adds fixtures for latent diffusion model tests (dims, conditions, model variants).
tests/test_networks/test_latent_diffusion_model/test_latent_diffusion_model.py Tests LDM build, forward/inverse shapes, metrics, warmup, serialization, save/load.
tests/test_networks/test_latent_diffusion_model/test_encoder.py Tests encoder build/output shapes, auto latent dim, stochasticity, serialization, save/load.
tests/test_networks/test_latent_diffusion_model/test_decoder.py Tests decoder build/output shapes, required output_dim, serialization, save/load.
tests/test_networks/test_latent_diffusion_model/__init__.py Declares the new test package.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +195 to +202
# Build encoder
self.encoder.build(xz_shape)
actual_latent_dim = self.encoder.latent_dim

# Build decoder with correct output dimension
self.decoder.output_dim = input_dim
latent_shape = tuple(xz_shape[:-1]) + (actual_latent_dim,)
self.decoder.build(latent_shape)
Comment on lines +58 to +71
Network for diffusion noise prediction (only used when
``inference_network`` is None). Default is ``"time_mlp"``.
diffusion_subnet_kwargs : dict[str, any], optional
Additional arguments for diffusion subnet (only used when
``inference_network`` is None). Default is None.
noise_schedule : str, optional
Noise schedule for diffusion (only used when ``inference_network``
is None). Default is ``"cosine"``.
schedule_kwargs : dict[str, any], optional
Additional arguments for noise schedule (only used when
``inference_network`` is None). Default is None.
integrate_kwargs : dict[str, any], optional
Configuration for ODE integration during sampling (only used when
``inference_network`` is None). Default is None.
Comment on lines +172 to +178
elif isinstance(inference_network, str):
raise ValueError(
f"Unknown inference_network specification: {inference_network}. "
f"Expected 'auto' or an InferenceNetwork instance."
)
else:
self.inference_network = inference_network
Comment on lines +9 to +26
def test_build(latent_diffusion_model, random_samples, random_conditions):
xz_shape = keras.ops.shape(random_samples)
cond_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None

assert not latent_diffusion_model.built
latent_diffusion_model.build(xz_shape, conditions_shape=cond_shape)
assert latent_diffusion_model.built
assert latent_diffusion_model.variables


def test_forward_output_shape(latent_diffusion_model, random_samples, random_conditions):
xz_shape = keras.ops.shape(random_samples)
cond_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None
latent_diffusion_model.build(xz_shape, conditions_shape=cond_shape)

z = latent_diffusion_model(random_samples, conditions=random_conditions)
assert keras.ops.shape(z) == (xz_shape[0], latent_diffusion_model.latent_dim)

@@ -0,0 +1,84 @@
import keras
import numpy as np
import pytest
@@ -0,0 +1,66 @@
import keras
import numpy as np
@LarsKue LarsKue added the feature New feature or request label Mar 18, 2026
@LarsKue LarsKue added this to the BayesFlow 2.1 milestone Mar 18, 2026
@LarsKue LarsKue changed the title Latent Space support for all inference networks. Tested. [Feature] Add Latent Diffusion and generic latent inference wrapper for all inference networks. Mar 18, 2026
Copy link
Copy Markdown
Author

@bhanuprasanna2001 bhanuprasanna2001 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Added Latent Inference Network, which can be used with any Inference Network to operate in the latent space.
  • Added Latent Diffusion Model as a thin wrapper for the Latent Inference Network.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants