Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,17 @@ def process_loaded_weights(self, layer, weights) -> None:
layer.weight.set_value(weights)

def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
linear_out = paddle.matmul(x, layer.weight)
if layer.with_bias:
linear_out = paddle.add(linear_out, layer.bias)
return linear_out
bias = layer.bias
assert bias.dim() == 1 and bias.shape[-1] == layer.weight.shape[-1], (
f"bias must be 1D with size equal to the last dim of weight, "
f"but got bias.shape={bias.shape}, weight.shape[-1]={layer.weight.shape[-1]}"
)
out = paddle.nn.functional.linear(x, layer.weight, bias)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 对于 paddle 格式的模型,layer.weight 的形状为 [input_size, output_size],而 paddle.nn.functional.linear 期望的 weight 形状为 (output_size, input_size)

根据代码分析:

  • torch 格式:layer.weight 形状为 [output_size, input_size](在 create_weights 中转置)
  • paddle 格式:layer.weight 形状为 [input_size, output_size](未转置)
  • UnquantizedLinearMethodprocess_weights_after_loading 会被跳过

对于 paddle 格式的模型,直接使用 paddle.nn.functional.linear(x, layer.weight, bias) 可能会导致形状不匹配错误。

建议:

  1. 验证 paddle 格式模型的兼容性
  2. 如果不支持,考虑添加条件判断或注释说明

else:
out = paddle.matmul(x, layer.weight)

return out


class LinearBase(nn.Layer):
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/utils/rollout_routing_replay_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode
model_path = os.getenv("MODEL_PATH")
if model_path:
baseline_path = os.path.join(
model_path, f"R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}"
model_path, f"R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}"
)
else:
baseline_path = f"./R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}"
baseline_path = f"./R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}"
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")

nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")
Expand Down
Loading