Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dense Sparse Attention (DSA) backend for TRT-LLM with indexer-based TopK selection."""
import math
import threading
Expand Down Expand Up @@ -1504,13 +1518,16 @@ def __init__(self,
# attribute queries do not end up frozen into a captured graph.
warmup_heuristic_topk_decode(top_k=self.index_topk)

def post_load_weights(self):
def cache_derived_state(self) -> None:
"""Fuse wk + weights_proj into single FP32 weight for F.linear GEMM under allow_tf32 (TF32 tensor cores on Ampere+)."""
# wk: [head_dim, hidden_size] + weights_proj: [n_heads, hidden_size]
# → fused: [head_dim + n_heads, hidden_size]
self._fused_wk_wp_weight = torch.cat(
[self.wk.weight.data, self.weights_proj.weight.data], dim=0)

def post_load_weights(self) -> None:
self.cache_derived_state()

@staticmethod
def prepare_one_prefill_chunk(
metadata: DSAtrtllmAttentionMetadata,
Expand Down Expand Up @@ -2404,7 +2421,7 @@ def pre_indexer_proj(
split in MLA.forward_dsa_proj sees a stable signature.
"""
assert self._fused_wk_wp_weight is not None, \
"post_load_weights() must be called before forward()"
"cache_derived_state() must be called before forward()"
hidden_float = _to_float(hidden_states)
with _tf32_matmul_enabled():
# F.linear computes input @ weight.T internally; no explicit .t() needed.
Expand Down
24 changes: 16 additions & 8 deletions tensorrt_llm/_torch/memory/gpu_memory_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@
CUDA memory pool. After loading, weights are committed for read-only
access by other workers and the client transitions to RO mode in place.
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
weights from the GMS pool. `post_load_weights()` must run BEFORE
materialization so that module aliases are set up correctly.
weights from the GMS pool. `setup_aliases()` must run BEFORE
materialization so that module aliases are set up correctly, while derived
state is refreshed after real tensors are bound. RO is validated for models
whose `post_load_weights()` is pure alias wiring; models that additionally
rely on plain Python attributes set inside `post_load_weights()` (rather
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
those side effects to `cache_derived_state()` or another hook that runs on
RO readers. One-shot tensor layout changes belong in `transform_weights()`
on the writer; the GMS RO reader runs `setup_aliases()` before
`materialize_module()`, then `cache_derived_state()` afterward. It does not
run `transform_weights()`.
"""

from contextlib import contextmanager
Expand Down Expand Up @@ -477,7 +486,7 @@ def materialize_module(self, model: nn.Module) -> None:
by GPU pointers from the shared memory region — no data copies,
no disk I/O, just CUDA VMM remapping. The model's submodule
layout must already match the writer's at commit time, including
any aliases / derived buffers introduced by `post_load_weights`.
any aliases introduced by `setup_aliases`.

Args:
model: The `nn.Module` to materialize. Walks the full
Expand All @@ -489,11 +498,10 @@ def materialize_module(self, model: nn.Module) -> None:
RuntimeError: If `connect()` has not been called yet.

Note:
`post_load_weights()` must be called on the model BEFORE
this method. The order ensures that any aliases / derived
parameters created by post-load hooks are present on the
module tree at materialization time, so they are bound to
the same GMS storage as their primary tensor.
`setup_aliases()` must be called on the model BEFORE this method.
The order ensures that any structural aliases created by post-load
hooks are present on the module tree at materialization time, so
they are bound to the same GMS storage as their primary tensor.
"""
if self._client is None:
raise RuntimeError("GMS client not connected. Call connect() first.")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Any

Expand Down Expand Up @@ -69,6 +72,17 @@ def is_weights_preloaded(self) -> bool:
"""Whether the last load wrote weights directly into the model."""
return False

def is_post_transform_weights_preloaded(self) -> bool:
"""Whether the last direct preload delivered post-transform weights.

This is narrower than :meth:`is_weights_preloaded`: a loader may write
bytes directly into the model while those bytes are still the raw
checkpoint layout. Only return ``True`` when the source identity was
verified and the incoming bytes can safely skip module
``transform_weights()`` hooks.
"""
return False

def post_load_apply(self,
model: nn.Module,
*,
Expand Down
Loading
Loading