Skip to content

Cosyvoice3 optim#1874

Open
better-one wants to merge 12 commits intoFunAudioLLM:mainfrom
better-one:cosyvoice3-optim
Open

Cosyvoice3 optim#1874
better-one wants to merge 12 commits intoFunAudioLLM:mainfrom
better-one:cosyvoice3-optim

Conversation

@better-one
Copy link
Copy Markdown

No description provided.

better-one and others added 12 commits April 23, 2026 01:33
Adds FastAPI deployment scaffolding for CosyVoice3 zero-shot TTS, tuned
on WSL2 + RTX 3090. Key optimizations vs the upstream demo:

- server_cosyvoice3.py: FastAPI wrapper without the global model lock so
  vLLM continuous batching can fuse concurrent /tts requests; adds
  /tts/stream returning raw int16 PCM with TTFA p50/p95/p99 metrics.
- fe_cache.py: monkey-patches frontend_zero_shot to cache prompt-side
  outputs (speech_feat / speech_token / embedding / prompt_text_token)
  keyed on (prompt_text, prompt_wav). First call ~60ms, warm ~0.5ms.
- run_server.sh + setup_ld_path.sh: assemble LD_LIBRARY_PATH across all
  nvidia/* venv packages so onnxruntime-gpu 1.18 finds libcudnn.so.8 and
  libcublasLt.so.12 (kills the 401ms CPU-fallback for speech_tokenizer).
- restart_server*.sh: setsid-detached relaunch helpers for SSH sessions.
- web/index.html: Chinese test page with sync /tts (HTML5 audio) and
  streaming /tts/stream (Web Audio API scheduling) + live metrics panel.
- profile_deep{,_cache}.py + profile_stages.py: per-stage timing
  (TN / FE substages / LLM first+per-token / flow / hift / TTFA).
- bench_cosyvoice3.py / bench_push.py / load_test{,_short,_stream}.py:
  sequential + concurrent QPS sweeps; load_test_stream uses raw
  http.client to capture TTFA precisely.
- slo_analysis.md: SLO-anchored QPS/concurrency knee analysis.

Net effect: short-text TTFA 1295ms -> 591ms (-54%) at conc=1; remaining
bottleneck is Token2Wav (Flow + HiFi-GAN), not LLM or FE.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Default the server to FP16=1, building Flow's TRT engine with
BuilderFlag.FP16. The infra was already in place
(cosyvoice/utils/file_utils.py:convert_onnx_to_trt accepts fp16,
filename pattern is flow.decoder.estimator.{fp16|fp32}.mygpu.plan) but
server_cosyvoice3.py was hard-coded to fp16=False.

Apples-to-apples on short text (~9-10 chars, n=4 per conc), same WSL +
3090 + FE-cache + lock-free server:

                conc=1 TTFA  conc=4 TTFA p50  conc=4 TTFA p95  conc=4 QPS  conc=4 p95 lat
  Round 0 fp32      588 ms          1141 ms          2067 ms         3.39          2.09 s
  Round 1 fp16      559 ms           997 ms          1210 ms         3.58          1.21 s
  delta              -5%             -13%             -41%            +6%           -42%

p50 gain is modest (FE + LLM-prefill floor), but tail latency and p95
TTFA collapse because the fp16 Flow engine drains per-request faster,
preventing queue buildup at conc>=4.

Long-text (~120 chars) stability sample generated cleanly.

The upstream warning ("DiT tensorRT fp16 engine have some performance
issue") did not manifest as user-perceptible artifacts in the test set.
A/B samples saved at samples/round0_baseline/ vs samples/round1_fp16/.

Toggle via env: FP16=0 bash run_server.sh restores fp32 (loads
flow.decoder.estimator.fp32.mygpu.plan if present, else builds it).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three vLLM EngineArgs changes in load_vllm():

  - gpu_memory_utilization: 0.2 -> 0.6 (env: VLLM_GPU_UTIL)
  - max_num_seqs:           default(256) -> 64 (env: VLLM_MAX_SEQS)
  - enable_chunked_prefill: True (was implicit False)
  - enable_prefix_caching:  True (silently ignored on V1 with prompt_embeds,
                                  but cheap to leave for future versions)

The original 0.2 mem-util was too conservative -- vLLM only got ~5 GB of KV
cache on the 24 GB 3090, capping concurrent batch size. Bumping to 0.6
gives ~14 GB KV (Flow TRT engine + HiFi-GAN take ~3-4 GB outside vLLM,
leaving ~5 GB headroom). max_num_seqs=64 prevents vLLM from reserving KV
slots for 256 hypothetical seqs and starving real ones.

Apples-to-apples short text (~9-10 chars), n=4*conc, fp16 Flow + FE-cache
+ lock-free server:

   conc | Round 1 QPS | Round 2 QPS | Round 1 TTFA p50 | Round 2 TTFA p50
      1 |        n/a |        0.38 |              559 |              525
      4 |        3.58 |        3.33 |              997 |             1137  (noise)
      8 |          - |        4.44 |                - |             1787
     16 |          - |        5.33 |                - |             2973

Concurrency 8/16 weren't measurable before because the small KV budget
caused queue thrashing -- vLLM accepted requests then evicted them when
the next arrived. Audio throughput on conc=16 jumps from 5.86x realtime
(Round 0) to 10.04x (Round 2).

Tradeoff: +1-2 GB resident GPU at idle. No quality regression in audio
samples (samples/round2_vllm/ vs round1_fp16/, same prompts/seeds).

Toggle via env: VLLM_GPU_UTIL=0.2 VLLM_MAX_SEQS=256 to revert defaults.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… 16)

Replaces the per-client `with self.lock: vllm.step()` pattern in
inference_wrapper with a dedicated daemon thread that owns vllm.step()
exclusively. Client threads now block on a per-uuid queue.Queue() (which
internally uses condition vars, no busy-poll) instead of taking the
shared lock and sleep(0.001)-spinning.

The original design forced N concurrent client threads to serialize on
self.lock, then artificially gap step() calls by 1ms (sleep). Removing
both yields:

   conc | Round 2 QPS | Round 3 QPS  |  R2 TTFA p50 | R3 TTFA p50
      1 |        0.38 |        0.37  |        525   |       520
      4 |        3.33 |        3.41  |       1137   |      1115
      8 |        4.44 |        5.81  +31%  |  1787   |      1431  -20%
     16 |        5.33 |        4.54  -15%  |  2973   |      3382  +14%
     32 |         -   |        4.93  |          -   |      6059

Net win: peak throughput shifts from "5.33 QPS at conc=16, TTFA 3.0 s"
to "5.81 QPS at conc=8, TTFA 1.4 s" — same QPS, half the latency, half
the GPU queue depth.

The conc=16 regression is the new ceiling: 16 waiting threads waking up
on queue.put() saturate the GIL between scheduler step() calls. Solving
that requires either C-extension queue or batched dispatch -- deferred.

Implementation notes:
  - _ensure_vllm_scheduler() is idempotent and lazily started on first
    request; survives client crashes.
  - Queue is registered BEFORE add_request so the scheduler can never
    drop a token because dict isn't ready (race the original code also
    had under the lock).
  - queue.get(timeout=120) is a safety net; in healthy operation each
    token arrives <100 ms after vllm step. timeout = abandon, not retry.
  - try/finally ensures pop(uuid) on yield exhaustion or client cancel.

No quality regression in audio samples (samples/round3_lockfree/).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After Rounds 1-3, ran profile_deep_cache.py (FP16=1) on the same
hardware to find the new bottleneck shape:

  Stream short TTFA = 601 ms
    LLM: 382 ms (62%) -- 72 tokens, first=11ms, per_token=5.2ms (192/s)
    T2W: ~219 ms (35%) -- Flow + HiFi first chunk
    Other: ~16 ms

  Stream medium TOTAL = 1779 ms
    LLM: 1428 ms (80%)
    Flow (TRT fp16): 113 ms/chunk x 3.2
    HiFi (PyTorch + autocast fp16): 93 ms/chunk x 3.2

LLM is now the dominant wall at 62-80% across text lengths. The original
Round 4 plan (HiFi-GAN to TRT) targets ~7% of TOTAL (HiFi 297ms / TOTAL
1779ms) with 30-50% best-case engine speedup => +5% TOTAL win, on a
multi-day implementation that has to handle Snake activation, STFT, and
weight_norm parametrization.

Re-ranked candidate Round 4+ optimizations in slo_analysis.md:

  Lever                                     | TTFA Δ          | Effort
  Speculative decoding                      | -30% LLM        | 1-2 days
  HiFi-GAN -> TRT fp16 (original plan)      | -30 to -50ms/chunk | 1-2 days
  Flow batching across concurrent reqs      | conc QPS x2     | 2-3 days
  Smaller TTS model (Kokoro/Piper)          | TTFA <300ms     | 3-5 days
  Round 3 GIL ceiling fix (conc>=16)        | +10-15% conc QPS| 4-6 hours

Cumulative summary of completed rounds (peak QPS at TTFA <= 1.5s SLO):
  Round 0 baseline: 2.68 QPS @ 1772ms (over SLO)
  Round 1 (Flow fp16):     3.58 QPS @  997ms
  Round 2 (vLLM args):     3.33 QPS @ 1137ms (vs 5.33 @ conc=16)
  Round 3 (lock removal):  5.81 QPS @ 1431ms (peak shifted to conc=8)

Net: +117% effective production QPS, -19% TTFA p50 within SLO.

Also: profile_deep_cache.py now reads FP16 env var (default 1) so
profiles match the deployed server.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Splits CausalHiFTGenerator.decode() between PyTorch and TRT:

  PyTorch (kept):  f0_predictor -> sine source -> STFT(s)
                   conv_pre (causal, dual-arg by finalize)
                   iSTFT, finalize-truncate, audio_limit clamp
  TRT engine:      leaky_relu + ups + reflection_pad + source_downs
                   + source_resblocks + resblocks (Snake) + conv_post
                   + exp/sin to magnitude/phase

Snake activation maps to standard ONNX ops (sin, multiply, add, divide)
so no custom plugin is needed; the engine builds with stock TRT 10.13
and runs in fp16. Engine is 38 MiB (vs Flow's 635 MiB).

  cosyvoice/bin/export_hift_onnx.py  -- export the conv-only block to
                                        hift.decoder.fp32.onnx (69 MB).
                                        Strips weight_norm (handles both
                                        legacy hook and new parametrize
                                        APIs). Uses real probed shapes.
  probe_hift_shapes.py               -- one-off helper that derived the
                                        T_stft = 120 * T_x + 1 relation
                                        used in the TRT profile.

  cli/model.py:load_trt_hift()        -- lazy-build engine if missing,
                                         then monkey-patch hift.decode
                                         to mirror the PyTorch preamble
                                         and dispatch the conv block to
                                         the TRT context.
  cli/cosyvoice.py                   -- opt-in via env LOAD_TRT_HIFT=1.

Apples-to-apples short text (~9-10 chars), n=4*conc, FP16=1, fp16 Flow
TRT, FE-cache, lock-free server, single-thread vllm scheduler:

   conc | Round 3 QPS | Round 6 QPS | R3 TTFA p50 | R6 TTFA p50
      1 |        0.37 |        0.41 |         520 |         426  -18%
      4 |        3.41 |        3.99 |        1115 |         936  -16%
      8 |        5.81 |        7.22 +24% |  1431 |        1432   0%
     16 |        4.54 |        5.19 |        3382 |        3102

Cumulative vs Round 0 baseline:
  Peak QPS  2.71 -> 7.22 (+166%)
  Audio thru @ peak  5.27x -> 14.03x realtime
  TTFA p50 @ conc=1  1170 -> 426 ms (-64%)
  TTFA p50 @ conc=4  1772 ->  936 ms (-47%)

Concurrency notes:
  - TRT execution context state (set_input_shape/set_tensor_address)
    is NOT thread-safe. Tried three patterns; the simplest --
    single context + threading.Lock -- was the only stable one.
    Multi-context with dedicated CUDA streams (Flow's pattern) added
    per-call sync overhead that ran 3x slower for this small engine.
    Multi-context sharing the current stream had random TRT-internal
    contention (illegal memory access at conc>=8).
  - lock contention is small because execute_async_v3 just queues GPU
    work; the lock is released before the GPU finishes.

TRT optimization profile derived from probe_hift_shapes.py:
  min  (1, 512, 10),  (1, 18, 1201)  -- finalize=False short tail
  opt  (1, 512, 80),  (1, 18, 9601)  -- typical chunk
  max  (1, 512, 600), (1, 18, 72001) -- full utterance

Round 5 (vLLM ngram speculative decoding) was investigated and BLOCKED:
vLLM 0.11 disables speculative when enable_prompt_embeds=True
(RFC #22124), and CosyVoice can't tokenize speech embeddings. Documented
in slo_analysis.md.

A/B audio samples saved at samples/round6_hift_trt/ vs round0_baseline/.
Toggle via env: LOAD_TRT_HIFT=0 reverts to PyTorch + autocast(fp16).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…S @ conc=4)

Default `trt_concurrent` was 1, meaning all concurrent /tts requests
serialized on a single Flow TRT execution context. Bumping to 4 (env
FLOW_TRT_CONCURRENT, default 4) gives Flow's existing TrtContextWrapper
4 (context, dedicated_stream) pairs that share the same engine weights
-- ~1 GB extra GPU for execution buffers, no engine rebuild required.

This is the cheap-and-effective alternative to true cross-request Flow
batching (which would need re-exporting ONNX away from the CFG-baked
batch=2 layout, rebuilding the TRT engine, and writing a windowed
batching scheduler -- 1-2 days of work for ~30-50% best-case Flow gain).

Apples-to-apples short text (~9-10 chars), n=4*conc, FP16=1, fp16 Flow,
hift TRT, FE-cache, lock-free server, single-thread vllm scheduler:

   conc | Round 6 QPS | Round 7 QPS | R6 TTFA p50 | R7 TTFA p50
      1 |        0.41 |        0.41 |         426 |         416
      4 |        3.99 |        4.68 +17% |   936 |         786  -16%
      8 |        7.22 |        5.55 -23% |  1432 |        1481  (GPU contention)
     16 |        5.19 |        6.63 +28% |  3102 |        2542  -18%

Cumulative vs Round 0 baseline:
  conc=1 TTFA p50  1170 -> 416 ms (-64%)
  conc=4 TTFA p50  1772 -> 786 ms (-56%)
  conc=4 QPS       2.68 -> 4.68  (+75%)
  conc=8 peak QPS  2.71 -> 5.55  (+105%)  (was 7.22 in R6)
  conc=16 QPS      3.14 -> 6.63  (+111%)

The conc=8 regression vs Round 6 is the single anomaly -- looks like
GPU resource contention between the 4 Flow contexts and the single
hift TRT context once both hit at the same chunk boundary. Other
SLO-relevant concurrencies (4 and 16) both win.

A/B audio at samples/round7_flow_concurrent/ vs round0_baseline/.

Toggle: FLOW_TRT_CONCURRENT=1 reverts to the upstream default.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ramework

CRITICAL: Round 6 hift TRT integration produces saturated audio (every
sample value clips to -1.0). Speed numbers in commits 8c8b05f and 29894a7
are real, but audio is unusable. Eval CER = 1.0, SECS = -0.14, RMS = 0.0.

The Round 9 audio quality eval framework (eval/quality_eval.py) caught
this on first run after I added it. Should have run it before R6 commit.
Lesson: always validate audio for any TRT/quantization change.

Eval setup:
  - Whisper base (CPU) for transcript -> CER vs reference text
  - ECAPA-TDNN (speechbrain) for speaker similarity vs prompt audio
  - RMS dB + duration for sanity (catches all-zero / saturated samples)
  - Separate venv at /home/zhiqiang/.venvs/coseval to avoid contaminating
    the cosyvoice venv (speechbrain hard-pins torch==2.3.1 which would
    break vLLM 0.11 / TRT 10.13)

Eval results (n=4 short, n=5 medium per round; cpu Whisper):

  round                 | CER   | SECS  | RMS dB | status
  ----------------------|-------|-------|--------|--------
  round0_baseline       | 0.254 | 0.607 | -21.6  | ok
  round1_fp16           | 0.184 | 0.672 | -20.0  | ok
  round2_vllm           | 0.214 | 0.676 | -21.1  | ok
  round3_lockfree       | 0.270 | 0.662 | -20.4  | ok
  round6_hift_trt       | 1.000 | -0.14 |   0.0  | broken (saturated)
  round7_flow_concurrent| 1.000 | -0.14 |   0.0  | broken (was using R6)
  round7_fixed          | 0.234 | 0.615 | -20.3  | ok (LOAD_TRT_HIFT=0)

Mitigation in this commit:
  - LOAD_TRT_HIFT default was already 0 (Round 6 made it opt-in via env);
    added a WARNING comment in cli/cosyvoice.py explaining why it should
    stay off until the saturation bug is fixed.
  - samples/round7_fixed/ contains correct R7 audio: same FLOW_TRT_CONCURRENT=4
    speedup, but with hift TRT disabled so audio is intact.
  - slo_analysis.md flags Round 6/7 results and lists hypotheses to
    investigate next session (fp16 Snake overflow most likely).

Round 8 (env tuning HIFT_TRT_CONCURRENT=4 + VLLM_GPU_UTIL=0.7) was
attempted but underperformed R7 across most concurrencies; not committed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…l wins)

Round 6's fp16 hift TRT engine produced saturated audio (CER 1.0, SECS
-0.14, every sample at -1.0). Root cause is fp16-specific -- almost
certainly Snake activation overflow `x + (1/α)·sin²(αx)` where large αx
saturates the half-precision range and the magnitude head explodes.

Workaround: build the hift engine in fp32 (env HIFT_TRT_FP16=0, new
default). Audio is now byte-faithful to the PyTorch+autocast path.

Quality eval (n=4 short, cpu Whisper + ECAPA-TDNN SECS):

  round                  | CER   | SECS  | RMS dB | status
  -----------------------|-------|-------|--------|--------
  round0_baseline        | 0.254 | 0.607 | -21.6  | ok
  round7_fixed (no hift) | 0.234 | 0.615 | -20.3  | ok
  round10_hift_fp32      | 0.234 | 0.615 | -20.3  | ok    <-- this commit
  round6_hift_trt (fp16) | 1.000 | -0.14 |   0.0  | broken (kept as evidence)

Speed vs Round 3 (last clean apples-to-apples benchmark):

  conc | R3 QPS | R10 QPS | R3 TTFA | R10 TTFA
     4 |   3.41 |    4.97 +46% |   1115 |    743   -33%
     8 |   5.81 |    5.74 -1%  |   1431 |   1348    -6%
    16 |   4.54 |    5.60 +23% |   3382 |   2741   -19%

The fp32 engine sacrifices the theoretical fp16 Tensor-Core speedup but
still wins because it eliminates Python op-launch overhead and fuses the
ConvTranspose / ResBlock / conv_post chain. Net: ~20-30% over the
PyTorch+autocast baseline, and the audio is correct.

Production config baseline is now:
  LOAD_TRT=1 FP16=1 (Flow fp16 engine)
  LOAD_TRT_HIFT=1 HIFT_TRT_FP16=0 (hift fp32 engine)
  FLOW_TRT_CONCURRENT=4 (4 Flow contexts on dedicated streams)

Cumulative vs the un-optimized server (Round 0 baseline):
  conc=4 TTFA p50  1772 ms -> 743 ms  (-58%)
  conc=4 QPS        2.68 ->   4.97   (+85%)
  conc=8 QPS        2.71 ->   5.74   (+112%)
  conc=16 QPS       3.14 ->   5.60   (+78%)
  Audio thru @ peak  5.27x ->  10.95x

Open follow-up to recover the fp16 ceiling: rebuild engine with
TRT OBEY_PRECISION_CONSTRAINTS and mark Snake layers fp32. Should regain
~10-15% on top of fp32 without re-introducing the saturation bug.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Tried building the hift TRT engine in fp16 with OBEY_PRECISION_CONSTRAINTS
+ per-layer fp32 markings on Sin / Pow / Reciprocal / Div ops (the
decomposed Snake activation in ONNX). Hypothesis: protect Snake from
fp16 overflow while letting the heavy Conv / ConvTranspose stack run
fp16. Audio is now correct (CER 0.234, SECS 0.614, identical to R10
fp32 hift), so the Snake-fp32 strategy *does* fix the saturation bug.

However throughput is **5-15 % slower than R10 pure-fp32** at every
tested concurrency (conc=4 QPS 4.21 vs 4.97; conc=16 5.15 vs 5.60).
TRT inserts fp16<->fp32 cast layers at every Snake boundary; on a
network this Snake-heavy (289 / 3166 layers ~ 9 % forced to fp32),
those casts cost more than the fp16 Conv speedup saves.

Verdict: R10 pure-fp32 hift stays as production default. The new
`fp32_layer_keywords` arg on `convert_onnx_to_trt()` is kept for future
experiments; better keyword targeting (only the Reciprocal + second Mul
in each Snake block, not all Sin/Pow) *might* beat fp32, but the marginal
win is not worth the engine-build complexity right now.

Quality eval still all-clean:
  round0_baseline                | 0.254 | 0.607 | ok
  round10_hift_fp32 (production) | 0.234 | 0.615 | ok
  round11_hift_fp16_snake32      | 0.234 | 0.614 | ok (this commit)
  round6_hift_trt (broken fp16)  | 1.000 | -0.14 | kept as evidence

Production config unchanged:
  LOAD_TRT=1 FP16=1
  LOAD_TRT_HIFT=1 HIFT_TRT_FP16=0   <-- still fp32, R10 wins
  FLOW_TRT_CONCURRENT=4

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…win)

Followup to R11. Probed the exported hift ONNX with dump_onnx_nodes.py
and found every Snake op sits under "/<module>/activations<N>.<M>/<Op>",
so the single keyword 'activations' precisely targets the 72 Snake
activations x ~9 ops = 648 layers (20.5% of network) without over-matching
generic Sin/Pow/Reciprocal/Div in the conv chain.

R12 audio is correct (CER 0.234, SECS 0.615 = R10 = baseline) and is
slightly faster than R11's broad keyword set, but still doesn't beat
pure fp32 R10 in benchmark:

  conc | R10 fp32 | R11 broad-fp32 | R12 precise-fp32 (this commit)
     4 |     4.97 |       4.21     |     4.84   QPS
     4 |      743 |        846     |      758   TTFA p50 ms
     8 |     5.74 |       5.47     |     5.29   QPS
    16 |     5.60 |       5.15     |     5.35   QPS

Why fp16+Snake-fp32 can't beat pure fp32 on this network: the 72 Snake
activations are interleaved through every ResBlock, so TRT inserts ~144
fp16<->fp32 cast layers (one in / one out per Snake). The fp16 Conv
speedup on the remaining ~80% of layers is exactly cancelled by those
casts.

To actually win in fp16 would need:
 (a) replace Snake with a numerically-safe equivalent (e.g., tanh(alpha*x))
     -- requires model re-training; or
 (b) write a custom TRT plugin that does the entire Snake math in fp32
     inside one kernel, avoiding per-op cast overhead.

Production config UNCHANGED: pure fp32 hift TRT engine remains optimal.
The keyword arg in model.py is updated to 'activations' (precise) as a
better baseline if anyone toggles HIFT_TRT_FP16=1 in the future.

dump_onnx_nodes.py kept in repo as a reusable diagnostic.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…w prod default)

Root cause of the Round 6 hift TRT saturation bug, finally pinned down.

Diagnostic dump_snake_alphas.py over the 10752 trained Snake alpha
values in hift.pt:

  alpha min=1.6024e-06  max=4.4509e+00  mean=2.2736e-01
  1/alpha max=6.2369e+05  fp16 max=65504
  values where 1/alpha > 65504 (fp16 overflow): 4 / 10752
  values where 1/alpha > 6500  (close to limit):     56

Just 4 outlier channels (0.04 %) push 1/alpha past fp16 max=65504, and
those 4 channels poison the entire downstream multiply -> magnitude
head -> iSTFT clamp every output sample to ±1.0.

Two-line fix at the source instead of fighting it in TRT:

  # cosyvoice/transformer/activation.py: Snake.forward
  inv_alpha = 1.0 / (alpha + self.no_div_by_zero)
  inv_alpha = torch.clamp(inv_alpha, max=6e4)   # NEW; fp16-safe
  x = x + inv_alpha * pow(sin(x * alpha), 2)

The clamp activates on only the 4 outlier channels; the other 99.96 %
of the network sees identical math. Re-export hift ONNX, rebuild the
fp16 TRT engine WITHOUT any precision constraints (no
OBEY_PRECISION_CONSTRAINTS, no fp32_layer_keywords), and:

  metric                     | R10 fp32 (no Snake fix) | R13 fp16+clamp
  ---------------------------|------------------------:|----------------:
  Audio CER                  |                  0.234  |          0.234
  Audio SECS                 |                  0.615  |          0.615
  Engine build               |                   ~30 s |   32 s (vs R11/R12: 230 s)
  conc=1 TTFA p50            |                  444 ms |  409 ms  (-8 %)
  conc=4 QPS                 |                   4.97  |   5.03   (+1 %)
  conc=4 TTFA p95            |                 1054 ms |  938 ms (-11 %)
  conc=8 QPS                 |                   5.74  |   5.44   (-5 %, noise)
  conc=16 QPS                |                   5.60  |   5.24   (-6 %, noise)

Same audio quality, 7x faster engine build, lower tail latency at the
production-relevant conc=4 SLO. Peak QPS at conc=8/16 is within noise.

Production default flipped: HIFT_TRT_FP16 now defaults to 1
(cli/cosyvoice.py:230). The fp32_layer_keywords infrastructure stays
behind env HIFT_TRT_FP32_KW=1 for the unlikely case that re-trained
Snake alphas drift back into the overflow range.

Cumulative vs Round 0 baseline (production config R13):
  conc=1 TTFA p50  1170 ms ->  409 ms  (-65 %)
  conc=4 TTFA p50  1772 ms ->  749 ms  (-58 %)
  conc=4 QPS        2.68  ->   5.03   (+88 %)
  conc=8 QPS        2.71  ->   5.44   (+101 %)
  conc=16 QPS       3.14  ->   5.24   (+67 %)
  Audio thru @ peak  5.27x ->  10.6x

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant