Skip to content

Commit bf22b46

Browse files
committed
Optimize chained string concatenation with counted CONCAT
1 parent 904592a commit bf22b46

6 files changed

Lines changed: 202 additions & 19 deletions

File tree

src/jmh/java/io/jawk/backend/AVMExpressionBenchmark.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ public class AVMExpressionBenchmark {
6161
private AwkExpression fieldConcatenation;
6262
private AwkExpression fieldRegexMatch;
6363
private AwkExpression multiStringConcatenation;
64+
private AwkExpression constantStringConcatenation;
65+
private AwkExpression stringConstantStringConstantConcatenation;
66+
private AwkExpression fourStringConcatenation;
6467
private AwkExpression mixedExpression;
6568

6669
/**
@@ -78,6 +81,9 @@ public void setup() throws IOException {
7881
this.fieldConcatenation = awk.compileExpression("$1 \" test\"");
7982
this.fieldRegexMatch = awk.compileExpression("$1 ~ /test/");
8083
this.multiStringConcatenation = awk.compileExpression("$1 \" test1\" \" test2\" \" test3\"");
84+
this.constantStringConcatenation = awk.compileExpression("\"constant\" \"constant\" \"constant\" \"constant\"");
85+
this.stringConstantStringConstantConcatenation = awk.compileExpression("$1 \"constant\" $2 \"constant\"");
86+
this.fourStringConcatenation = awk.compileExpression("$1 $2 $3 $4");
8187
this.mixedExpression = awk.compileExpression("($1 + $2) \":\" ($3 ~ /test/) \":\" $4");
8288
this.avm = new AVM(new AwkSettings(), Collections.emptyMap());
8389
this.avm.prepareForEval("42 3.14 test-value suffix");
@@ -159,6 +165,39 @@ public Object multiStringConcatenation() throws IOException {
159165
return this.avm.eval(this.multiStringConcatenation);
160166
}
161167

168+
/**
169+
* Measures concatenation of four constant string operands.
170+
*
171+
* @return expression result
172+
* @throws IOException if input preparation or evaluation fails
173+
*/
174+
@Benchmark
175+
public Object constantStringConcatenation() throws IOException {
176+
return this.avm.eval(this.constantStringConcatenation);
177+
}
178+
179+
/**
180+
* Measures alternating field and constant string concatenation.
181+
*
182+
* @return expression result
183+
* @throws IOException if input preparation or evaluation fails
184+
*/
185+
@Benchmark
186+
public Object stringConstantStringConstantConcatenation() throws IOException {
187+
return this.avm.eval(this.stringConstantStringConstantConcatenation);
188+
}
189+
190+
/**
191+
* Measures concatenation of four field string operands.
192+
*
193+
* @return expression result
194+
* @throws IOException if input preparation or evaluation fails
195+
*/
196+
@Benchmark
197+
public Object fourStringConcatenation() throws IOException {
198+
return this.avm.eval(this.fourStringConcatenation);
199+
}
200+
162201
/**
163202
* Measures mixed numeric, string, field, and regular expression operations.
164203
*

src/main/java/io/jawk/backend/AVM.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,12 +1168,25 @@ private void executeTuples(PositionTracker position)
11681168
break;
11691169
}
11701170
case CONCAT: {
1171-
// stack[0] = string1
1172-
// stack[1] = string2
1173-
String s2 = jrt.toAwkString(pop());
1174-
String s1 = jrt.toAwkString(pop());
1175-
String resultString = s1 + s2;
1176-
push(resultString);
1171+
// arg[0] = number of stack items to concatenate
1172+
// stack[0] = last concatenation operand
1173+
CountTuple countTuple = (CountTuple) tuple;
1174+
int count = (int) countTuple.getCount();
1175+
// Store String references so appends run left-to-right. Converting
1176+
// operands to char[] would copy them once before StringBuilder
1177+
// copies them again, and front-inserting would shift existing
1178+
// content on each operand.
1179+
String[] values = new String[count];
1180+
int resultLength = 0;
1181+
for (int i = count - 1; i >= 0; i--) {
1182+
values[i] = jrt.toAwkString(pop());
1183+
resultLength += values[i].length();
1184+
}
1185+
StringBuilder resultString = new StringBuilder(resultLength);
1186+
for (String value : values) {
1187+
resultString.append(value);
1188+
}
1189+
push(resultString.toString());
11771190
position.next();
11781191
break;
11791192
}

src/main/java/io/jawk/frontend/AwkParser.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3628,12 +3628,25 @@ private ConcatExpressionAst(AST lhs, AST rhs) {
36283628
@Override
36293629
public int populateTuples(AwkTuples tuples) {
36303630
pushSourceLineNumber(tuples);
3631-
getAst1().populateTuples(tuples);
3632-
getAst2().populateTuples(tuples);
3633-
tuples.concat();
3631+
List<AST> operands = new ArrayList<>();
3632+
collectConcatOperands(this, operands);
3633+
for (AST operand : operands) {
3634+
operand.populateTuples(tuples);
3635+
}
3636+
tuples.concat(operands.size());
36343637
popSourceLineNumber(tuples);
36353638
return 1;
36363639
}
3640+
3641+
private void collectConcatOperands(AST ast, List<AST> operands) {
3642+
if (ast instanceof ConcatExpressionAst) {
3643+
ConcatExpressionAst concat = (ConcatExpressionAst) ast;
3644+
collectConcatOperands(concat.getAst1(), operands);
3645+
collectConcatOperands(concat.getAst2(), operands);
3646+
} else {
3647+
operands.add(ast);
3648+
}
3649+
}
36373650
}
36383651

36393652
private final class NegativeExpressionAst extends ScalarExpressionAst {

src/main/java/io/jawk/intermediate/AwkTuples.java

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.HashMap;
3333
import java.util.HashSet;
3434
import java.util.IdentityHashMap;
35+
import java.util.List;
3536
import java.util.Map;
3637
import java.util.Set;
3738
import java.util.function.Supplier;
@@ -69,7 +70,7 @@ public class AwkTuples implements Serializable {
6970
* can be serialized and patched efficiently. A linked list would make every
7071
* lookup O(n) and complicate address reassignment.
7172
*/
72-
private java.util.List<Tuple> queue = new ArrayList<Tuple>(100) {
73+
private List<Tuple> queue = new ArrayList<Tuple>(100) {
7374
private static final long serialVersionUID = -6334362156408598578L;
7475

7576
@Override
@@ -296,11 +297,25 @@ public void length(int numExprs) {
296297

297298
/**
298299
* <p>
299-
* concat.
300+
* Concatenates two stack items.
300301
* </p>
301302
*/
302303
public void concat() {
303-
queue.add(new Tuple.NoOperandTuple(Opcode.CONCAT));
304+
concat(2);
305+
}
306+
307+
/**
308+
* <p>
309+
* Concatenates the requested number of stack items.
310+
* </p>
311+
*
312+
* @param count the number of stack items to concatenate
313+
*/
314+
public void concat(int count) {
315+
if (count < 2) {
316+
throw new IllegalArgumentException("CONCAT requires at least two stack items");
317+
}
318+
queue.add(new Tuple.CountTuple(Opcode.CONCAT, count));
304319
}
305320

306321
/**
@@ -1878,10 +1893,10 @@ private boolean peepholeOptimizePass() {
18781893
return false;
18791894
}
18801895

1881-
java.util.List<Tuple> original = new ArrayList<Tuple>(queue);
1896+
List<Tuple> original = new ArrayList<Tuple>(queue);
18821897
int[] indexMapping = new int[originalSize];
18831898
Arrays.fill(indexMapping, -1);
1884-
java.util.List<Tuple> optimizedQueue = new ArrayList<Tuple>(originalSize);
1899+
List<Tuple> optimizedQueue = new ArrayList<Tuple>(originalSize);
18851900
boolean[] isAddressTarget = addressTargets(original, originalSize);
18861901

18871902
boolean modified = false;
@@ -1929,6 +1944,20 @@ private boolean peepholeOptimizePass() {
19291944
continue;
19301945
}
19311946
}
1947+
Tuple countedStringConcat = foldCountedStringConcat(original, oldIndex);
1948+
if (countedStringConcat != null) {
1949+
int count = countStringPushes(original, oldIndex);
1950+
// Fold a counted literal-only concatenation into one string push,
1951+
// e.g. PUSH_STRING "a", PUSH_STRING "b", CONCAT 2 -> PUSH_STRING
1952+
// "ab". Numeric literals are deliberately not folded here because
1953+
// their string representation depends on runtime formatting state.
1954+
optimizedQueue.add(countedStringConcat);
1955+
mapFoldedRange(indexMapping, oldIndex, count + 1, newIndex);
1956+
oldIndex += count + 1;
1957+
newIndex++;
1958+
modified = true;
1959+
continue;
1960+
}
19321961
if ((oldIndex + 2) < originalSize) {
19331962
Tuple nextTuple = original.get(oldIndex + 1);
19341963
Tuple opTuple = original.get(oldIndex + 2);
@@ -1987,7 +2016,7 @@ private boolean peepholeOptimizePass() {
19872016
return true;
19882017
}
19892018

1990-
private boolean[] addressTargets(java.util.List<Tuple> tuples, int tupleCount) {
2019+
private boolean[] addressTargets(List<Tuple> tuples, int tupleCount) {
19912020
boolean[] targets = new boolean[tupleCount];
19922021
for (Tuple tuple : tuples) {
19932022
Address address = tuple.getAddress();
@@ -2020,6 +2049,47 @@ private Object literalValue(Tuple tuple) {
20202049
}
20212050
}
20222051

2052+
private Tuple foldCountedStringConcat(List<Tuple> original, int oldIndex) {
2053+
Tuple firstTuple = original.get(oldIndex);
2054+
if (firstTuple.getOpcode() != Opcode.PUSH_STRING) {
2055+
return null;
2056+
}
2057+
2058+
int tupleCount = original.size();
2059+
StringBuilder folded = new StringBuilder();
2060+
int itemCount = 0;
2061+
int currentIndex = oldIndex;
2062+
while (currentIndex < tupleCount && original.get(currentIndex).getOpcode() == Opcode.PUSH_STRING) {
2063+
folded.append(((Tuple.PushStringTuple) original.get(currentIndex)).getValue());
2064+
itemCount++;
2065+
currentIndex++;
2066+
}
2067+
2068+
if (itemCount < 2 || currentIndex >= tupleCount) {
2069+
return null;
2070+
}
2071+
2072+
Tuple operation = original.get(currentIndex);
2073+
if (operation.getOpcode() != Opcode.CONCAT || !(operation instanceof Tuple.CountTuple)) {
2074+
return null;
2075+
}
2076+
2077+
long count = ((Tuple.CountTuple) operation).getCount();
2078+
if (count != itemCount) {
2079+
return null;
2080+
}
2081+
2082+
return createLiteralPush(folded.toString(), firstTuple.getLineNumber());
2083+
}
2084+
2085+
private int countStringPushes(List<Tuple> original, int oldIndex) {
2086+
int currentIndex = oldIndex;
2087+
while (currentIndex < original.size() && original.get(currentIndex).getOpcode() == Opcode.PUSH_STRING) {
2088+
currentIndex++;
2089+
}
2090+
return currentIndex - oldIndex;
2091+
}
2092+
20232093
private Object foldBinary(Object left, Object right, Tuple operation) {
20242094
Opcode opcode = operation.getOpcode();
20252095
if (opcode == null) {

src/main/java/io/jawk/intermediate/Opcode.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,13 @@ public enum Opcode {
189189
*/
190190
LENGTH,
191191
/**
192-
* Pop and concatenate two strings from the top-of-stack; push the result onto
193-
* the stack.
192+
* Pops and concatenates N strings from the top-of-stack; push the result onto
193+
* the stack. The number of items is passed in as a tuple argument.
194+
* <p>
195+
* Argument: # of items (N)
194196
* <p>
195-
* Stack before: x y ...<br/>
196-
* Stack after: x-concatenated-with-y ...
197+
* Stack before: x1 x2 x3 .. xN ...<br/>
198+
* Stack after: x1-concatenated-through-xN ...
197199
*/
198200
CONCAT,
199201
/**

src/test/java/io/jawk/AwkTupleOptimizationTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,34 @@ public void foldsLiteralStringConcatenation() throws Exception {
139139
assertTrue("Expected folded literal push of foobar", hasLiteralPush(tuples, "foobar"));
140140
}
141141

142+
@Test
143+
public void foldsChainedLiteralStringConcatenation() throws Exception {
144+
String script = "BEGIN { print \"foo\" \"bar\" \"baz\" \"qux\" }\n";
145+
AwkTestSupport
146+
.awkTest("folds chained literal string concatenation")
147+
.script(script)
148+
.expect("foobarbazqux\n")
149+
.runAndAssert();
150+
151+
AwkProgram tuples = new Awk().compile(script);
152+
List<Opcode> opcodes = collectOpcodes(tuples);
153+
assertFalse("Chained literal concatenation should eliminate CONCAT tuple", opcodes.contains(Opcode.CONCAT));
154+
assertTrue("Expected folded literal push of foobarbazqux", hasLiteralPush(tuples, "foobarbazqux"));
155+
}
156+
157+
@Test
158+
public void compilesChainedStringConcatenationAsSingleCountedConcat() throws Exception {
159+
String script = "BEGIN { s1 = \"alpha\"; s2 = \"beta\"; print s1 \"-\" s2 \":\" }\n";
160+
AwkTestSupport
161+
.awkTest("counted chained string concatenation")
162+
.script(script)
163+
.expect("alpha-beta:\n")
164+
.runAndAssert();
165+
166+
AwkProgram tuples = new Awk().compile(script);
167+
assertEquals("Expected one counted CONCAT for the mixed chain", 1, countOpcodeWithCount(tuples, Opcode.CONCAT, 4));
168+
}
169+
142170
@Test
143171
public void foldsScalarAssignmentPopIntoNonPushingAssignment() throws Exception {
144172
String script = "BEGIN { a = -2; b = 2; c = 4; print a + b + c }\n";
@@ -204,6 +232,10 @@ public void doesNotFoldNumericConcatenation() throws Exception {
204232
AwkProgram tuples = new Awk().compile(script);
205233
List<Opcode> opcodes = collectOpcodes(tuples);
206234
assertTrue("Numeric literal concatenation should preserve CONCAT tuple", opcodes.contains(Opcode.CONCAT));
235+
assertEquals(
236+
"Numeric literal concatenation should remain binary",
237+
1,
238+
countOpcodeWithCount(tuples, Opcode.CONCAT, 2));
207239
assertFalse("Optimizer should not fold numeric/string concatenation", hasLiteralPush(tuples, "1x"));
208240
}
209241

@@ -570,6 +602,20 @@ private static int countOpcode(AwkProgram tuples, Opcode opcode) {
570602
return count;
571603
}
572604

605+
private static int countOpcodeWithCount(AwkProgram tuples, Opcode opcode, long expectedCount) {
606+
int count = 0;
607+
PositionTracker tracker = rawTuples(tuples).top();
608+
while (!tracker.isEOF()) {
609+
if (tracker.opcode() == opcode
610+
&& tracker.current() instanceof Tuple.CountTuple
611+
&& ((Tuple.CountTuple) tracker.current()).getCount() == expectedCount) {
612+
count++;
613+
}
614+
tracker.next();
615+
}
616+
return count;
617+
}
618+
573619
private static String dumpTuples(AwkProgram tuples) throws Exception {
574620
ByteArrayOutputStream out = new ByteArrayOutputStream();
575621
try (PrintStream ps = new PrintStream(out, true, StandardCharsets.UTF_8.name())) {

0 commit comments

Comments
 (0)