Skip to content

UNet1DModel loss plateauing at ≃0.5. How to fix time blindness when training on sequential data. #13703

@elcaidaprojects-art

Description

@elcaidaprojects-art

Describe the bug

The UNet1DModel class from the diffusers library seems to not be able to converge past a certain training loss threshold(usually around 0.5 - 0.6). This issue can be easily addressed by slightly modifying the default configuration, in a way that forces the network to process the time step, by concatenating it to the input tensors. We need to set extra_in_channels = double the size of the first downblock, and set DownBlock1DNoSkip as the first down_block_type and UpBlock1DNoSkip as the last up_block_type, to prevent the network from saving the skip connections at those specific steps(which would result in a traceback).

Reproduction

from diffusers import UNet1DModel

Failing Config (Time Blind) - Loss Plataeus at 0.5

model = UNet1DModel(
sample_size=64,
in_channels=1,
out_channels=1,
layers_per_block=2,
block_out_channels=(64, 64, 128, 128, 256),
down_block_types=("DownBlock1D", "DownBlock1D", "DownBlock1D", "DownBlock1D", "DownBlock1D"),
up_block_types=("UpBlock1D", "UpBlock1D", "UpBlock1D", "UpBlock1D", "UpBlock1D"),
)

Working/Time Aware Configuration

model = UNet1DModel(
sample_size=64,
in_channels=1,
out_channels=1,
layers_per_block=2,
extra_in_channels=128. #has to be exactly twice the size of the first downblock's dimension
block_out_channels=(64, 64, 128, 128, 256),
down_block_types=("DownBlock1DNoSkip", "DownBlock1D", "DownBlock1D", "DownBlock1D", "DownBlock1D"), #no skip connection saved in the first downblock
up_block_types=("UpBlock1D", "UpBlock1D", "UpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
) #no skip connection required in the last upblock

Logs

There are no tracebacks produced and the code runs successfully. However it fails to converge past 0.5 and learn the underlying patterns of the data. Instead, it generates a safe average prediction of noise across all time steps.

System Info

diffusers version: 0.38.0

Who can help?

@sayakpaul @patrickvonplaten

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmodels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions