Skip to content

Commit 878f2e7

Browse files
authored
Update AppVersionComputer to use registered workflow FQ name & method's byte code (#363)
fixes #328
1 parent 22eb5b8 commit 878f2e7

5 files changed

Lines changed: 339 additions & 72 deletions

File tree

gradle/libs.versions.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[versions]
2+
asm = "9.7.1"
23
aspectj = "1.9.22.1"
34
assertj = "3.27.3"
45
cron-utils = "9.2.1"
@@ -26,6 +27,7 @@ testcontainers = "2.0.4"
2627
versions = "0.53.0"
2728

2829
[libraries]
30+
asm = { module = "org.ow2.asm:asm", version.ref = "asm" }
2931
aspectjweaver = { module = "org.aspectj:aspectjweaver", version.ref = "aspectj" }
3032
assertj-core = { module = "org.assertj:assertj-core", version.ref = "assertj" }
3133
cron-utils = { module = "com.cronutils:cron-utils", version.ref = "cron-utils" }

transact/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies {
2929
api(libs.slf4j.api)
3030
api(libs.jspecify)
3131

32+
implementation(libs.asm)
3233
implementation(libs.postgresql)
3334
implementation(libs.hikaricp)
3435
implementation(libs.bundles.jackson)

transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,8 @@ public void start(
177177
this.alertHandler = alertHandler;
178178

179179
if (this.appVersion == null || this.appVersion.isEmpty()) {
180-
List<Class<?>> registeredClasses =
181-
workflowMap.values().stream()
182-
.map(wrapper -> wrapper.target().getClass())
183-
.collect(Collectors.toList());
184-
this.appVersion = AppVersionComputer.computeAppVersion(registeredClasses);
180+
this.appVersion =
181+
AppVersionComputer.computeAppVersion(DBOS.version(), workflowMap.values());
185182
}
186183

187184
if (config.conductorKey() != null) {
Lines changed: 163 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,191 @@
11
package dev.dbos.transact.internal;
22

3-
import dev.dbos.transact.DBOS;
3+
import dev.dbos.transact.execution.RegisteredWorkflow;
44

5-
import java.io.InputStream;
5+
import java.io.IOException;
6+
import java.nio.charset.StandardCharsets;
67
import java.security.MessageDigest;
8+
import java.security.NoSuchAlgorithmException;
79
import java.util.*;
8-
import java.util.stream.Collectors;
910

11+
import org.objectweb.asm.ClassReader;
12+
import org.objectweb.asm.ClassVisitor;
13+
import org.objectweb.asm.Handle;
14+
import org.objectweb.asm.Label;
15+
import org.objectweb.asm.MethodVisitor;
16+
import org.objectweb.asm.Opcodes;
17+
import org.objectweb.asm.Type;
1018
import org.slf4j.Logger;
1119
import org.slf4j.LoggerFactory;
1220

1321
public class AppVersionComputer {
1422

1523
private static Logger logger = LoggerFactory.getLogger(AppVersionComputer.class);
1624

17-
public static String computeAppVersion(List<Class<?>> registeredClasses) {
25+
public static String computeAppVersion(
26+
String dbosVersion, Collection<RegisteredWorkflow> workflows) {
1827
try {
19-
MessageDigest hasher = MessageDigest.getInstance("SHA-256");
20-
21-
// Sort by class name for deterministic ordering
22-
List<Class<?>> sortedClasses =
23-
registeredClasses.stream()
24-
.distinct()
25-
.sorted(Comparator.comparing(Class::getName))
26-
.collect(Collectors.toList());
27-
28-
// Hash each unique class
29-
for (Class<?> clazz : sortedClasses) {
30-
logger.debug("hashing {}", clazz.getName());
31-
String classHash = getClassBytecodeHash(clazz);
32-
hasher.update((clazz.getName() + ":" + classHash).getBytes("UTF-8"));
28+
final var hasher = MessageDigest.getInstance("SHA-256");
29+
hasher.update(dbosVersion.getBytes(StandardCharsets.UTF_8));
30+
31+
var sortedWorkflows =
32+
workflows.stream().sorted(Comparator.comparing(RegisteredWorkflow::fullyQualifiedName));
33+
var it = sortedWorkflows.iterator();
34+
35+
while (it.hasNext()) {
36+
var wf = it.next();
37+
hasher.update(wf.fullyQualifiedName().getBytes(StandardCharsets.UTF_8));
38+
39+
var klass = wf.workflowMethod().getDeclaringClass();
40+
var klassPath = klass.getName().replace('.', '/') + ".class";
41+
var methodDesc = Type.getMethodDescriptor(wf.workflowMethod());
42+
var methodName = wf.workflowMethod().getName();
43+
44+
try (var in = klass.getClassLoader().getResourceAsStream(klassPath)) {
45+
if (in == null) throw new IOException("%s class not found".formatted(klass.getName()));
46+
var reader = new ClassReader(in);
47+
reader.accept(
48+
new ClassVisitor(Opcodes.ASM9) {
49+
@Override
50+
public MethodVisitor visitMethod(
51+
int access, String name, String desc, String signature, String[] exceptions) {
52+
return (name.equals(methodName) && desc.equals(methodDesc))
53+
? new HashingMethodVisitor(hasher)
54+
: null;
55+
}
56+
},
57+
0);
58+
}
3359
}
3460

35-
// Different DBOS versions should produce different app versions
36-
hasher.update(DBOS.version().getBytes("UTF-8"));
37-
3861
return bytesToHex(hasher.digest());
39-
} catch (Exception e) {
40-
logger.warn("Failed to compute simplified app version", e);
41-
return getFallbackVersion();
62+
} catch (NoSuchAlgorithmException | IOException e) {
63+
logger.warn("Failed to compute app version", e);
64+
return "unknown-" + System.currentTimeMillis();
4265
}
4366
}
4467

45-
/** Gets a hash of the class bytecode. */
46-
private static String getClassBytecodeHash(Class<?> clazz) {
47-
try {
48-
// Get the class file as a resource
49-
String className = clazz.getName().replace('.', '/') + ".class";
50-
51-
try (InputStream is = clazz.getClassLoader().getResourceAsStream(className)) {
52-
if (is != null) {
53-
MessageDigest hasher = MessageDigest.getInstance("SHA-256");
54-
byte[] buffer = new byte[8192];
55-
int bytesRead;
56-
while ((bytesRead = is.read(buffer)) != -1) {
57-
hasher.update(buffer, 0, bytesRead);
58-
}
59-
return bytesToHex(hasher.digest());
60-
}
68+
static class HashingMethodVisitor extends MethodVisitor {
69+
private final MessageDigest md;
70+
private final Map<Label, Integer> labelOrdinals = new LinkedHashMap<>();
71+
private int nextLabelOrdinal = 0;
72+
73+
public HashingMethodVisitor(MessageDigest md) {
74+
super(Opcodes.ASM9);
75+
this.md = md;
76+
}
77+
78+
private int labelOrdinal(Label label) {
79+
return labelOrdinals.computeIfAbsent(label, l -> nextLabelOrdinal++);
80+
}
81+
82+
private void update(String... values) {
83+
for (var v : values) {
84+
if (v != null) md.update(v.getBytes(StandardCharsets.UTF_8));
6185
}
86+
}
6287

63-
// Fallback: use class hashCode and serialVersionUID if available
64-
long classHash = clazz.hashCode();
65-
try {
66-
java.lang.reflect.Field serialVersionUID = clazz.getDeclaredField("serialVersionUID");
67-
serialVersionUID.setAccessible(true);
68-
classHash ^= serialVersionUID.getLong(null);
69-
} catch (Exception ignored) {
70-
// serialVersionUID not available, that's ok
88+
private void update(int... values) {
89+
for (var v : values) {
90+
md.update((byte) (v >>> 24));
91+
md.update((byte) (v >>> 16));
92+
md.update((byte) (v >>> 8));
93+
md.update((byte) v);
7194
}
95+
}
7296

73-
return Long.toHexString(classHash);
97+
@Override
98+
public void visitLabel(Label label) {
99+
labelOrdinal(label);
100+
}
101+
102+
@Override
103+
public void visitInsn(int opcode) {
104+
update(opcode);
105+
}
106+
107+
@Override
108+
public void visitIntInsn(int opcode, int operand) {
109+
update(opcode, operand);
110+
}
74111

75-
} catch (Exception e) {
76-
logger.debug("Error getting class bytecode hash for {}", clazz.getName(), e);
77-
return Integer.toHexString(clazz.getName().hashCode());
112+
@Override
113+
public void visitVarInsn(int opcode, int varIndex) {
114+
update(opcode, varIndex);
115+
}
116+
117+
@Override
118+
public void visitTypeInsn(int opcode, String type) {
119+
update(opcode);
120+
update(type);
121+
}
122+
123+
@Override
124+
public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
125+
update(opcode);
126+
update(owner, name, descriptor);
127+
}
128+
129+
@Override
130+
public void visitMethodInsn(
131+
int opcode, String owner, String name, String descriptor, boolean isInterface) {
132+
update(opcode);
133+
update(owner, name, descriptor);
134+
update(isInterface ? 1 : 0);
135+
}
136+
137+
@Override
138+
public void visitInvokeDynamicInsn(
139+
String name,
140+
String descriptor,
141+
Handle bootstrapMethodHandle,
142+
Object... bootstrapMethodArguments) {
143+
update(name, descriptor);
144+
update(bootstrapMethodHandle.toString());
145+
for (var arg : bootstrapMethodArguments) {
146+
if (arg != null) update(arg.toString());
147+
}
148+
}
149+
150+
@Override
151+
public void visitJumpInsn(int opcode, Label label) {
152+
update(opcode, labelOrdinal(label));
153+
}
154+
155+
@Override
156+
public void visitLdcInsn(Object value) {
157+
update(Opcodes.LDC);
158+
if (value != null) update(value.toString());
159+
}
160+
161+
@Override
162+
public void visitIincInsn(int varIndex, int increment) {
163+
update(Opcodes.IINC, varIndex, increment);
164+
}
165+
166+
@Override
167+
public void visitTableSwitchInsn(int min, int max, Label dflt, Label... labels) {
168+
update(Opcodes.TABLESWITCH, min, max, labelOrdinal(dflt));
169+
for (var l : labels) update(labelOrdinal(l));
170+
}
171+
172+
@Override
173+
public void visitLookupSwitchInsn(Label dflt, int[] keys, Label[] labels) {
174+
update(Opcodes.LOOKUPSWITCH, labelOrdinal(dflt));
175+
update(keys);
176+
for (var l : labels) update(labelOrdinal(l));
177+
}
178+
179+
@Override
180+
public void visitMultiANewArrayInsn(String descriptor, int numDimensions) {
181+
update(Opcodes.MULTIANEWARRAY, numDimensions);
182+
update(descriptor);
183+
}
184+
185+
@Override
186+
public void visitTryCatchBlock(Label start, Label end, Label handler, String type) {
187+
update(labelOrdinal(start), labelOrdinal(end), labelOrdinal(handler));
188+
update(type);
78189
}
79190
}
80191

@@ -89,8 +200,4 @@ private static String bytesToHex(byte[] bytes) {
89200
}
90201
return hexString.toString();
91202
}
92-
93-
private static String getFallbackVersion() {
94-
return "unknown-" + System.currentTimeMillis();
95-
}
96203
}

0 commit comments

Comments
 (0)