Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/training/finetune/sd35/validation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"data": [
{
"caption": "a photo of a cat sitting on a windowsill",
"image_path": null,
"video_path": null,
"num_inference_steps": 28,
"height": 512,
"width": 512,
"num_frames": 1
},
{
"caption": "a beautiful landscape with mountains and a lake",
"image_path": null,
"video_path": null,
"num_inference_steps": 28,
"height": 512,
"width": 512,
"num_frames": 1
}
]
}
28 changes: 28 additions & 0 deletions fastvideo/dataset/dataloader/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,34 @@
])


pyarrow_schema_sd35 = pa.schema([
pa.field("id", pa.string()),
# --- Image VAE latents (C, T, H, W) with T=1 for SD3.5 ---
pa.field("vae_latent_bytes", pa.binary()),
pa.field("vae_latent_shape", pa.list_(pa.int64())),
pa.field("vae_latent_dtype", pa.string()),
# --- Combined text encoder output (CLIP-L + CLIP-G padded + T5) ---
# Shape: [seq_len, 4096] where seq_len = clip_seq + t5_seq
pa.field("text_embedding_bytes", pa.binary()),
pa.field("text_embedding_shape", pa.list_(pa.int64())),
pa.field("text_embedding_dtype", pa.string()),
# --- Pooled CLIP projections (CLIP-L pooled + CLIP-G pooled) ---
# Shape: [2048]
pa.field("pooled_projection_bytes", pa.binary()),
pa.field("pooled_projection_shape", pa.list_(pa.int64())),
pa.field("pooled_projection_dtype", pa.string()),
# --- Metadata ---
pa.field("file_name", pa.string()),
pa.field("caption", pa.string()),
pa.field("media_type", pa.string()),
pa.field("width", pa.int64()),
pa.field("height", pa.int64()),
pa.field("num_frames", pa.int64()),
pa.field("duration_sec", pa.float64()),
pa.field("fps", pa.float64()),
])


pyarrow_schema_matrixgame = pa.schema([
pa.field("id", pa.string()),
# --- Image/Video VAE latents ---
Expand Down
8 changes: 4 additions & 4 deletions fastvideo/fastvideo_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,25 +1003,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--fsdp-sharding-strategy", type=str, help="FSDP sharding strategy")

parser.add_argument(
"--weighting_scheme",
"--weighting-scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
)
parser.add_argument(
"--logit_mean",
"--logit-mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
"--logit-std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
"--mode-scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
Expand Down
3 changes: 2 additions & 1 deletion fastvideo/pipelines/basic/sd35/sd35_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.logger import init_logger
from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase
from fastvideo.pipelines.lora_pipeline import LoRAPipeline
from fastvideo.pipelines.stages.input_validation import InputValidationStage
from fastvideo.pipelines.stages.text_encoding import TextEncodingStage
from fastvideo.pipelines.stages.timestep_preparation import (
Expand All @@ -19,7 +20,7 @@
logger = init_logger(__name__)


class SD35Pipeline(ComposedPipelineBase):
class SD35Pipeline(LoRAPipeline, ComposedPipelineBase):
"""Minimal SD3.5 Medium text-to-image pipeline (treat as num_frames=1)."""

_required_config_modules = [
Expand Down
1 change: 1 addition & 0 deletions fastvideo/pipelines/pipeline_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class TrainingBatch:
timesteps: torch.Tensor | None = None
sigmas: torch.Tensor | None = None
noise: torch.Tensor | None = None
pooled_projections: torch.Tensor | None = None

attn_metadata_vsa: AttentionMetadata | None = None
attn_metadata: AttentionMetadata | None = None
Expand Down
Loading
Loading