Skip to content

Commit 48f875b

Browse files
committed
[#83887] Remove function/task arguments
Signed-off-by: Artur Bieniek <abieniek@internships.antmicro.com>
1 parent b59d9c3 commit 48f875b

8 files changed

Lines changed: 251 additions & 3 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ FetchContent_Declare(
1111
)
1212
FetchContent_MakeAvailable(slang)
1313

14-
add_executable(sv-bugpoint source/SvBugpoint.cpp source/Utils.cpp source/PairRemovers.cpp source/BodyRemover.cpp source/LabelRemover.cpp
14+
add_executable(sv-bugpoint source/SvBugpoint.cpp source/Utils.cpp source/PairRemovers.cpp source/SetRemovers.cpp source/BodyRemover.cpp source/LabelRemover.cpp
1515
source/BodyPartsRemover.cpp source/DeclRemover.cpp source/InstantationRemover.cpp source/BindRemover.cpp source/ModportRemover.cpp source/ContAssignRemover.cpp source/ParamAssignRemover.cpp source/ModuleRemover.cpp
1616
source/StatementsRemover.cpp source/MemberRemover.cpp source/ImportsRemover.cpp source/TypeSimplifier.cpp)
1717

source/PairRemovers.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "Utils.hpp"
55

66
class PairRemover : public SyntaxRewriter<PairRemover> {
7-
// each tranform yields removal of pair of nodes (based on locations in suplied pairs list)
7+
// each tranform yields removal of pair of nodes (based on locations in supplied pairs list)
88
public:
99
std::vector<std::pair<SourceRange, SourceRange>> pairs;
1010
std::pair<SourceRange, SourceRange> searchedPair;

source/SetRemovers.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
#include "SetRemovers.hpp"
3+
#include <slang/ast/ASTVisitor.h>
4+
#include <unordered_map>
5+
6+
class FunctionArgMapper : public ASTVisitor<FunctionArgMapper, true, true, true> {
7+
// Builds vector that maps the argument in definition to all calls
8+
public:
9+
std::vector<SetRemover::RemovalSet> removals;
10+
11+
void handle(const SubroutineSymbol& subroutine) {
12+
if (shouldProcess(subroutine)) {
13+
for (const FormalArgumentSymbol* arg : subroutine.getArguments()) {
14+
registerFormal(arg);
15+
}
16+
}
17+
18+
visitDefault(subroutine);
19+
}
20+
21+
void handle(const MethodPrototypeSymbol& prototype) {
22+
if (shouldProcess(prototype)) {
23+
for (const FormalArgumentSymbol* arg : prototype.getArguments()) {
24+
registerFormal(arg);
25+
}
26+
}
27+
28+
visitDefault(prototype);
29+
}
30+
31+
void handle(const CallExpression& call) {
32+
const SubroutineSymbol* subroutine = getSubroutine(call);
33+
if (!subroutine || !shouldProcess(*subroutine)) {
34+
visitDefault(call);
35+
return;
36+
}
37+
38+
auto formals = subroutine->getArguments();
39+
auto actuals = call.arguments();
40+
size_t count = std::min(formals.size(), actuals.size());
41+
for (size_t idx = 0; idx < count; idx++) {
42+
size_t removalIdx = registerFormal(formals[idx]);
43+
if (removalIdx == invalidIndex) {
44+
continue;
45+
}
46+
47+
if (const SyntaxNode* argNode = getArgumentNode(actuals[idx])) {
48+
removals[removalIdx].push_back(argNode);
49+
}
50+
}
51+
52+
visitDefault(call);
53+
}
54+
55+
private:
56+
static constexpr size_t invalidIndex = std::numeric_limits<size_t>::max();
57+
std::unordered_map<const FormalArgumentSymbol*, size_t> indexByFormal;
58+
59+
size_t registerFormal(const FormalArgumentSymbol* formal) {
60+
if (!formal) {
61+
return invalidIndex;
62+
}
63+
64+
if (auto it = indexByFormal.find(formal); it != indexByFormal.end()) {
65+
return it->second;
66+
}
67+
68+
const SyntaxNode* node = getFormalNode(*formal);
69+
if (!node) {
70+
indexByFormal.emplace(formal, invalidIndex);
71+
return invalidIndex;
72+
}
73+
74+
SetRemover::RemovalSet set;
75+
set.push_back(node);
76+
removals.push_back(std::move(set));
77+
size_t idx = removals.size() - 1;
78+
indexByFormal.emplace(formal, idx);
79+
return idx;
80+
}
81+
82+
static const SyntaxNode* getFormalNode(const FormalArgumentSymbol& formal) {
83+
if (auto syntax = formal.getSyntax()) {
84+
if (syntax->parent) {
85+
return syntax->parent;
86+
}
87+
return syntax;
88+
}
89+
return nullptr;
90+
}
91+
92+
static const SyntaxNode* getArgumentNode(const Expression* expr) {
93+
if (!expr || !expr->syntax) {
94+
return nullptr;
95+
}
96+
97+
const SyntaxNode* node = expr->syntax;
98+
while (node) {
99+
if (node->kind == SyntaxKind::OrderedArgument ||
100+
node->kind == SyntaxKind::NamedArgument) {
101+
return node;
102+
}
103+
node = node->parent;
104+
}
105+
return nullptr;
106+
}
107+
108+
static const SubroutineSymbol* getSubroutine(const CallExpression& call) {
109+
if (call.isSystemCall()) {
110+
return nullptr;
111+
}
112+
if (std::holds_alternative<const SubroutineSymbol*>(call.subroutine)) {
113+
return std::get<const SubroutineSymbol*>(call.subroutine);
114+
}
115+
return nullptr;
116+
}
117+
118+
static bool shouldProcess(const SubroutineSymbol& subroutine) {
119+
return subroutine.subroutineKind == SubroutineKind::Function ||
120+
subroutine.subroutineKind == SubroutineKind::Task;
121+
}
122+
123+
static bool shouldProcess(const MethodPrototypeSymbol& prototype) {
124+
return prototype.subroutineKind == SubroutineKind::Function ||
125+
prototype.subroutineKind == SubroutineKind::Task;
126+
}
127+
};
128+
129+
SetRemover makeFunctionArgRemover(std::shared_ptr<SyntaxTree> tree) {
130+
Compilation compilation;
131+
compilation.addSyntaxTree(tree);
132+
compilation.getAllDiagnostics();
133+
FunctionArgMapper mapper;
134+
compilation.getRoot().visit(mapper);
135+
return SetRemover(std::move(mapper.removals));
136+
}

source/SetRemovers.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
#pragma once
3+
#include <slang/syntax/SyntaxVisitor.h>
4+
#include "Utils.hpp"
5+
6+
class SetRemover : public SyntaxRewriter<SetRemover> {
7+
// each transform yields removal of a set of nodes (based on locations in supplied sets list)
8+
public:
9+
using RemovalSet = std::vector<const slang::syntax::SyntaxNode*>;
10+
11+
std::vector<RemovalSet> removals;
12+
std::unordered_set<const slang::syntax::SyntaxNode*> pendingNodes;
13+
std::string removedTypeInfo;
14+
15+
SetRemover(std::vector<RemovalSet>&& removals) : removals(removals) {}
16+
17+
std::shared_ptr<SyntaxTree> transform(const std::shared_ptr<SyntaxTree>& tree,
18+
bool& traversalDone,
19+
AttemptStats& stats) {
20+
if (removals.empty()) {
21+
traversalDone = true;
22+
return tree;
23+
}
24+
RemovalSet removal = removals.back();
25+
removals.pop_back();
26+
pendingNodes.clear();
27+
removedTypeInfo = "";
28+
for (const auto* node : removal) {
29+
if (node && node->sourceRange() != SourceRange::NoLocation) {
30+
pendingNodes.insert(node);
31+
}
32+
}
33+
if (pendingNodes.empty()) {
34+
return transform(tree, traversalDone, stats);
35+
}
36+
37+
auto tree2 = SyntaxRewriter<SetRemover>::transform(tree);
38+
39+
if (!pendingNodes.empty()) {
40+
return transform(tree, traversalDone, stats);
41+
}
42+
43+
stats.typeInfo = removedTypeInfo;
44+
traversalDone = removals.empty();
45+
return tree2;
46+
}
47+
48+
/// The default handler invoked when no visit() method is overridden for a particular type.
49+
/// Will visit all child nodes by default.
50+
template <typename T>
51+
void visitDefault(T&& node) {
52+
for (uint32_t i = 0; i < node.getChildCount(); i++) {
53+
auto child = node.childNode(i);
54+
if (child) {
55+
child->visit(*this, node.isChildOptional(i));
56+
}
57+
}
58+
}
59+
60+
template <typename T>
61+
void logType() {
62+
std::cerr << STRINGIZE_NODE_TYPE(T) << "\n";
63+
removedTypeInfo += (removedTypeInfo.empty() ? "" : ",") + STRINGIZE_NODE_TYPE(T);
64+
}
65+
66+
template <typename T>
67+
void visit(T&& node, bool isNodeRemovable = true) {
68+
const auto* ptr = &node;
69+
auto it = pendingNodes.find(ptr);
70+
if (it != pendingNodes.end() && isNodeRemovable &&
71+
node.sourceRange() != SourceRange::NoLocation) {
72+
logType<T>();
73+
std::cerr << prefixLines(node.toString(), "-") << "\n";
74+
remove(node);
75+
pendingNodes.erase(it);
76+
return;
77+
}
78+
visitDefault(node);
79+
}
80+
};
81+
82+
SetRemover makeFunctionArgRemover(std::shared_ptr<SyntaxTree> tree);

source/SvBugpoint.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
#include <iostream>
1111
#include "OneTimeRewritersFwd.hpp"
1212
#include "PairRemovers.hpp"
13+
#include "SetRemovers.hpp"
1314
#include "Utils.hpp"
1415

1516
using namespace slang::syntax;
1617
using namespace slang::ast;
1718
using namespace slang;
1819

19-
bool rewriteLoop(PairRemover rewriter,
20+
template <typename Rewriter>
21+
bool rewriteLoop(Rewriter rewriter,
2022
std::shared_ptr<SyntaxTree>& tree,
2123
std::string stageName,
2224
std::string passIdx,
@@ -224,6 +226,8 @@ bool SvBugpoint::pass(const std::string& passIdx) {
224226
commited |= rewriteLoop<ModportRemover>(tree, "modportRemover", passIdx, this);
225227
commited |= rewriteLoop(makePortsRemover(tree), tree, "portsRemover", passIdx, this);
226228
commited |= rewriteLoop(makeStructFieldRemover(tree), tree, "structRemover", passIdx, this);
229+
commited |=
230+
rewriteLoop(makeFunctionArgRemover(tree), tree, "functionArgRemover", passIdx, this);
227231
commited |= rewriteLoop<ModuleRemover>(tree, "moduleRemover", passIdx, this);
228232
commited |= rewriteLoop<TypeSimplifier>(tree, "typeSimplifier", passIdx, this);
229233
commited |= rewriteLoop<LabelRemover>(tree, "LabelRemover", passIdx, this);

tests/Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ test_labels:
6262
test_dpi:
6363
@./run_test dpi checkverilator_lint.sh ${INPUT_DIR}/dpi.sv
6464

65+
.PHONY: test_argument
66+
test_argument:
67+
@./run_test argument checkverilator_run_finish.sh ${INPUT_DIR}/argument.sv
68+
6569
.PHONY: test_caliptra
6670
test_caliptra: test_caliptra_exit0 test_caliptra_grep
6771

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module t;
2+
int i=0;
3+
function void f();
4+
if(i==1) $finish;
5+
i=i+1;
6+
endfunction;
7+
initial begin
8+
f();
9+
f();
10+
end;
11+
endmodule

tests/input_files/argument.sv

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module t;
2+
int i=0;
3+
function void f(int arg);
4+
if(i==1) $finish;
5+
i=i+1;
6+
endfunction;
7+
initial begin;
8+
f(1);
9+
f(2);
10+
end;
11+
endmodule

0 commit comments

Comments
 (0)