Skip to content

Commit 937183e

Browse files
committed
fixes
1 parent 03a77cc commit 937183e

3 files changed

Lines changed: 271 additions & 65 deletions

File tree

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/LuaNativeLowering.java

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -76,59 +76,44 @@ private LuaNativeLowering() {}
7676
*
7777
* <p>Must be called <em>before</em> the optimizer so that the optimizer
7878
* can inline and eliminate the generated wrappers.
79+
*
80+
* <p>Stubs and wrappers are created lazily (on first call-site encounter) and added
81+
* to prog only after the traversal completes. This avoids the memory cost of
82+
* creating wrappers for every BJ function in the IM (common.j declares hundreds of
83+
* functions, most of which are unreachable in any given program).
7984
*/
8085
public static void transform(ImProg prog) {
81-
// Pre-scan: find which BJ functions are actually called, so we only create stubs/wrappers
82-
// for reachable functions. Creating wrappers for all BJ functions in the IM (common.j has
83-
// hundreds of them) would be extremely memory-intensive.
84-
Set<ImFunction> calledBjFuncs = new LinkedHashSet<>();
85-
prog.accept(new Element.DefaultVisitor() {
86-
@Override
87-
public void visit(ImFunctionCall call) {
88-
super.visit(call);
89-
if (call.getFunc().isBj()) {
90-
calledBjFuncs.add(call.getFunc());
91-
}
92-
}
93-
});
94-
95-
if (calledBjFuncs.isEmpty()) {
96-
return;
97-
}
98-
99-
// Maps original BJ function → replacement (either a IS_NATIVE stub or a nil-safety wrapper)
86+
// Maps original BJ function → replacement (IS_NATIVE stub or nil-safety wrapper).
87+
// Populated lazily during the traversal.
10088
Map<ImFunction, ImFunction> replacements = new LinkedHashMap<>();
101-
// Nil-safety wrappers are collected separately and added to prog AFTER the traversal,
102-
// so the traversal does not visit their bodies and replace their internal BJ delegate calls.
103-
List<ImFunction> deferredWrappers = new ArrayList<>();
104-
105-
for (ImFunction f : calledBjFuncs) {
106-
String name = f.getName();
89+
// BJ functions that don't need a replacement (not GetHandleId, not hashtable/callback,
90+
// no handle params). Cached to avoid rechecking the same function at every call site.
91+
Set<ImFunction> noReplacement = new HashSet<>();
92+
// All generated functions (stubs and wrappers) are deferred until after the traversal:
93+
// - Stubs: deferred so ConcurrentModificationException is avoided on prog.getFunctions()
94+
// - Wrappers: deferred so the visitor doesn't see their internal BJ delegate calls and
95+
// recursively wrap them, which would cause infinite wrapping.
96+
List<ImFunction> deferredAdditions = new ArrayList<>();
10797

108-
if ("GetHandleId".equals(name)) {
109-
replacements.put(f, createNativeStub("__wurst_GetHandleId", f, prog));
110-
} else if (HASHTABLE_NATIVE_NAMES.contains(name)) {
111-
replacements.put(f, createNativeStub("__wurst_" + name, f, prog));
112-
} else if (CONTEXT_CALLBACK_NATIVE_NAMES.contains(name)) {
113-
replacements.put(f, createNativeStub("__wurst_" + name, f, prog));
114-
} else if (hasHandleParam(f)) {
115-
ImFunction wrapper = createNilSafeWrapper(f);
116-
replacements.put(f, wrapper);
117-
deferredWrappers.add(wrapper);
118-
}
119-
}
120-
121-
if (replacements.isEmpty()) {
122-
return;
123-
}
124-
125-
// Replace all call sites in the existing IM (before adding wrappers).
126-
// Wrappers are deferred so their internal BJ delegate calls are not replaced.
12798
prog.accept(new Element.DefaultVisitor() {
12899
@Override
129100
public void visit(ImFunctionCall call) {
130101
super.visit(call);
131-
ImFunction replacement = replacements.get(call.getFunc());
102+
ImFunction f = call.getFunc();
103+
if (!f.isBj()) return;
104+
if (noReplacement.contains(f)) return;
105+
106+
if (!replacements.containsKey(f)) {
107+
ImFunction r = computeReplacement(f);
108+
if (r != null) {
109+
replacements.put(f, r);
110+
deferredAdditions.add(r);
111+
} else {
112+
noReplacement.add(f);
113+
}
114+
}
115+
ImFunction replacement = replacements.get(f);
116+
132117
if (replacement != null) {
133118
call.replaceBy(JassIm.ImFunctionCall(
134119
call.attrTrace(), replacement,
@@ -137,30 +122,45 @@ public void visit(ImFunctionCall call) {
137122
false, CallType.NORMAL));
138123
}
139124
}
125+
126+
private ImFunction computeReplacement(ImFunction bj) {
127+
String name = bj.getName();
128+
if ("GetHandleId".equals(name)) {
129+
return createNativeStub("__wurst_GetHandleId", bj);
130+
} else if (HASHTABLE_NATIVE_NAMES.contains(name)) {
131+
return createNativeStub("__wurst_" + name, bj);
132+
} else if (CONTEXT_CALLBACK_NATIVE_NAMES.contains(name)) {
133+
return createNativeStub("__wurst_" + name, bj);
134+
} else if (hasHandleParam(bj)) {
135+
return createNilSafeWrapper(bj);
136+
}
137+
return null;
138+
}
140139
});
141140

142-
// Add nil-safety wrapper functions AFTER traversal so their own bodies are not traversed.
143-
prog.getFunctions().addAll(deferredWrappers);
141+
// Add all generated functions after the traversal so their bodies are not visited
142+
// by the replacement visitor above.
143+
prog.getFunctions().addAll(deferredAdditions);
144144
}
145145

146146
/**
147147
* Creates a new IS_NATIVE (non-BJ) IM function stub with the same signature as
148148
* {@code original}. The Lua translator will fill in the body via
149149
* {@code LuaNatives.get()} when it encounters the stub.
150+
*
151+
* <p>The caller is responsible for adding the stub to prog.getFunctions().
150152
*/
151-
private static ImFunction createNativeStub(String name, ImFunction original, ImProg prog) {
153+
private static ImFunction createNativeStub(String name, ImFunction original) {
152154
ImVars params = JassIm.ImVars();
153155
for (ImVar p : original.getParameters()) {
154156
params.add(JassIm.ImVar(p.attrTrace(), p.getType().copy(), p.getName(), false));
155157
}
156-
ImFunction stub = JassIm.ImFunction(
158+
return JassIm.ImFunction(
157159
original.attrTrace(), name,
158160
JassIm.ImTypeVars(), params,
159161
original.getReturnType().copy(),
160162
JassIm.ImVars(), JassIm.ImStmts(),
161163
Collections.singletonList(FunctionFlagEnum.IS_NATIVE));
162-
prog.getFunctions().add(stub);
163-
return stub;
164164
}
165165

166166
/**

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaAssertions.java

Lines changed: 206 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,49 +63,241 @@ public static void assertNoLeakedHashtableNativeCalls(String luaCode) {
6363
}
6464
}
6565

66+
/**
67+
* Collects all function names that appear as CALLS in the Lua source.
68+
*
69+
* Skips string literals, comments, and function declaration names (including
70+
* method-syntax declarations like {@code function Foo:bar()} or
71+
* {@code function Foo.bar()}) to avoid false positives.
72+
*/
6673
static Set<String> collectCalledFunctionNames(String text) {
6774
Set<String> result = new HashSet<>();
6875
int length = text.length();
6976
int index = 0;
7077
while (index < length) {
71-
if (!isIdentifierStart(text.charAt(index))) {
72-
index++;
78+
char ch = text.charAt(index);
79+
80+
// Skip Lua comments: -- short or --[[ long ]]
81+
if (ch == '-' && index + 1 < length && text.charAt(index + 1) == '-') {
82+
int longLevel = countLongBracketLevel(text, index + 2);
83+
if (longLevel >= 0) {
84+
index = skipLongString(text, index + 2, longLevel);
85+
} else {
86+
// Short comment: skip to end of line
87+
while (index < length && text.charAt(index) != '\n') {
88+
index++;
89+
}
90+
}
7391
continue;
7492
}
75-
int end = scanIdentifierEnd(text, index + 1);
76-
int next = skipWhitespace(text, end);
77-
if (next < length && text.charAt(next) == '(') {
78-
result.add(text.substring(index, end));
93+
94+
// Skip string literals: "..." or '...'
95+
if (ch == '"' || ch == '\'') {
96+
index = skipQuotedString(text, index, ch);
97+
continue;
98+
}
99+
100+
// Skip long strings: [[...]] or [=[...]=]
101+
if (ch == '[') {
102+
int longLevel = countLongBracketLevel(text, index);
103+
if (longLevel >= 0) {
104+
index = skipLongString(text, index, longLevel);
105+
continue;
106+
}
79107
}
80-
index = end;
108+
109+
// Skip function declarations: after the 'function' keyword the name tokens
110+
// (including A.B or A:B method syntax) are NOT calls.
111+
if (matchesWord(text, index, "function")) {
112+
index = skipFunctionDeclarationName(text, index + "function".length());
113+
continue;
114+
}
115+
116+
// Check identifier followed by '(' → function call
117+
if (isIdentifierStart(ch)) {
118+
int end = scanIdentifierEnd(text, index + 1);
119+
int next = skipWhitespace(text, end);
120+
if (next < length && text.charAt(next) == '(') {
121+
result.add(text.substring(index, end));
122+
}
123+
index = end;
124+
continue;
125+
}
126+
127+
index++;
81128
}
82129
return result;
83130
}
84131

132+
/**
133+
* Collects function names that appear as DEFINITIONS in the Lua source.
134+
*
135+
* Handles both simple ({@code function name(}) and method-syntax
136+
* ({@code function A:name(} or {@code function A.name(}) declarations.
137+
* Skips string literals and comments.
138+
*/
85139
static Set<String> collectDefinedFunctionNames(String text) {
86140
Set<String> result = new HashSet<>();
87141
int length = text.length();
88142
int index = 0;
89143
while (index < length) {
144+
char ch = text.charAt(index);
145+
146+
// Skip comments
147+
if (ch == '-' && index + 1 < length && text.charAt(index + 1) == '-') {
148+
int longLevel = countLongBracketLevel(text, index + 2);
149+
if (longLevel >= 0) {
150+
index = skipLongString(text, index + 2, longLevel);
151+
} else {
152+
while (index < length && text.charAt(index) != '\n') {
153+
index++;
154+
}
155+
}
156+
continue;
157+
}
158+
159+
// Skip string literals
160+
if (ch == '"' || ch == '\'') {
161+
index = skipQuotedString(text, index, ch);
162+
continue;
163+
}
164+
if (ch == '[') {
165+
int longLevel = countLongBracketLevel(text, index);
166+
if (longLevel >= 0) {
167+
index = skipLongString(text, index, longLevel);
168+
continue;
169+
}
170+
}
171+
90172
if (!matchesWord(text, index, "function")) {
91173
index++;
92174
continue;
93175
}
94-
int nameStart = skipWhitespace(text, index + "function".length());
95-
if (nameStart >= length || !isIdentifierStart(text.charAt(nameStart))) {
176+
177+
// Skip past 'function', then scan the name
178+
int pos = skipWhitespace(text, index + "function".length());
179+
if (pos >= length || !isIdentifierStart(text.charAt(pos))) {
96180
index++;
97181
continue;
98182
}
99-
int nameEnd = scanIdentifierEnd(text, nameStart + 1);
100-
int next = skipWhitespace(text, nameEnd);
101-
if (next < length && text.charAt(next) == '(') {
102-
result.add(text.substring(nameStart, nameEnd));
183+
184+
// Walk A.B.C or A:B chains, keeping track of the last identifier
185+
String lastName = null;
186+
while (pos < length && isIdentifierStart(text.charAt(pos))) {
187+
int nameEnd = scanIdentifierEnd(text, pos + 1);
188+
lastName = text.substring(pos, nameEnd);
189+
pos = nameEnd;
190+
if (pos < length && (text.charAt(pos) == '.' || text.charAt(pos) == ':')) {
191+
pos++; // consume '.' or ':'
192+
} else {
193+
break;
194+
}
195+
}
196+
197+
int next = skipWhitespace(text, pos);
198+
if (lastName != null && next < length && text.charAt(next) == '(') {
199+
result.add(lastName);
103200
}
104-
index = nameEnd;
201+
index = pos;
105202
}
106203
return result;
107204
}
108205

206+
/**
207+
* After the {@code function} keyword, skip past the declaration name
208+
* (which may include {@code A.B} or {@code A:B} qualifiers) and return
209+
* the position after the opening {@code (}.
210+
*
211+
* If there is no valid name, returns the position just after the keyword.
212+
*/
213+
private static int skipFunctionDeclarationName(String text, int index) {
214+
int length = text.length();
215+
int pos = skipWhitespace(text, index);
216+
217+
if (pos >= length || !isIdentifierStart(text.charAt(pos))) {
218+
// Anonymous function: 'function(' — no name to skip
219+
return pos;
220+
}
221+
222+
// Walk A.B.C or A:B chains
223+
while (pos < length && isIdentifierStart(text.charAt(pos))) {
224+
pos = scanIdentifierEnd(text, pos + 1);
225+
if (pos < length && (text.charAt(pos) == '.' || text.charAt(pos) == ':')) {
226+
pos++; // consume '.' or ':'
227+
} else {
228+
break;
229+
}
230+
}
231+
232+
// Skip to just after '(' so the outer loop doesn't re-examine the '('
233+
pos = skipWhitespace(text, pos);
234+
if (pos < length && text.charAt(pos) == '(') {
235+
pos++;
236+
}
237+
return pos;
238+
}
239+
240+
/**
241+
* Returns the long-bracket level of a {@code [=..=[} opener at {@code index},
242+
* or -1 if there is no valid long-bracket opener at that position.
243+
*/
244+
private static int countLongBracketLevel(String text, int index) {
245+
int length = text.length();
246+
if (index >= length || text.charAt(index) != '[') {
247+
return -1;
248+
}
249+
int level = 0;
250+
int pos = index + 1;
251+
while (pos < length && text.charAt(pos) == '=') {
252+
level++;
253+
pos++;
254+
}
255+
if (pos < length && text.charAt(pos) == '[') {
256+
return level;
257+
}
258+
return -1;
259+
}
260+
261+
/**
262+
* Skips past a long string starting with {@code [=..=[} at {@code index}.
263+
* The {@code level} is the number of {@code =} signs in the bracket.
264+
* Returns the index after the closing {@code ]=..=]}.
265+
*/
266+
private static int skipLongString(String text, int index, int level) {
267+
int length = text.length();
268+
// Skip the opening bracket [=..=[ (1 + level + 1 chars)
269+
int pos = index + 1 + level + 1;
270+
String close = "]" + "=".repeat(level) + "]";
271+
int closeIdx = text.indexOf(close, pos);
272+
if (closeIdx < 0) {
273+
return length;
274+
}
275+
return closeIdx + close.length();
276+
}
277+
278+
/**
279+
* Skips a quoted string starting at {@code index} with quote character {@code quote}.
280+
* Handles backslash escapes. Returns the index after the closing quote.
281+
*/
282+
private static int skipQuotedString(String text, int index, char quote) {
283+
int length = text.length();
284+
int pos = index + 1; // skip opening quote
285+
while (pos < length) {
286+
char ch = text.charAt(pos);
287+
if (ch == '\\') {
288+
pos += 2; // skip escaped character
289+
} else if (ch == quote) {
290+
return pos + 1;
291+
} else if (ch == '\n') {
292+
// Unfinished string literal — treat as ended
293+
return pos;
294+
} else {
295+
pos++;
296+
}
297+
}
298+
return pos;
299+
}
300+
109301
private static int skipWhitespace(String text, int index) {
110302
while (index < text.length() && Character.isWhitespace(text.charAt(index))) {
111303
index++;

0 commit comments

Comments
 (0)