diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 7543ac17..199f82a5 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -11,6 +11,7 @@ import torch import torchaudio import webdataset as wds +import math from os import path from torch import nn @@ -319,8 +320,15 @@ def __getitem__(self, idx): start = random.randint(0, last_ix - self.latent_crop_length) else: start = 0 - - latents = latents[:, start:start+self.latent_crop_length] + + # Update seconds_start based on latent crop + original_length = info["seconds_total"] * ( + info["timestamps"][1] - info["timestamps"][0] + ) + seconds_per_latent = original_length / info["padding_mask"].count(1) + info["seconds_start"] += math.floor(start * seconds_per_latent) + + latents = latents[:, start : start + self.latent_crop_length] info["padding_mask"] = info["padding_mask"][start:start+self.latent_crop_length]