Skip to content

Commit 849eb3d

Browse files
authored
[Cherry-Pick][Optimization] merge matmul and add (#6986) (#7191)
* merge matmul and add * modify format * using paddle.nn.functional.linear * using _C_ops.linear * using paddle.nn.functional.linear * add FLAGS_use_legacy_linear env var in test case * fix format * add assert and remove env * modify format * using matmul for no bias * modify accurate baseline
1 parent 098dd2c commit 849eb3d

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,17 @@ def process_loaded_weights(self, layer, weights) -> None:
8282
layer.weight.set_value(weights)
8383

8484
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
85-
linear_out = paddle.matmul(x, layer.weight)
8685
if layer.with_bias:
87-
linear_out = paddle.add(linear_out, layer.bias)
88-
return linear_out
86+
bias = layer.bias
87+
assert bias.dim() == 1 and bias.shape[-1] == layer.weight.shape[-1], (
88+
f"bias must be 1D with size equal to the last dim of weight, "
89+
f"but got bias.shape={bias.shape}, weight.shape[-1]={layer.weight.shape[-1]}"
90+
)
91+
out = paddle.nn.functional.linear(x, layer.weight, bias)
92+
else:
93+
out = paddle.matmul(x, layer.weight)
94+
95+
return out
8996

9097

9198
class LinearBase(nn.Layer):

tests/e2e/utils/rollout_routing_replay_test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode
157157
model_path = os.getenv("MODEL_PATH")
158158
if model_path:
159159
baseline_path = os.path.join(
160-
model_path, f"R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}"
160+
model_path, f"R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}"
161161
)
162162
else:
163-
baseline_path = f"./R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}"
163+
baseline_path = f"./R3_BaseLine_dev_uint8_0403/routing_replay_output_baseline_{model_name}"
164164
stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream")
165165

166166
nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream")

0 commit comments

Comments
 (0)