Skip to content

Commit 1d383ef

Browse files
IR2Vec.cpp refactor
1 parent 76b3142 commit 1d383ef

1 file changed

Lines changed: 120 additions & 87 deletions

File tree

src/IR2Vec.cpp

Lines changed: 120 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,102 @@ void printVersion(raw_ostream &ostream) {
7373
cl::PrintVersionMessage();
7474
}
7575

76-
int main(int argc, char **argv) {
77-
cl::SetVersionPrinter(printVersion);
78-
cl::HideUnrelatedOptions(category);
76+
struct SymOutputs {
77+
std::ofstream out;
78+
};
79+
80+
struct FAOutputs : SymOutputs {
81+
std::ofstream miss;
82+
std::ofstream cyclic;
83+
};
84+
85+
inline SymOutputs openSymOutputs(const std::string &baseName) {
86+
SymOutputs f;
87+
f.out.open(baseName, std::ios_base::app);
88+
return f;
89+
}
90+
91+
inline FAOutputs openFAOutputs(const std::string &baseName) {
92+
FAOutputs f;
93+
f.out.open(baseName, std::ios_base::app);
94+
f.miss.open("missCount_" + baseName, std::ios_base::app);
95+
f.cyclic.open("cyclicCount_" + baseName, std::ios_base::app);
96+
return f;
97+
}
98+
99+
template <class F>
100+
inline void runMaybeTimed(bool shouldTime, const char *timingMsgFmt, F &&job) {
101+
if (shouldTime) {
102+
const clock_t start = clock();
103+
std::forward<F>(job)();
104+
const clock_t end = clock();
105+
const double elapsed = static_cast<double>(end - start) / CLOCKS_PER_SEC;
106+
std::printf(timingMsgFmt, elapsed);
107+
} else {
108+
std::forward<F>(job)();
109+
}
110+
}
111+
112+
template <class Encoder, class Outputs, class OutputsFactory, class Body>
113+
inline void executeEncoder(const char *timingMsgFmt, bool shouldTime,
114+
OutputsFactory &&makeOutputs, Body &&body) {
115+
auto M = getLLVMIR();
116+
auto vocabulary = VocabularyFactory::createVocabulary(DIM)->getVocabulary();
117+
Encoder encoder(*M, vocabulary);
118+
auto files = std::forward<OutputsFactory>(makeOutputs)(oname);
119+
120+
auto job = [&] { std::forward<Body>(body)(encoder, files); };
121+
runMaybeTimed(shouldTime, timingMsgFmt, job);
122+
}
123+
124+
void generateFAEncodingsFunction(std::string funcName) {
125+
executeEncoder<IR2Vec_FA, FAOutputs>(
126+
"Time taken by on-demand generation of flow-aware encodings is: %.6f "
127+
"seconds.\n",
128+
printTime, openFAOutputs, [&, funcName](IR2Vec_FA &FA, FAOutputs &files) {
129+
FA.generateFlowAwareEncodingsForFunction(&files.out, funcName,
130+
&files.miss, &files.cyclic);
131+
});
132+
}
133+
134+
void generateFAEncodings() {
135+
executeEncoder<IR2Vec_FA, FAOutputs>(
136+
"Time taken by normal generation of flow-aware encodings is: %.6f "
137+
"seconds.\n",
138+
printTime, openFAOutputs, [&](IR2Vec_FA &FA, FAOutputs &files) {
139+
FA.generateFlowAwareEncodings(&files.out, &files.miss, &files.cyclic);
140+
});
141+
}
142+
143+
void generateSymEncodingsFunction(std::string funcName) {
144+
executeEncoder<IR2Vec_Symbolic, SymOutputs>(
145+
"Time taken by on-demand generation of symbolic encodings is: %.6f "
146+
"seconds.\n",
147+
printTime, openSymOutputs,
148+
[&, funcName](IR2Vec_Symbolic &SYM, SymOutputs &files) {
149+
SYM.generateSymbolicEncodingsForFunction(&files.out, funcName);
150+
});
151+
}
152+
153+
void generateSYMEncodings() {
154+
executeEncoder<IR2Vec_Symbolic, SymOutputs>(
155+
"Time taken by normal generation of symbolic encodings is: %.6f "
156+
"seconds.\n",
157+
printTime, openSymOutputs, [&](IR2Vec_Symbolic &SYM, SymOutputs &files) {
158+
SYM.generateSymbolicEncodings(&files.out);
159+
});
160+
}
161+
162+
void collectIRfunc() {
163+
auto M = getLLVMIR();
164+
CollectIR cir(M);
165+
std::ofstream o;
166+
o.open(oname, std::ios_base::app);
167+
cir.generateTriplets(o);
168+
o.close();
169+
}
170+
171+
void setGlobalVars(int argc, char **argv) {
79172
cl::ParseCommandLineOptions(argc, argv);
80173

81174
fa = cl_fa;
@@ -92,113 +185,53 @@ int main(int argc, char **argv) {
92185
WT = cl_WT;
93186
debug = cl_debug;
94187
printTime = cl_printTime;
188+
}
95189

190+
void checkFailureConditions() {
96191
bool failed = false;
97-
if (!((sym ^ fa) ^ collectIR)) {
98-
errs() << "Either of sym, fa or collectIR should be specified\n";
192+
193+
if (!(sym || fa || collectIR)) {
194+
errs() << "Either of sym, fa, or collectIR should be specified\n";
99195
failed = true;
100196
}
101197

198+
if (failed)
199+
exit(1);
200+
102201
if (sym || fa) {
103202
if (level != 'p' && level != 'f') {
104203
errs() << "Invalid level specified: Use either p or f\n";
105204
failed = true;
106205
}
107206
} else {
108-
if (!collectIR) {
109-
errs() << "Either of sym, fa or collectIR should be specified\n";
110-
failed = true;
111-
} else if (level)
207+
assert(collectIR == true);
208+
209+
if (collectIR && level) {
112210
errs() << "[WARNING] level would not be used in collectIR mode\n";
211+
}
113212
}
114213

115214
if (failed)
116215
exit(1);
216+
}
117217

118-
auto M = getLLVMIR();
119-
auto vocabulary = VocabularyFactory::createVocabulary(DIM)->getVocabulary();
218+
int main(int argc, char **argv) {
219+
cl::SetVersionPrinter(printVersion);
220+
cl::HideUnrelatedOptions(category);
221+
setGlobalVars(argc, argv);
222+
checkFailureConditions();
120223

121-
// newly added
122224
if (sym && !(funcName.empty())) {
123-
IR2Vec_Symbolic SYM(*M, vocabulary);
124-
std::ofstream o;
125-
o.open(oname, std::ios_base::app);
126-
if (printTime) {
127-
clock_t start = clock();
128-
SYM.generateSymbolicEncodingsForFunction(&o, funcName);
129-
clock_t end = clock();
130-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
131-
printf("Time taken by on-demand generation of symbolic encodings "
132-
"is: %.6f "
133-
"seconds.\n",
134-
elapsed);
135-
} else {
136-
SYM.generateSymbolicEncodingsForFunction(&o, funcName);
137-
}
138-
o.close();
225+
generateSymEncodingsFunction(funcName);
139226
} else if (fa && !(funcName.empty())) {
140-
IR2Vec_FA FA(*M, vocabulary);
141-
std::ofstream o, missCount, cyclicCount;
142-
o.open(oname, std::ios_base::app);
143-
missCount.open("missCount_" + oname, std::ios_base::app);
144-
cyclicCount.open("cyclicCount_" + oname, std::ios_base::app);
145-
if (printTime) {
146-
clock_t start = clock();
147-
FA.generateFlowAwareEncodingsForFunction(&o, funcName, &missCount,
148-
&cyclicCount);
149-
clock_t end = clock();
150-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
151-
printf("Time taken by on-demand generation of flow-aware encodings "
152-
"is: %.6f "
153-
"seconds.\n",
154-
elapsed);
155-
} else {
156-
FA.generateFlowAwareEncodingsForFunction(&o, funcName, &missCount,
157-
&cyclicCount);
158-
}
159-
o.close();
227+
generateFAEncodingsFunction(funcName);
160228
} else if (fa) {
161-
IR2Vec_FA FA(*M, vocabulary);
162-
std::ofstream o, missCount, cyclicCount;
163-
o.open(oname, std::ios_base::app);
164-
missCount.open("missCount_" + oname, std::ios_base::app);
165-
cyclicCount.open("cyclicCount_" + oname, std::ios_base::app);
166-
if (printTime) {
167-
clock_t start = clock();
168-
FA.generateFlowAwareEncodings(&o, &missCount, &cyclicCount);
169-
clock_t end = clock();
170-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
171-
printf("Time taken by normal generation of flow-aware encodings "
172-
"is: %.6f "
173-
"seconds.\n",
174-
elapsed);
175-
} else {
176-
FA.generateFlowAwareEncodings(&o, &missCount, &cyclicCount);
177-
}
178-
o.close();
229+
generateFAEncodings();
179230
} else if (sym) {
180-
IR2Vec_Symbolic SYM(*M, vocabulary);
181-
std::ofstream o;
182-
o.open(oname, std::ios_base::app);
183-
if (printTime) {
184-
clock_t start = clock();
185-
SYM.generateSymbolicEncodings(&o);
186-
clock_t end = clock();
187-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
188-
printf("Time taken by normal generation of symbolic encodings is: "
189-
"%.6f "
190-
"seconds.\n",
191-
elapsed);
192-
} else {
193-
SYM.generateSymbolicEncodings(&o);
194-
}
195-
o.close();
231+
generateSYMEncodings();
196232
} else if (collectIR) {
197-
CollectIR cir(M);
198-
std::ofstream o;
199-
o.open(oname, std::ios_base::app);
200-
cir.generateTriplets(o);
201-
o.close();
233+
collectIRfunc();
202234
}
235+
203236
return 0;
204237
}

0 commit comments

Comments
 (0)