Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 21 additions & 46 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@
from .utils import _set_var_distributed, divide, get_tensor, modules_to_convert


def may_be_do_cast(loaded_weight, param):

assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 兼容性 Breaking Change:原 else 分支 .cast() 改为 assert,破坏已有隐式 dtype 转换

原代码:

else:
    loaded_weight = loaded_weight.cast(param.dtype)

允许隐式 dtype 转换(如 float16 → float32),兼容多种权重文件格式。新代码改为 assert / raise 后,除 int8→float8_e4m3fn 外的所有 dtype 不匹配均会在模型加载时直接报错,属 Breaking Change。

需在 PR 描述中明确说明此语义变更的意图,并确认所有已支持模型/权重格式均不依赖隐式 dtype 转换,再合入。

assert (

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug assert 用于运行时 dtype 校验,在 Python -O 优化模式下会被跳过,导致错误静默

assert 语句在 Python 以 -O 参数运行时会被完全忽略,不适合用于运行时参数校验。

建议修复方式:

else:
    raise ValueError(
        f"loaded_weight.dtype: {loaded_weight.dtype}, param.dtype: {param.dtype}"
    )

loaded_weight.dtype == param.dtype

This comment was marked as outdated.

), f"loaded_weight.dtype: {loaded_weight.dtype}, param.dtype: {param.dtype}"
return loaded_weight


class UnquantizedLinearMethod(QuantMethodBase):
"""Linear method without quantization."""

Expand Down Expand Up @@ -407,15 +423,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
start=param_shard_offset,
end=param_shard_offset + param_shard_size,
)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
loaded_weight = may_be_do_cast(loaded_weight, param)
# (bukejiyu) After this fix, the early H2D copy for non-GPU devices is no longer needed and can be safely removed.
h2d_copy(param, loaded_weight)

Expand Down Expand Up @@ -592,16 +600,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)

loaded_weight = may_be_do_cast(loaded_weight, param)
h2d_copy(param, loaded_weight)

def load_state_dict(self, state_dict: dict):
Expand Down Expand Up @@ -753,15 +752,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
loaded_weight = may_be_do_cast(loaded_weight, param)
h2d_copy(param, loaded_weight)

def load_weight(self, state_dict: dict):
Expand Down Expand Up @@ -1279,15 +1270,7 @@ def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
loaded_weight = may_be_do_cast(loaded_weight, param)
h2d_copy(param, loaded_weight)

def gate_weight_loader(self, param, loaded_weight):
Expand Down Expand Up @@ -1319,15 +1302,7 @@ def gate_weight_loader(self, param, loaded_weight):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)

param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
loaded_weight = may_be_do_cast(loaded_weight, param)
h2d_copy(param, loaded_weight)

def load_weight(self, state_dict: dict):
Expand Down
6 changes: 3 additions & 3 deletions tests/model_executor/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ def test_merged_and_column_weight_paths():
layer_merge = MergedReplicatedLinear.__new__(MergedReplicatedLinear)
layer_merge.__dict__.update(fd_config=make_fd_config(model_format="paddle"), output_sizes=[2, 2])
param = TinyParam(paddle.zeros([2, 4], dtype="float32"), initialized=False, with_track=True)
loaded_weight = paddle.ones([2, 4], dtype="float16")
loaded_weight = paddle.ones([2, 4], dtype="float32")
layer_merge.weight_loader(param, loaded_weight, loaded_shard_id=None)
assert param.tensor_track.calls == [(0, loaded_weight.shape[-1])]
np.testing.assert_allclose(param._tensor.numpy(), np.ones((2, 4), dtype="float32"))
param_shard = TinyParam(paddle.zeros([2, 4], dtype="float32"), initialized=False)
layer_merge.weight_loader(param_shard, paddle.ones([2, 2], dtype="int8"), loaded_shard_id="gate")
layer_merge.weight_loader(param_shard, paddle.ones([2, 2], dtype="float32"), loaded_shard_id="gate")
assert param_shard._is_initialized() is True
assert not np.allclose(param_shard._tensor.numpy()[..., :2], 0)
assert np.allclose(param_shard._tensor.numpy()[..., 2:], 0)
Expand All @@ -213,7 +213,7 @@ def test_merged_and_column_weight_paths():
param_gate = TinyParam(paddle.zeros([2, 4], dtype="float32"), initialized=True)
param_gate.output_dim = True
param_gate.weight_need_transpose = True
layer_mc.weight_loader(param_gate, paddle.ones([4, 2], dtype="int8"), loaded_shard_id="gate")
layer_mc.weight_loader(param_gate, paddle.ones([4, 2], dtype="float32"), loaded_shard_id="gate")
assert not np.allclose(param_gate._tensor.numpy()[..., :2], 0)
assert np.allclose(param_gate._tensor.numpy()[..., 2:], 0)
layer_mc.local_rank = 1
Expand Down
Loading