diff --git a/docs/upstream_cpu_image_report_draft.md b/docs/upstream_cpu_image_report_draft.md new file mode 100644 index 0000000..7f7eaba --- /dev/null +++ b/docs/upstream_cpu_image_report_draft.md @@ -0,0 +1,91 @@ +# CPU image generation now demonstrated end to end + +## Summary + +This note documents a successful CPU demonstration for `bonsai-image` on the unpacked transformer path. + +The headline result is straightforward: CPU image generation can produce coherent `128x128` outputs end to end. + +The two clearest examples are: + +- a coherent plain `ostrich` +- a clean 4-quadrant multi-color layout + +Additional validated outputs include: + +- a coherent ostrich silhouette +- a large centered red circle +- the same plain `ostrich` result in repeated successful runs + +That establishes three concrete facts: + +1. the unpacked CPU path can converge on globally coherent images +2. both geometric and object-level prompts can work on CPU +3. `128x128` is a practical resolution for validating CPU image generation behavior + +This is not a full root-cause fix report for every failure mode. It is a status report showing that CPU image generation is real, reproducible, and already capable of meaningful outputs. + +## Demonstrated capability + +For this pipeline: + +- `64x64` corresponds to a `4x4` packed latent grid +- `128x128` corresponds to an `8x8` packed latent grid + +On the unpacked transformer CPU path, `128x128` was sufficient to demonstrate: + +- global composition across many packed tokens +- successful convergence on simple object prompts +- successful convergence on large centered shapes +- successful convergence on structured color layouts + +## Validated `128x128` outputs + +On the CPU path using the unpacked transformer, these `128x128`, `4-step` prompts converged coherently: + +- plain `ostrich` +- `a large red circle centered on a pure white background, filling most of the image, flat solid color, hard clean edge, no other objects, minimalist` +- `a large black silhouette of an ostrich centered on a pure white background, full body, hard clean edge, minimalist, no other objects` +- a 4-quadrant color layout prompt + +## Suggested upstream guidance + +It may help to document the following as the current practical CPU guidance: + +1. Use `128x128+` for structure/composition validation. +2. Treat final outputs as the primary correctness signal. +3. Use the unpacked transformer CPU path as a reference-capable configuration for image bring-up. + +## Runner-side settings that matched successful runs + +Two runner-side settings were helpful in the successful CPU runs: + +1. Increasing default prompt context from `64` to `512` +2. Using `128x128` instead of `64x64` for meaningful structure/composition validation + +## Reproduction shape + +Example command shape used for successful CPU runs: + +```bash +python scripts/generate_cpu_experimental.py \ + --prompt 'ostrich' \ + --output outputs/cpu-ostrich.png \ + --height 128 \ + --width 128 \ + --steps 4 \ + --seed 7 \ + --transformer-dir models/bonsai-image-4B-ternary-unpacked/transformer +``` + +## Example output categories + +- plain ostrich +- ostrich silhouette +- large centered red circle +- 4-quadrant color layout + +## Notes + +- This note intentionally avoids host-specific details, private paths, tokens, and unrelated local setup details. +- If a PR is not the right venue, the same content could be posted as an issue or discussion instead. diff --git a/docs/upstream_issue_cpu_bringup.md b/docs/upstream_issue_cpu_bringup.md new file mode 100644 index 0000000..bfcc077 --- /dev/null +++ b/docs/upstream_issue_cpu_bringup.md @@ -0,0 +1,60 @@ +# CPU image generation works end to end on unpacked path at 128x128 + +## Summary + +CPU image generation is working end to end on the unpacked transformer path at `128x128`. + +The clearest validated outputs were: + +- plain `ostrich` +- a clean 4-quadrant multi-color layout +- a large centered red circle +- a coherent ostrich silhouette + +This is useful because it establishes that the CPU path can already do: + +- globally coherent composition +- structured color layouts +- simple object-level prompts + +## Key evidence + +Validated `128x128`, `4-step` outputs: + +- `ostrich` +- `a large red circle centered on a pure white background, filling most of the image, flat solid color, hard clean edge, no other objects, minimalist` +- `a large black silhouette of an ostrich centered on a pure white background, full body, hard clean edge, minimalist, no other objects` +- a 4-quadrant color layout prompt + +## Practical CPU guidance + +Two practical constraints stood out during bring-up: + +1. `64x64` is a poor structure/composition validation target for this pipeline. +2. Final outputs are much more reliable than intermediate-step VAE decodes when judging correctness. + +At this resolution regime: + +- `64x64` corresponds to a `4x4` packed latent grid +- `128x128` corresponds to an `8x8` packed latent grid + +Using `128x128` made the difference between misleading geometry tests and meaningful validation. + +## Reproduction shape + +Example command shape used for successful CPU runs: + +```bash +python scripts/generate_cpu_experimental.py \ + --prompt 'ostrich' \ + --output outputs/cpu-ostrich.png \ + --height 128 \ + --width 128 \ + --steps 4 \ + --seed 7 \ + --transformer-dir models/bonsai-image-4B-ternary-unpacked/transformer +``` + +## Suggested takeaway + +The unpacked CPU path appears good enough to document as a real bring-up configuration for `bonsai-image`, at least for `128x128` validation and simple-to-moderate image composition. diff --git a/scripts/autoresearch_ostrich_cpu.py b/scripts/autoresearch_ostrich_cpu.py new file mode 100644 index 0000000..6c74a1d --- /dev/null +++ b/scripts/autoresearch_ostrich_cpu.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import gc +import json +import statistics +import sys +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from pathlib import Path + +from PIL import Image +import torch + +REPO_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(REPO_ROOT / "scripts")) + +import generate_cpu_experimental as gen + + +@dataclass(frozen=True) +class Candidate: + name: str + width: int + height: int + steps: int + allow_sub128: bool + guidance: float + max_seq: int + dtype: str + text_encoder_dtype: str + threads: int + interop_threads: int + + +def candidate_grid() -> list[Candidate]: + shared = dict(guidance=1.0, threads=4) + layouts = [ + dict(width=128, height=128, allow_sub128=False), + dict(width=96, height=96, allow_sub128=True), + ] + families = [ + dict( + tag="fp16_auto_i4_s64", + dtype="float16", + text_encoder_dtype="auto", + interop_threads=4, + max_seq=64, + ), + dict( + tag="fp16_auto_i4_s16", + dtype="float16", + text_encoder_dtype="auto", + interop_threads=4, + max_seq=16, + ), + dict( + tag="fp16_auto_i1_s64", + dtype="float16", + text_encoder_dtype="auto", + interop_threads=1, + max_seq=64, + ), + dict( + tag="fp16_auto_i1_s16", + dtype="float16", + text_encoder_dtype="auto", + interop_threads=1, + max_seq=16, + ), + dict( + tag="fp32_auto_i4_s16", + dtype="float32", + text_encoder_dtype="auto", + interop_threads=4, + max_seq=16, + ), + dict( + tag="fp32_auto_i4_s64", + dtype="float32", + text_encoder_dtype="auto", + interop_threads=4, + max_seq=64, + ), + ] + + candidates: list[Candidate] = [] + for layout in layouts: + size_tag = f"{layout['width']}" + for family in families: + family_args = {key: value for key, value in family.items() if key != "tag"} + for steps in (4, 3, 2, 1): + candidates.append( + Candidate( + name=f"{size_tag}_{steps}step_{family['tag']}", + steps=steps, + **shared, + **layout, + **family_args, + ) + ) + return candidates + + +def runtime_key(candidate: Candidate) -> tuple[str, str, int, int]: + return ( + candidate.dtype, + candidate.text_encoder_dtype, + candidate.threads, + candidate.interop_threads, + ) + + +def check_image_stats(path: Path) -> tuple[int, str, str]: + img = Image.open(path).convert("RGB") + values = list(img.tobytes()) + mean = sum(values) / len(values) + std = statistics.pstdev(values) + r_vals = values[0::3] + g_vals = values[1::3] + b_vals = values[2::3] + summary = ( + f"check_image: {path.name} mean={mean:.1f} std={std:.1f} " + f"R={sum(r_vals)/len(r_vals):.1f} G={sum(g_vals)/len(g_vals):.1f} B={sum(b_vals)/len(b_vals):.1f}" + ) + if mean < 5.0: + return 1, summary, f"mean brightness {mean:.1f} < 5.0" + if std < 15.0: + return 1, summary, f"pixel std-dev {std:.1f} < 15.0" + return 0, summary, "" + + +def configure_runtime(candidate: Candidate) -> None: + torch.set_num_threads(candidate.threads) + torch.set_num_interop_threads(candidate.interop_threads) + gen.CPU_INFERENCE_DTYPE = gen.resolve_inference_dtype(candidate.dtype) + gen.TEXT_ENCODER_DTYPE = gen.resolve_text_encoder_dtype(candidate.text_encoder_dtype) + + +def load_session( + *, + candidate: Candidate, + prompt: str, + model_root: Path, + transformer_dir: Path | None, + prompt_cache_dir: Path, +) -> tuple[torch.Tensor, torch.nn.Module, torch.nn.Module, float]: + configure_runtime(candidate) + started = time.time() + print(" setup: encode prompt", flush=True) + prompt_embeds = gen.encode_prompt( + prompt, + model_root, + max_sequence_length=candidate.max_seq, + cache_dir=prompt_cache_dir, + ) + print(f" setup: prompt ready elapsed={time.time()-started:.1f}s", flush=True) + print(" setup: load VAE", flush=True) + vae = gen.AutoencoderKLFlux2.from_pretrained( + str(model_root / "vae"), + torch_dtype=gen.CPU_INFERENCE_DTYPE, + ).to("cpu").eval() + print(f" setup: VAE ready elapsed={time.time()-started:.1f}s", flush=True) + print(" setup: load transformer", flush=True) + transformer = ( + gen.load_unpacked_transformer(transformer_dir) + if transformer_dir is not None + else gen.load_dense_transformer(model_root) + ) + print(f" setup: transformer ready elapsed={time.time()-started:.1f}s", flush=True) + return prompt_embeds, vae, transformer, time.time() - started + + +def run_candidate_warm( + *, + candidate: Candidate, + prompt_embeds: torch.Tensor, + vae: torch.nn.Module, + transformer: torch.nn.Module, + seed: int, + output_dir: Path, +) -> dict[str, object]: + output_path = output_dir / f"{candidate.name}.png" + started = time.time() + image = gen.run_diffusion( + transformer, + vae, + prompt_embeds, + height=candidate.height, + width=candidate.width, + num_steps=candidate.steps, + seed=seed, + guidance=candidate.guidance, + ) + image.save(output_path) + wall = time.time() - started + quality_rc, quality_stdout, quality_stderr = check_image_stats(output_path) + return { + "candidate": asdict(candidate), + "returncode": 0, + "wall_seconds": round(wall, 1), + "quality_returncode": quality_rc, + "quality_stdout": quality_stdout, + "quality_stderr": quality_stderr, + "output_path": str(output_path), + } + + +def main() -> int: + p = argparse.ArgumentParser(description="Warm autoresearch-style ostrich CPU sweeper.") + p.add_argument("--prompt", default="ostrich") + p.add_argument("--seed", type=int, default=7) + p.add_argument("--limit", type=int) + p.add_argument("--match") + p.add_argument("--output-dir", default=str(REPO_ROOT / "outputs" / "autoresearch")) + p.add_argument("--prompt-cache-dir", default=str(REPO_ROOT / "outputs" / "prompt_cache")) + p.add_argument("--results-jsonl", default=str(REPO_ROOT / "outputs" / "autoresearch" / "results.jsonl")) + p.add_argument("--model-root", default=str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-gemlite")) + p.add_argument("--transformer-dir", default=str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-unpacked" / "transformer")) + p.add_argument("--gemlite-dense", action="store_true") + p.add_argument("--dry-run", action="store_true") + p.add_argument("--setup-only", action="store_true") + args = p.parse_args() + + output_dir = Path(args.output_dir) + prompt_cache_dir = Path(args.prompt_cache_dir) + results_jsonl = Path(args.results_jsonl) + model_root = Path(args.model_root) + transformer_dir = None if args.gemlite_dense else (Path(args.transformer_dir) if args.transformer_dir else None) + output_dir.mkdir(parents=True, exist_ok=True) + results_jsonl.parent.mkdir(parents=True, exist_ok=True) + + candidates = candidate_grid() + if args.match: + candidates = [c for c in candidates if args.match in c.name] + if args.limit is not None: + candidates = candidates[: args.limit] + if args.dry_run: + for cand in candidates: + print(json.dumps(asdict(cand), sort_keys=True)) + return 0 + + grouped: dict[tuple[str, str, int, int], list[Candidate]] = defaultdict(list) + for candidate in candidates: + grouped[runtime_key(candidate)].append(candidate) + + best: dict[str, object] | None = None + for _, group in grouped.items(): + session_ref = group[0] + print( + f"warming session dtype={session_ref.dtype} text_dtype={session_ref.text_encoder_dtype} " + f"threads={session_ref.threads} interop={session_ref.interop_threads}", + flush=True, + ) + prompt_embeds, vae, transformer, setup_seconds = load_session( + candidate=session_ref, + prompt=args.prompt, + model_root=model_root, + transformer_dir=transformer_dir, + prompt_cache_dir=prompt_cache_dir, + ) + print(f" session setup {setup_seconds:.1f}s", flush=True) + if args.setup_only: + print( + json.dumps( + { + "candidate": asdict(session_ref), + "setup_only": True, + "session_setup_seconds": round(setup_seconds, 1), + }, + sort_keys=True, + ) + ) + del prompt_embeds, vae, transformer + gc.collect() + continue + + for index, candidate in enumerate(group, 1): + print(f"[{index}/{len(group)}] {candidate.name}", flush=True) + result = run_candidate_warm( + candidate=candidate, + prompt_embeds=prompt_embeds, + vae=vae, + transformer=transformer, + seed=args.seed, + output_dir=output_dir, + ) + result["session_setup_seconds"] = round(setup_seconds, 1) + with results_jsonl.open("a") as fh: + fh.write(json.dumps(result) + "\n") + + passed = result["quality_returncode"] == 0 + if passed and (best is None or result["wall_seconds"] < best["wall_seconds"]): + best = result + print( + f" new best: {candidate.name} warm_wall={result['wall_seconds']}s " + f"setup={setup_seconds:.1f}s path={result['output_path']}", + flush=True, + ) + else: + print( + f" keep-best unchanged: quality={result['quality_returncode']} " + f"warm_wall={result['wall_seconds']}s", + flush=True, + ) + + del prompt_embeds, vae, transformer + gc.collect() + + if best is not None: + print(json.dumps(best, indent=2, sort_keys=True)) + return 0 + + print("no passing candidates") + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_cpu_server_grid.py b/scripts/benchmark_cpu_server_grid.py new file mode 100755 index 0000000..3fbedbd --- /dev/null +++ b/scripts/benchmark_cpu_server_grid.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import statistics +import time +import urllib.request +from pathlib import Path + +from PIL import Image + + +def parse_size(raw: str) -> tuple[int, int]: + width_text, height_text = raw.lower().split("x", 1) + return int(width_text), int(height_text) + + +def bench_one( + endpoint: str, + prompt: str, + output_path: Path, + width: int, + height: int, + steps: int, + seed: int, + guidance: float, + max_seq: int, + timeout: int, +) -> dict[str, object]: + payload = { + "prompt": prompt, + "output": str(output_path), + "width": width, + "height": height, + "steps": steps, + "seed": seed, + "guidance": guidance, + "max_seq": max_seq, + } + request = urllib.request.Request( + endpoint, + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + started_at = time.time() + with urllib.request.urlopen(request, timeout=timeout) as response: + result = json.loads(response.read().decode("utf-8")) + image = Image.open(output_path).convert("RGB") + values = list(image.tobytes()) + result["mean"] = round(sum(values) / len(values), 1) + result["std"] = round(statistics.pstdev(values), 1) + result["wall_client_seconds"] = round(time.time() - started_at, 1) + return result + + +def main() -> int: + parser = argparse.ArgumentParser(description="Benchmark a warm Bonsai CPU image server over a size grid.") + parser.add_argument("--endpoint", default="http://127.0.0.1:8011/generate") + parser.add_argument("--prompt", default="bonsai") + parser.add_argument("--steps", type=int, default=4) + parser.add_argument("--guidance", type=float, default=1.0) + parser.add_argument("--max-seq", type=int, default=64) + parser.add_argument("--seed-base", type=int, default=760000) + parser.add_argument("--timeout", type=int, default=1800) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("outputs/telegram"), + help="Directory for generated images.", + ) + parser.add_argument( + "sizes", + nargs="+", + help="One or more WIDTHxHEIGHT sizes, for example: 224x224 256x224 320x320", + ) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + for index, raw_size in enumerate(args.sizes): + width, height = parse_size(raw_size) + output_path = args.output_dir / f"server-bonsai-grid-{width}x{height}-{args.steps}step.png" + result = bench_one( + endpoint=args.endpoint, + prompt=args.prompt, + output_path=output_path, + width=width, + height=height, + steps=args.steps, + seed=args.seed_base + index, + guidance=args.guidance, + max_seq=args.max_seq, + timeout=args.timeout, + ) + print(json.dumps(result)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_ostrich_forward.py b/scripts/benchmark_ostrich_forward.py new file mode 100644 index 0000000..d19ef68 --- /dev/null +++ b/scripts/benchmark_ostrich_forward.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +import torch +from diffusers import Flux2Pipeline +from diffusers.pipelines.flux2.pipeline_flux2 import retrieve_timesteps + +REPO_ROOT = Path(__file__).resolve().parents[1] + +import sys + +sys.path.insert(0, str(REPO_ROOT / "scripts")) + +import generate_cpu_experimental as gen + + +class StopAfterBlocks(RuntimeError): + pass + + +def build_inputs( + *, + transformer, + prompt_embeds: torch.Tensor, + width: int, + height: int, + num_steps: int, + seed: int, +): + transformer_device = next(transformer.parameters()).device + transformer_dtype = next(transformer.parameters()).dtype + prompt_embeds = prompt_embeds.to(device=transformer_device, dtype=transformer_dtype) + text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(transformer_device) + + in_channels_latents = transformer.config.in_channels // 4 + h_lat = width // 8 + w_lat = height // 8 + noise_shape = (1, in_channels_latents * 4, h_lat // 2, w_lat // 2) + generator = torch.Generator(device="cpu").manual_seed(seed) + latents_4d = torch.randn(noise_shape, generator=generator, dtype=torch.float32).to( + device=transformer_device, dtype=transformer_dtype + ) + latent_ids = Flux2Pipeline._prepare_latent_ids(latents_4d).to(transformer_device) + latents = Flux2Pipeline._pack_latents(latents_4d) + + scheduler = gen.build_scheduler() + image_seq_len = latents.shape[1] + mu = gen._mflux_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps) + sigmas = None if getattr(scheduler.config, "use_flow_sigmas", False) else torch.linspace( + 1.0, 1.0 / num_steps, num_steps + ).numpy() + timesteps, _ = retrieve_timesteps( + scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu + ) + timestep = timesteps[0].expand(latents.shape[0]).to(latents.dtype) / 1000 + guidance = torch.full([1], 1.0, device=transformer_device, dtype=torch.float32).expand( + latents.shape[0] + ) + return { + "prompt_embeds": prompt_embeds, + "text_ids": text_ids, + "latents": latents, + "latent_ids": latent_ids, + "timestep": timestep, + "guidance": guidance, + "seq_len": int(image_seq_len), + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="Benchmark one ostrich diffusion forward pass.") + parser.add_argument("--prompt-cache", required=True) + parser.add_argument("--transformer-dir", required=True) + parser.add_argument("--dtype", choices=("float16", "float32"), default="float16") + parser.add_argument("--threads", type=int, default=4) + parser.add_argument("--interop-threads", type=int, default=4) + parser.add_argument("--width", type=int, default=64) + parser.add_argument("--height", type=int, default=64) + parser.add_argument("--seed", type=int, default=7) + parser.add_argument("--repeats", type=int, default=2) + parser.add_argument("--trace-blocks", action="store_true") + parser.add_argument("--max-blocks", type=int) + args = parser.parse_args() + + torch.set_num_threads(args.threads) + torch.set_num_interop_threads(args.interop_threads) + gen.CPU_INFERENCE_DTYPE = gen.resolve_inference_dtype(args.dtype) + + prompt_cache = Path(args.prompt_cache) + transformer_dir = Path(args.transformer_dir) + + load_start = time.time() + prompt_embeds = torch.load(prompt_cache, map_location="cpu").to(gen.CPU_INFERENCE_DTYPE) + transformer = gen.load_unpacked_transformer(transformer_dir) + load_seconds = time.time() - load_start + print(json.dumps({"phase": "load", "seconds": round(load_seconds, 1)}), flush=True) + + inputs = build_inputs( + transformer=transformer, + prompt_embeds=prompt_embeds, + width=args.width, + height=args.height, + num_steps=1, + seed=args.seed, + ) + + block_starts: dict[int, float] = {} + hooks = [] + if args.trace_blocks and hasattr(transformer, "transformer_blocks"): + for idx, block in enumerate(transformer.transformer_blocks): + def pre_hook(_module, _inputs, *, block_idx=idx): + block_starts[block_idx] = time.time() + + def post_hook(_module, _inputs, _output, *, block_idx=idx): + start = block_starts.pop(block_idx, None) + if start is None: + return + elapsed = time.time() - start + print( + json.dumps( + { + "phase": "block", + "block": block_idx, + "seconds": round(elapsed, 3), + } + ), + flush=True, + ) + if args.max_blocks is not None and (block_idx + 1) >= args.max_blocks: + raise StopAfterBlocks() + + hooks.append(block.register_forward_pre_hook(pre_hook)) + hooks.append(block.register_forward_hook(post_hook)) + + for repeat in range(1, args.repeats + 1): + start = time.time() + stopped_early = False + try: + _ = transformer( + hidden_states=inputs["latents"], + timestep=inputs["timestep"], + guidance=inputs["guidance"], + encoder_hidden_states=inputs["prompt_embeds"], + txt_ids=inputs["text_ids"], + img_ids=inputs["latent_ids"], + return_dict=False, + )[0] + except StopAfterBlocks: + stopped_early = True + elapsed = time.time() - start + print( + json.dumps( + { + "phase": "forward", + "repeat": repeat, + "threads": args.threads, + "interop_threads": args.interop_threads, + "dtype": args.dtype, + "width": args.width, + "height": args.height, + "seq_len": inputs["seq_len"], + "seconds": round(elapsed, 1), + "stopped_early": stopped_early, + } + ), + flush=True, + ) + for hook in hooks: + hook.remove() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/cpu_image_server.py b/scripts/cpu_image_server.py new file mode 100644 index 0000000..9ce6a80 --- /dev/null +++ b/scripts/cpu_image_server.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import os +import sys +import time +from contextlib import asynccontextmanager +from pathlib import Path +from threading import Lock + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import torch.nn as nn + +REPO_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(REPO_ROOT / "scripts")) + +import generate_cpu_experimental as gen + + +class GenerateRequest(BaseModel): + prompt: str + output: str + width: int + height: int + steps: int + seed: int + guidance: float = 1.0 + max_seq: int = 64 + + +class ServerState: + def __init__(self) -> None: + self.model_root = Path( + os.environ.get("BONSAI_MODEL_ROOT", str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-gemlite")) + ) + self.transformer_dir = Path( + os.environ.get( + "BONSAI_TRANSFORMER_DIR", + str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-unpacked" / "transformer"), + ) + ) + self.prompt_cache_dir = Path( + os.environ.get("BONSAI_PROMPT_CACHE_DIR", str(REPO_ROOT / "outputs" / "prompt_cache_fp32_auto")) + ) + self.dtype_name = os.environ.get("BONSAI_DTYPE", "float32") + self.text_encoder_dtype_name = os.environ.get("BONSAI_TEXT_ENCODER_DTYPE", "auto") + self.threads = int(os.environ.get("BONSAI_THREADS", "4")) + self.interop_threads = int(os.environ.get("BONSAI_INTEROP_THREADS", "4")) + self.use_gemlite_dense = os.environ.get("BONSAI_GEMLITE_DENSE", "").lower() in {"1", "true", "yes"} + self.compile_transformer = os.environ.get("BONSAI_COMPILE_TRANSFORMER", "").lower() in {"1", "true", "yes"} + self.compile_mode = os.environ.get("BONSAI_COMPILE_MODE", "reduce-overhead") + self.dynamic_int8 = os.environ.get("BONSAI_DYNAMIC_INT8", "").lower() in {"1", "true", "yes"} + self.dynamic_int8_engine = os.environ.get("BONSAI_DYNAMIC_INT8_ENGINE", "qnnpack") + self.lock = Lock() + self.ready = False + self.started_at = time.time() + self.vae = None + self.transformer = None + + +STATE = ServerState() + + +def configure_runtime() -> None: + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + gen.torch.set_num_threads(STATE.threads) + gen.torch.set_num_interop_threads(STATE.interop_threads) + gen.CPU_INFERENCE_DTYPE = gen.resolve_inference_dtype(STATE.dtype_name) + gen.TEXT_ENCODER_DTYPE = gen.resolve_text_encoder_dtype(STATE.text_encoder_dtype_name) + + +def load_models() -> None: + configure_runtime() + STATE.prompt_cache_dir.mkdir(parents=True, exist_ok=True) + STATE.vae = gen.AutoencoderKLFlux2.from_pretrained( + str(STATE.model_root / "vae"), + torch_dtype=gen.CPU_INFERENCE_DTYPE, + ).to("cpu").eval() + STATE.transformer = ( + gen.load_dense_transformer(STATE.model_root) + if STATE.use_gemlite_dense + else gen.load_unpacked_transformer(STATE.transformer_dir) + ) + if STATE.dynamic_int8: + gen.torch.backends.quantized.engine = STATE.dynamic_int8_engine + STATE.transformer = gen.torch.quantization.quantize_dynamic( + STATE.transformer, + {nn.Linear}, + dtype=gen.torch.qint8, + ) + if STATE.compile_transformer: + STATE.transformer = gen.torch.compile(STATE.transformer, mode=STATE.compile_mode) + STATE.ready = True + + +@asynccontextmanager +async def lifespan(app: FastAPI): + load_models() + yield + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/healthz") +def healthz() -> dict[str, object]: + return { + "ready": STATE.ready, + "uptime_seconds": round(time.time() - STATE.started_at, 1), + "dtype": STATE.dtype_name, + "text_encoder_dtype": STATE.text_encoder_dtype_name, + "threads": STATE.threads, + "interop_threads": STATE.interop_threads, + "gemlite_dense": STATE.use_gemlite_dense, + "compile_transformer": STATE.compile_transformer, + "compile_mode": STATE.compile_mode, + "dynamic_int8": STATE.dynamic_int8, + "dynamic_int8_engine": STATE.dynamic_int8_engine, + "prompt_cache_dir": str(STATE.prompt_cache_dir), + } + + +@app.post("/generate") +def generate(req: GenerateRequest) -> dict[str, object]: + if not STATE.ready or STATE.vae is None or STATE.transformer is None: + raise HTTPException(status_code=503, detail="server not ready") + if req.width % 32 != 0 or req.height % 32 != 0: + raise HTTPException(status_code=400, detail="height and width must be multiples of 32") + + output_path = Path(req.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + cache_key = gen.prompt_cache_key( + req.prompt, + STATE.model_root, + req.max_seq, + include_inference_dtype=False, + ) + cache_path = STATE.prompt_cache_dir / f"{cache_key}.pt" + if not cache_path.exists(): + raise HTTPException( + status_code=409, + detail="prompt cache miss; uncached prompts are disabled on this host for the warm server route", + ) + + with STATE.lock: + start = time.time() + prompt_embeds = gen.torch.load(cache_path, map_location="cpu").to(gen.CPU_INFERENCE_DTYPE) + prompt_seconds = time.time() - start + image = gen.run_diffusion( + STATE.transformer, + STATE.vae, + prompt_embeds, + height=req.height, + width=req.width, + num_steps=req.steps, + seed=req.seed, + guidance=req.guidance, + ) + image.save(output_path) + total_seconds = time.time() - start + + return { + "prompt": req.prompt, + "output_path": str(output_path), + "width": req.width, + "height": req.height, + "steps": req.steps, + "seed": req.seed, + "prompt_seconds": round(prompt_seconds, 1), + "total_seconds": round(total_seconds, 1), + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="Warm CPU Bonsai image server.") + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8011) + parser.add_argument("--model-root", default=str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-gemlite")) + parser.add_argument( + "--transformer-dir", + default=str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-unpacked" / "transformer"), + ) + parser.add_argument("--prompt-cache-dir", default=str(REPO_ROOT / "outputs" / "prompt_cache_fp32_auto")) + parser.add_argument("--dtype", default="float32") + parser.add_argument("--text-encoder-dtype", default="auto") + parser.add_argument("--threads", type=int, default=4) + parser.add_argument("--interop-threads", type=int, default=4) + parser.add_argument("--gemlite-dense", action="store_true") + parser.add_argument("--compile-transformer", action="store_true") + parser.add_argument("--compile-mode", default="reduce-overhead") + parser.add_argument("--dynamic-int8", action="store_true") + parser.add_argument("--dynamic-int8-engine", default="qnnpack") + args = parser.parse_args() + + os.environ["BONSAI_MODEL_ROOT"] = args.model_root + os.environ["BONSAI_TRANSFORMER_DIR"] = args.transformer_dir + os.environ["BONSAI_PROMPT_CACHE_DIR"] = args.prompt_cache_dir + os.environ["BONSAI_DTYPE"] = args.dtype + os.environ["BONSAI_TEXT_ENCODER_DTYPE"] = args.text_encoder_dtype + os.environ["BONSAI_THREADS"] = str(args.threads) + os.environ["BONSAI_INTEROP_THREADS"] = str(args.interop_threads) + os.environ["BONSAI_GEMLITE_DENSE"] = "1" if args.gemlite_dense else "0" + os.environ["BONSAI_COMPILE_TRANSFORMER"] = "1" if args.compile_transformer else "0" + os.environ["BONSAI_COMPILE_MODE"] = args.compile_mode + os.environ["BONSAI_DYNAMIC_INT8"] = "1" if args.dynamic_int8 else "0" + os.environ["BONSAI_DYNAMIC_INT8_ENGINE"] = args.dynamic_int8_engine + + global STATE + STATE = ServerState() + uvicorn.run(app, host=args.host, port=args.port, log_level="warning") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/generate_cpu_experimental.py b/scripts/generate_cpu_experimental.py new file mode 100644 index 0000000..64080b9 --- /dev/null +++ b/scripts/generate_cpu_experimental.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import gc +import hashlib +import json +import os +import platform +import sys +import time +from pathlib import Path + +import torch +import torch.nn as nn +from diffusers import AutoencoderKLFlux2, Flux2Pipeline, Flux2Transformer2DModel +from diffusers.pipelines.flux2.pipeline_flux2 import retrieve_timesteps +from hqq.core.quantize import HQQLinear +from hqq.models.hf.base import AutoHQQHFModel +from PIL import Image +from transformers import AutoTokenizer + +REPO_ROOT = Path(__file__).resolve().parents[1] +IMAGE_STUDIO_VENDOR = (REPO_ROOT / "vendor" / "image-studio").resolve() +if IMAGE_STUDIO_VENDOR.exists(): + sys.path.insert(0, str(IMAGE_STUDIO_VENDOR)) + +from backend_gpu.diffusion_klein import _mflux_empirical_mu # noqa: E402 + + +def log(msg: str) -> None: + now = time.strftime("%H:%M:%S") + print(f"[{now}] {msg}", flush=True) + + +def mem_gib() -> float: + try: + with open("/proc/self/status") as fh: + for line in fh: + if line.startswith("VmRSS:"): + return int(line.split()[1]) / 1024 / 1024 + except OSError: + return 0.0 + return 0.0 + + +TEXT_ENCODER_DTYPE = torch.bfloat16 +CPU_INFERENCE_DTYPE = torch.float16 + + +def cpu_supports_bf16() -> bool: + try: + cpuinfo = Path("/proc/cpuinfo").read_text() + except OSError: + return False + return " bf16" in f" {cpuinfo.lower()} " + + +def resolve_inference_dtype(name: str) -> torch.dtype: + if name == "auto": + # On this ARM CPU path, fp16 transformer inference benchmarks much + # worse than fp32. Prefer bf16 only with native support; otherwise use + # float32 as the CPU default. + return torch.bfloat16 if cpu_supports_bf16() else torch.float32 + if name == "bfloat16": + return torch.bfloat16 + if name == "float16": + return torch.float16 + if name == "float32": + return torch.float32 + raise ValueError(f"unsupported dtype {name}") + + +def resolve_text_encoder_dtype(name: str) -> torch.dtype: + if name == "auto": + # On CPUs without native bf16 support, float16 text-encoder inference + # can take dramatically longer than float32 for the short-prompt path + # we care about here. Prefer bf16 only when the ISA is real; otherwise + # stay in float32 and cast the final prompt embeds down afterward. + return torch.bfloat16 if cpu_supports_bf16() else torch.float32 + return resolve_inference_dtype(name) + + +def prompt_cache_key( + prompt: str, + model_root: Path, + max_sequence_length: int, + *, + include_inference_dtype: bool = True, +) -> str: + payload = { + "prompt": prompt, + "model_root": str(model_root.resolve()), + "max_sequence_length": int(max_sequence_length), + "text_encoder_dtype": str(TEXT_ENCODER_DTYPE), + } + if include_inference_dtype: + payload["dtype"] = str(CPU_INFERENCE_DTYPE) + return hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:16] + + +def load_quantized_text_encoder(path: Path, dtype: torch.dtype) -> nn.Module: + log("loading quantized text encoder") + model = AutoHQQHFModel.from_quantized(str(path), device="cpu") + return model.to(dtype) + + +def dequantize_text_encoder(model: nn.Module, dtype: torch.dtype) -> nn.Module: + count = 0 + start = time.time() + for name, mod in list(model.named_modules()): + if isinstance(mod, HQQLinear): + parent_name, _, child_name = name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + weight = mod.dequantize().to(dtype) + dense = nn.Linear( + mod.in_features, + mod.out_features, + bias=mod.bias is not None, + dtype=dtype, + ) + dense.weight = nn.Parameter(weight, requires_grad=False) + if mod.bias is not None: + dense.bias = nn.Parameter(mod.bias.to(dtype), requires_grad=False) + setattr(parent, child_name, dense) + count += 1 + if count % 20 == 0: + gc.collect() + log(f"text encoder dense layers: {count} rss={mem_gib():.2f} GiB elapsed={time.time()-start:.1f}s") + log(f"text encoder dequantized: {count} dense layers rss={mem_gib():.2f} GiB elapsed={time.time()-start:.1f}s") + return model + + +@torch.no_grad() +def encode_prompt( + prompt: str, + model_root: Path, + *, + max_sequence_length: int, + cache_dir: Path | None = None, + text_encoder: nn.Module | None = None, + tokenizer: AutoTokenizer | None = None, +) -> torch.Tensor: + cache_path: Path | None = None + if cache_dir is not None: + cache_dir.mkdir(parents=True, exist_ok=True) + cache_path = cache_dir / f"{prompt_cache_key(prompt, model_root, max_sequence_length)}.pt" + if cache_path.exists(): + log(f"loading cached prompt embeds from {cache_path}") + return torch.load(cache_path, map_location="cpu") + compat_cache_path = cache_dir / ( + f"{prompt_cache_key(prompt, model_root, max_sequence_length, include_inference_dtype=False)}.pt" + ) + if compat_cache_path.exists(): + log(f"loading compatible cached prompt embeds from {compat_cache_path}") + return torch.load(compat_cache_path, map_location="cpu").to(CPU_INFERENCE_DTYPE) + + owns_text_encoder = text_encoder is None + owns_tokenizer = tokenizer is None + if text_encoder is None or tokenizer is None: + text_path = model_root / "text_encoder-hqq-4bit" + tok_path = text_path / "tokenizer" + if text_encoder is None: + text_encoder = dequantize_text_encoder( + load_quantized_text_encoder(text_path, TEXT_ENCODER_DTYPE), + TEXT_ENCODER_DTYPE, + ) + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(str(tok_path)) + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + full_inputs = tokenizer(text, return_tensors="pt") + effective_max_length = min(max_sequence_length, int(full_inputs["input_ids"].shape[1])) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=effective_max_length, + ) + full_tokens = int(full_inputs["input_ids"].shape[1]) + used_tokens = int(inputs["attention_mask"][0].sum().item()) + if full_tokens > max_sequence_length: + log( + f"prompt truncated from {full_tokens} to {max_sequence_length} tokens; " + "increase --max-seq for faithful prompt encoding" + ) + else: + log( + f"prompt tokenized to {used_tokens} tokens with effective_seq={effective_max_length} " + f"within max_seq={max_sequence_length}" + ) + log("encoding prompt") + start = time.time() + output = text_encoder( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + output_hidden_states=True, + use_cache=False, + ) + prompt_embeds = torch.stack([output.hidden_states[k] for k in (9, 18, 27)], dim=1) + batch_size, num_channels, seq_len, hidden_dim = prompt_embeds.shape + prompt_embeds = ( + prompt_embeds.permute(0, 2, 1, 3) + .reshape(batch_size, seq_len, num_channels * hidden_dim) + .to(CPU_INFERENCE_DTYPE) + .cpu() + .contiguous() + ) + log(f"prompt encoded shape={tuple(prompt_embeds.shape)} rss={mem_gib():.2f} GiB elapsed={time.time()-start:.1f}s") + if cache_dir is not None: + assert cache_path is not None + torch.save(prompt_embeds, cache_path) + log(f"saved prompt embeds cache {cache_path}") + compat_cache_path = cache_dir / ( + f"{prompt_cache_key(prompt, model_root, max_sequence_length, include_inference_dtype=False)}.pt" + ) + if compat_cache_path != cache_path: + torch.save(prompt_embeds, compat_cache_path) + log(f"saved compatible prompt embeds cache {compat_cache_path}") + del output, inputs + if owns_text_encoder: + del text_encoder + if owns_tokenizer: + del tokenizer + gc.collect() + return prompt_embeds + + +def unpack_rows(w_q: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + return torch.stack( + [(w_q >> 0) & 0x3, (w_q >> 2) & 0x3, (w_q >> 4) & 0x3, (w_q >> 6) & 0x3], + dim=1, + ).reshape(rows, cols) + + +def unpack_cols(w_q: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + return torch.stack( + [(w_q >> 0) & 0x3, (w_q >> 2) & 0x3, (w_q >> 4) & 0x3, (w_q >> 6) & 0x3], + dim=2, + ).reshape(rows, cols) + + +def expand_group_metadata(meta: torch.Tensor, target_shape: tuple[int, int], group_size: int) -> torch.Tensor: + rows, cols = target_shape + expected_shape = (cols // group_size, rows) + if tuple(meta.shape) != expected_shape: + raise RuntimeError( + f"Unhandled gemlite metadata shape meta={tuple(meta.shape)} target={target_shape} group_size={group_size}" + ) + return meta.T.repeat_interleave(group_size, dim=1) + + +def decode_gemlite_weight( + layer_state: dict[str, torch.Tensor], + target_shape: tuple[int, int], + group_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + w_q = layer_state["W_q"] + scales = layer_state["scales"].to(torch.float32) + zeros = layer_state["zeros"].to(torch.float32) + metadata = [int(v) for v in layer_state["metadata"].tolist()] + rows, cols = target_shape + expected_wq_shape = (cols // 4, rows) + if tuple(w_q.shape) != expected_wq_shape: + raise RuntimeError( + f"Unhandled gemlite packing W_q={tuple(w_q.shape)} target={target_shape} expected={expected_wq_shape}" + ) + if metadata[10] != 4: + raise RuntimeError(f"Unhandled gemlite W_group_mode={metadata[10]}") + chunks = unpack_cols(w_q, rows, cols).to(torch.float32) + scale_full = expand_group_metadata(scales, target_shape, group_size) + zero_full = expand_group_metadata(zeros, target_shape, group_size) + return (chunks * scale_full + zero_full).to(dtype) + + +def load_dense_transformer(model_root: Path) -> nn.Module: + path = model_root / "transformer-gemlite-int2" + with (path / "config.json").open() as fh: + cfg = json.load(fh) + with (path / "quantization_config.json").open() as fh: + qcfg = json.load(fh) + group_size = int(qcfg.get("group_size", 128)) + + log("building transformer shell") + # The gemlite GPU backend requires fp16 activations, but this CPU fallback + # is a dense PyTorch model. Use bf16 only on CPUs that can execute it + # efficiently; otherwise stay in fp32. + model = Flux2Transformer2DModel.from_config(cfg).to(CPU_INFERENCE_DTYPE) + state = torch.load(str(path / "state_dict.pt"), map_location="cpu") + buckets: dict[str, dict[str, torch.Tensor]] = {} + remainder: dict[str, torch.Tensor] = {} + for key, value in state.items(): + fqn, _, leaf = key.rpartition(".") + if leaf in {"W_q", "bias", "scales", "zeros", "metadata", "orig_shape"} and fqn: + buckets.setdefault(fqn, {})[leaf] = value + else: + remainder[key] = value + missing, unexpected = model.load_state_dict(remainder, strict=False) + if unexpected: + raise RuntimeError(f"unexpected transformer keys: {unexpected[:8]}") + if len(missing) != len(buckets): + raise RuntimeError(f"unexpected transformer missing keys count: {len(missing)} vs {len(buckets)}") + del remainder, state + gc.collect() + + start = time.time() + items = sorted(buckets.items()) + for idx, (fqn, layer_state) in enumerate(items, 1): + parent_fqn, _, child_name = fqn.rpartition(".") + parent = model.get_submodule(parent_fqn) if parent_fqn else model + child = getattr(parent, child_name) + weight = decode_gemlite_weight( + layer_state, + tuple(child.weight.shape), + group_size, + CPU_INFERENCE_DTYPE, + ).to(CPU_INFERENCE_DTYPE) + child.weight = nn.Parameter(weight, requires_grad=False) + if "bias" in layer_state and layer_state["bias"] is not None: + child.bias = nn.Parameter(layer_state["bias"].to(CPU_INFERENCE_DTYPE), requires_grad=False) + if idx % 10 == 0 or idx == len(items): + gc.collect() + log(f"transformer dense layers: {idx}/{len(items)} rss={mem_gib():.2f} GiB elapsed={time.time()-start:.1f}s") + model._inference_dtype = CPU_INFERENCE_DTYPE # type: ignore[attr-defined] + return model.eval() + + +def load_unpacked_transformer(transformer_dir: Path) -> nn.Module: + log(f"loading unpacked transformer from {transformer_dir}") + model = Flux2Transformer2DModel.from_pretrained( + str(transformer_dir), + torch_dtype=CPU_INFERENCE_DTYPE, + local_files_only=True, + ).to("cpu") + model._inference_dtype = CPU_INFERENCE_DTYPE # type: ignore[attr-defined] + return model.eval() + + +def resolve_transformer_dir(model_root: Path, explicit_transformer_dir: str | None) -> Path | None: + if explicit_transformer_dir: + return Path(explicit_transformer_dir) + + direct = model_root / "transformer" + if direct.is_dir(): + log(f"using unpacked transformer at {direct}") + return direct + + sibling_name = model_root.name.replace("gemlite", "unpacked") + sibling = model_root.parent / sibling_name / "transformer" + if sibling.is_dir(): + log( + "using sibling unpacked transformer instead of GemLite dense reconstruction: " + f"{sibling}" + ) + return sibling + + return None + + +def build_scheduler(): + from diffusers import FlowMatchEulerDiscreteScheduler + + return FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=3.0, + use_dynamic_shifting=True, + base_shift=0.5, + max_shift=1.15, + base_image_seq_len=256, + max_image_seq_len=4096, + ) + + +@torch.no_grad() +def decode_latents_to_image( + latents: torch.Tensor, + latent_ids: torch.Tensor, + vae: nn.Module, + *, + log_prefix: str, +) -> Image.Image: + vae_device = next(vae.parameters()).device + decode_start = time.time() + log(f"{log_prefix} decode start rss={mem_gib():.2f} GiB") + latents = Flux2Pipeline._unpack_latents_with_ids(latents, latent_ids) + log(f"{log_prefix} latents unpacked rss={mem_gib():.2f} GiB elapsed={time.time()-decode_start:.1f}s") + latents = latents.to(device=vae_device, dtype=CPU_INFERENCE_DTYPE) + bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(latents.device, latents.dtype) + latents = latents * bn_std + bn_mean + latents = Flux2Pipeline._unpatchify_latents(latents) + log(f"{log_prefix} vae decode start rss={mem_gib():.2f} GiB elapsed={time.time()-decode_start:.1f}s") + image = vae.decode(latents, return_dict=False)[0] + log(f"{log_prefix} vae decode done rss={mem_gib():.2f} GiB elapsed={time.time()-decode_start:.1f}s") + img = image[0].clamp(-1.0, 1.0).float() + img = (img + 1.0) * 127.5 + img = img.clamp(0.0, 255.0).round().to(torch.uint8) + img = img.permute(1, 2, 0).cpu().numpy() + log(f"{log_prefix} decode done rss={mem_gib():.2f} GiB elapsed={time.time()-decode_start:.1f}s") + return Image.fromarray(img, mode="RGB") + + +@torch.no_grad() +def run_diffusion( + transformer: nn.Module, + vae: nn.Module, + prompt_embeds: torch.Tensor, + *, + height: int, + width: int, + num_steps: int, + seed: int, + guidance: float, + step_output_dir: Path | None = None, + step_output_stem: str = "step", +) -> Image.Image: + transformer_device = next(transformer.parameters()).device + transformer_dtype = next(transformer.parameters()).dtype + scheduler = build_scheduler() + prompt_embeds = prompt_embeds.to(device=transformer_device, dtype=transformer_dtype) + text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(transformer_device) + + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + h_lat = 2 * (int(height) // (vae_scale_factor * 2)) + w_lat = 2 * (int(width) // (vae_scale_factor * 2)) + in_channels_latents = transformer.config.in_channels // 4 + + gen = torch.Generator(device="cpu").manual_seed(int(seed)) + noise_shape = (1, in_channels_latents * 4, h_lat // 2, w_lat // 2) + latents_4d = torch.randn(noise_shape, generator=gen, dtype=torch.float32).to( + device=transformer_device, dtype=transformer_dtype + ) + latent_ids = Flux2Pipeline._prepare_latent_ids(latents_4d).to(transformer_device) + latents = Flux2Pipeline._pack_latents(latents_4d) + image_seq_len = latents.shape[1] + + mu = _mflux_empirical_mu(image_seq_len=image_seq_len, num_steps=num_steps) + sigmas = None if getattr(scheduler.config, "use_flow_sigmas", False) else torch.linspace(1.0, 1.0 / num_steps, num_steps).numpy() + timesteps, _ = retrieve_timesteps(scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu) + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(0) + guidance_t = torch.full([1], guidance, device=transformer_device, dtype=torch.float32).expand(latents.shape[0]) + + start = time.time() + for i, t in enumerate(timesteps, 1): + step_start = time.time() + log(f"diffusion step {i}/{len(timesteps)} start rss={mem_gib():.2f} GiB") + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance_t, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + return_dict=False, + )[0] + latents_dtype = latents.dtype + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + log( + f"diffusion step {i}/{len(timesteps)} done " + f"rss={mem_gib():.2f} GiB elapsed={time.time()-step_start:.1f}s" + ) + if step_output_dir is not None: + step_image = decode_latents_to_image( + latents.clone(), + latent_ids, + vae, + log_prefix=f"step {i}/{len(timesteps)}", + ) + step_path = step_output_dir / f"{step_output_stem}_step{i:02d}.png" + step_image.save(step_path) + log(f"saved {step_path}") + log(f"diffusion complete rss={mem_gib():.2f} GiB elapsed={time.time()-start:.1f}s") + + return decode_latents_to_image(latents, latent_ids, vae, log_prefix="final") + + +def main() -> int: + global CPU_INFERENCE_DTYPE + global TEXT_ENCODER_DTYPE + + parser = argparse.ArgumentParser() + parser.add_argument("--prompt", required=True) + parser.add_argument("--output", required=True) + parser.add_argument("--height", type=int, default=256) + parser.add_argument("--width", type=int, default=256) + parser.add_argument("--steps", type=int, default=4) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--guidance", type=float, default=1.0) + parser.add_argument("--max-seq", type=int, default=512) + parser.add_argument( + "--dtype", + choices=("auto", "float16", "bfloat16", "float32"), + default="auto", + help="Main transformer/VAE/latent dtype. auto prefers bf16 only with native support, otherwise float32 on CPU.", + ) + parser.add_argument( + "--text-encoder-dtype", + choices=("auto", "float16", "bfloat16", "float32"), + default="auto", + help="Text encoder dtype. auto prefers bf16 only with native support, otherwise float32 for faster CPU prompt encoding.", + ) + parser.add_argument("--threads", type=int) + parser.add_argument("--interop-threads", type=int) + parser.add_argument("--prompt-cache-dir") + parser.add_argument("--model-root", default=str(REPO_ROOT / "models" / "bonsai-image-4B-ternary-gemlite")) + parser.add_argument("--transformer-dir") + parser.add_argument( + "--gemlite-dense", + action="store_true", + help="force the experimental GemLite-to-dense CPU transformer path", + ) + parser.add_argument( + "--allow-sub128", + action="store_true", + help="allow exploratory renders below 128x128; useful for performance sweeps, but less reliable semantically", + ) + parser.add_argument("--step-output-dir") + parser.add_argument( + "--prompt-cache-only", + action="store_true", + help="encode and cache the prompt embeds, then exit without loading VAE/transformer or rendering", + ) + args = parser.parse_args() + + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + if args.steps <= 0: + raise SystemExit("--steps must be a positive integer") + if args.threads is not None: + torch.set_num_threads(args.threads) + if args.interop_threads is not None: + torch.set_num_interop_threads(args.interop_threads) + CPU_INFERENCE_DTYPE = resolve_inference_dtype(args.dtype) + TEXT_ENCODER_DTYPE = resolve_text_encoder_dtype(args.text_encoder_dtype) + model_root = Path(args.model_root) + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + step_output_dir = Path(args.step_output_dir) if args.step_output_dir else None + prompt_cache_dir = Path(args.prompt_cache_dir) if args.prompt_cache_dir else None + if step_output_dir is not None: + step_output_dir.mkdir(parents=True, exist_ok=True) + + if args.height % 32 != 0 or args.width % 32 != 0: + raise SystemExit("height and width must be multiples of 32") + + # Each packed latent token expands to a fixed 16x16 image patch after the + # unpack + VAE decode path. 64x64 therefore gives only a 4x4 packed grid: + # useful for coarse sanity checks, but too small for reliable geometry + # tests like quadrants or thin lines. Use 128x128+ for structure debugging. + min_dim = 96 if args.allow_sub128 else 128 + if args.height < min_dim or args.width < min_dim: + raise SystemExit( + f"height and width must be at least {min_dim} " + f"for {'exploratory' if args.allow_sub128 else 'meaningful'} CPU renders" + ) + + total_start = time.time() + log( + f"cpu={platform.machine()} threads={torch.get_num_threads()} " + f"interop={torch.get_num_interop_threads()} " + f"text_dtype={TEXT_ENCODER_DTYPE} dtype={CPU_INFERENCE_DTYPE}" + ) + prompt_embeds = encode_prompt( + args.prompt, + model_root, + max_sequence_length=args.max_seq, + cache_dir=prompt_cache_dir, + ) + log(f"after prompt encode rss={mem_gib():.2f} GiB") + if args.prompt_cache_only: + log("prompt-cache-only requested; skipping model load and render") + log(f"total elapsed={time.time()-total_start:.1f}s") + return 0 + + log("loading VAE") + vae = AutoencoderKLFlux2.from_pretrained( + str(model_root / "vae"), + torch_dtype=CPU_INFERENCE_DTYPE, + ).to("cpu").eval() + log(f"vae ready rss={mem_gib():.2f} GiB") + + transformer_dir = None if args.gemlite_dense else resolve_transformer_dir(model_root, args.transformer_dir) + transformer = ( + load_unpacked_transformer(transformer_dir) + if transformer_dir is not None + else load_dense_transformer(model_root) + ) + log(f"transformer ready rss={mem_gib():.2f} GiB") + image = run_diffusion( + transformer, + vae, + prompt_embeds, + height=args.height, + width=args.width, + num_steps=args.steps, + seed=args.seed, + guidance=args.guidance, + step_output_dir=step_output_dir, + step_output_stem=output_path.stem, + ) + image.save(output_path) + log(f"saved {output_path}") + log(f"total elapsed={time.time()-total_start:.1f}s") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/start_cpu_image_server.sh b/scripts/start_cpu_image_server.sh new file mode 100755 index 0000000..7ef9037 --- /dev/null +++ b/scripts/start_cpu_image_server.sh @@ -0,0 +1,14 @@ +#!/bin/sh +set -e + +DEMO_DIR="$(cd "$(dirname "$0")/.." && pwd)" +. "$DEMO_DIR/scripts/common.sh" +ensure_venv "$DEMO_DIR" + +: "${BONSAI_CPU_SERVER_HOST:=127.0.0.1}" +: "${BONSAI_CPU_SERVER_PORT:=8011}" + +exec "$DEMO_DIR/.venv/bin/python" "$DEMO_DIR/scripts/cpu_image_server.py" \ + --host "$BONSAI_CPU_SERVER_HOST" \ + --port "$BONSAI_CPU_SERVER_PORT" \ + "$@"