Skip to content

Commit f8b7f03

Browse files
committed
opt: 新增递归倍加模乘惯用法识别,折叠为 64 位宽乘取模
新增: ModMulIdiom pass 实现:新增专用指令 MulModInst(operands a,b + 模数字段,结果 i32), toString 展开为合法 LLVM IR(sext/mul/srem/trunc i64) 以兼容 llvmir 测试模式;后端 translate_mulmod 用 64 位 mul + li/rem 降落;识别 pass def-use 驱动匹配,注册在 Mem2Reg 之后、GVN 之前。
1 parent 79374fd commit f8b7f03

9 files changed

Lines changed: 479 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ set(IR_SRCS
182182
ir/Instructions/StoreInst.cpp
183183
ir/Instructions/BinaryInst.h
184184
ir/Instructions/BinaryInst.cpp
185+
ir/Instructions/MulModInst.h
186+
ir/Instructions/MulModInst.cpp
185187
ir/Instructions/FCmpInst.h
186188
ir/Instructions/FCmpInst.cpp
187189
ir/Instructions/ICmpInst.h
@@ -297,6 +299,8 @@ set(IR_SRCS
297299
ir/passes/modulePass/SmallFunctionInline.cpp
298300
ir/passes/functionPass/TailRecursionElim.h
299301
ir/passes/functionPass/TailRecursionElim.cpp
302+
ir/passes/functionPass/ModMulIdiom.h
303+
ir/passes/functionPass/ModMulIdiom.cpp
300304
ir/passes/functionPass/LateLoopCFGCleanup.h
301305
ir/passes/functionPass/LateLoopCFGCleanup.cpp
302306
ir/passes/functionPass/PhiLowering.h

backend/riscv64/InstSelectorRiscV64.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "AllocaInst.h"
2525
#include "BasicBlock.h"
2626
#include "BinaryInst.h"
27+
#include "MulModInst.h"
2728
#include "BranchInst.h"
2829
#include "CallInst.h"
2930
#include "ConstFloat.h"
@@ -504,6 +505,7 @@ InstSelectorRiscV64::InstSelectorRiscV64(
504505
translatorHandlers[IRInstOperator::IRINST_OP_MUL_I] = &InstSelectorRiscV64::translate_mul;
505506
translatorHandlers[IRInstOperator::IRINST_OP_DIV_I] = &InstSelectorRiscV64::translate_div;
506507
translatorHandlers[IRInstOperator::IRINST_OP_MOD_I] = &InstSelectorRiscV64::translate_mod;
508+
translatorHandlers[IRInstOperator::IRINST_OP_MULMOD_I] = &InstSelectorRiscV64::translate_mulmod;
507509
translatorHandlers[IRInstOperator::IRINST_OP_SHL_I] = &InstSelectorRiscV64::translate_shl;
508510
translatorHandlers[IRInstOperator::IRINST_OP_ASHR_I] = &InstSelectorRiscV64::translate_ashr;
509511
translatorHandlers[IRInstOperator::IRINST_OP_LSHR_I] = &InstSelectorRiscV64::translate_lshr;
@@ -1410,6 +1412,51 @@ void InstSelectorRiscV64::translate_mod(Instruction * inst)
14101412
translate_binary(inst, "remw");
14111413
}
14121414

1415+
/// @brief 翻译宽乘取模指令 (i64)a*b % m
1416+
///
1417+
/// 两个 i32 操作数以有符号扩展形式驻留 64 位寄存器,用 64 位 mul 得到精确的
1418+
/// 64 位积(调用点守卫保证 0<=a<m、b>=0,积非负且 < 2^61 不溢出),再对正常量
1419+
/// 取有符号 64 位余数。余数落在 [0, m) 内,可直接当作已符号扩展的 i32 使用
1420+
void InstSelectorRiscV64::translate_mulmod(Instruction * inst)
1421+
{
1422+
auto * mulmod = dynamic_cast<MulModInst *>(inst);
1423+
if (mulmod == nullptr) {
1424+
return;
1425+
}
1426+
const int32_t modulus = mulmod->getModulus();
1427+
1428+
int dstReg = getResultReg(inst);
1429+
LocalTempManager::Lease dstLease;
1430+
if (dstReg < 0) {
1431+
dstLease = tempMgr.borrow(inst);
1432+
dstReg = dstLease.reg();
1433+
}
1434+
1435+
OperandReg lhs = loadOperand(mulmod->getA(), inst, dstReg);
1436+
const int rhsPreferredReg = lhs.reg != dstReg ? dstReg : -1;
1437+
OperandReg rhs = loadOperand(mulmod->getB(), inst, rhsPreferredReg < 0 ? dstReg : -1, rhsPreferredReg);
1438+
1439+
// 64 位无截断乘法
1440+
iloc.inst("mul",
1441+
PlatformRiscV64::regName[dstReg],
1442+
PlatformRiscV64::regName[lhs.reg],
1443+
PlatformRiscV64::regName[rhs.reg]);
1444+
1445+
releaseOperand(rhs);
1446+
releaseOperand(lhs);
1447+
1448+
// 对常量取 64 位有符号余数
1449+
auto modTmp = tempMgr.borrowExcluding(inst, {dstReg});
1450+
iloc.load_imm(modTmp.reg(), modulus);
1451+
iloc.inst("rem",
1452+
PlatformRiscV64::regName[dstReg],
1453+
PlatformRiscV64::regName[dstReg],
1454+
PlatformRiscV64::regName[modTmp.reg()]);
1455+
modTmp.release();
1456+
1457+
storeResult(inst, dstReg, inst);
1458+
}
1459+
14131460
/// @brief 翻译逻辑左移指令(shl)
14141461
void InstSelectorRiscV64::translate_shl(Instruction * inst)
14151462
{

backend/riscv64/InstSelectorRiscV64.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class InstSelectorRiscV64 {
130130
void translate_div(Instruction * inst);
131131
/// @brief 翻译mod指令(取模)
132132
void translate_mod(Instruction * inst);
133+
/// @brief 翻译宽乘取模指令((i64)a*b % m,64 位无截断乘后对常量取模)
134+
void translate_mulmod(Instruction * inst);
133135
/// @brief 翻译逻辑左移指令(shl)
134136
void translate_shl(Instruction * inst);
135137
/// @brief 翻译算术右移指令(ashr,保留符号位)

ir/Instruction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ enum class IRInstOperator : std::int8_t {
1919
IRINST_OP_MUL_I,
2020
IRINST_OP_DIV_I,
2121
IRINST_OP_MOD_I,
22+
IRINST_OP_MULMOD_I, ///< 64 位宽乘后对常量取模:(i64)a * (i64)b % m,结果为 i32
2223
IRINST_OP_SHL_I, ///< 逻辑左移
2324
IRINST_OP_ASHR_I, ///< 算术右移(保留符号位)
2425
IRINST_OP_LSHR_I, ///< 逻辑右移(高位补 0)

ir/Instructions/MulModInst.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
///
2+
/// @file MulModInst.cpp
3+
/// @brief 宽乘取模指令实现
4+
///
5+
6+
#include "MulModInst.h"
7+
8+
#include <string>
9+
10+
#include "Function.h"
11+
#include "IntegerType.h"
12+
#include "Value.h"
13+
14+
MulModInst::MulModInst(Function * func, Value * a, Value * b, int32_t m)
15+
: Instruction(func, IRInstOperator::IRINST_OP_MULMOD_I, IntegerType::getTypeInt32()), modulus(m)
16+
{
17+
addOperand(a);
18+
addOperand(b);
19+
}
20+
21+
Value * MulModInst::getA()
22+
{
23+
return getOperand(0);
24+
}
25+
26+
Value * MulModInst::getB()
27+
{
28+
return getOperand(1);
29+
}
30+
31+
void MulModInst::toString(std::string & str)
32+
{
33+
// 展开为标准 LLVM IR:把两操作数符号扩展到 i64 后宽乘,再对模数取有符号余数并截回 i32
34+
// 临时名以本指令结果名派生,保证函数内唯一;多行之间补两空格缩进与 .ll 对齐
35+
const std::string dst = getIRName();
36+
const std::string sa = dst + ".sea";
37+
const std::string sb = dst + ".seb";
38+
const std::string prod = dst + ".w64";
39+
const std::string rem = dst + ".r64";
40+
const std::string m = std::to_string(modulus);
41+
42+
str = sa + " = sext i32 " + getA()->getIRName() + " to i64\n";
43+
str += " " + sb + " = sext i32 " + getB()->getIRName() + " to i64\n";
44+
str += " " + prod + " = mul i64 " + sa + ", " + sb + "\n";
45+
str += " " + rem + " = srem i64 " + prod + ", " + m + "\n";
46+
str += " " + dst + " = trunc i64 " + rem + " to i32";
47+
}

ir/Instructions/MulModInst.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
///
2+
/// @file MulModInst.h
3+
/// @brief 宽乘取模指令 (i64)a * (i64)b % m
4+
///
5+
/// 表示对两个 i32 操作数做 64 位无截断乘法后,再对编译期正常量 m 取模,
6+
/// 结果回落到 i32。用于把"递归倍加模乘"惯用法折叠成 O(1) 的单条宽乘加取模,
7+
/// 避免 32 位乘法溢出。语义上 a、b 均按有符号扩展到 64 位参与运算
8+
///
9+
10+
#pragma once
11+
12+
#include <cstdint>
13+
14+
#include "Instruction.h"
15+
16+
class Value;
17+
class Function;
18+
19+
class MulModInst final : public Instruction {
20+
21+
public:
22+
/// @brief 构造宽乘取模指令
23+
/// @param func 所在函数
24+
/// @param a 被乘数(i32)
25+
/// @param b 乘数(i32)
26+
/// @param m 取模的正常量(编译期已知)
27+
MulModInst(Function * func, Value * a, Value * b, int32_t m);
28+
29+
/// @brief 获取被乘数
30+
Value * getA();
31+
32+
/// @brief 获取乘数
33+
Value * getB();
34+
35+
/// @brief 获取模数常量
36+
[[nodiscard]] int32_t getModulus() const
37+
{
38+
return modulus;
39+
}
40+
41+
/// @brief 序列化为等价的 LLVM IR 文本(sext/mul/srem/trunc 展开)
42+
void toString(std::string & str) override;
43+
44+
private:
45+
int32_t modulus = 0;
46+
};

ir/passes/PassManager.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "functionPass/LateLoopCFGCleanup.h"
3737
#include "functionPass/LoopRotate.h"
3838
#include "functionPass/Mem2Reg.h"
39+
#include "functionPass/ModMulIdiom.h"
3940
#include "functionPass/PhiToSelect.h"
4041
#include "functionPass/PhiLowering.h"
4142
#include "functionPass/PureCallCSE.h"
@@ -104,6 +105,13 @@ void PassManager::registerDefaultOptimizationPipeline(int32_t optLevel, bool ena
104105
return false;
105106
});
106107

108+
// 递归倍加模乘惯用法识别:须在 Mem2Reg 后(依赖 SSA 分支形态)、
109+
// GVN/InstCombine 前(避免 srem/sdiv 被变形破坏匹配)
110+
registerFunctionPass("ModMulIdiom", [this](Function * func) {
111+
ModMulIdiom pass(func, module);
112+
return pass.run();
113+
});
114+
107115
registerFunctionPass("GVN", [this](Function * func) {
108116
GVN pass(func, module);
109117
return pass.run();

0 commit comments

Comments
 (0)