Skip to content

Commit bf0cd1c

Browse files
authored
Android/fix evalue string serialization (#17609)
toByteArray() used toString() (returning "EValue@...") instead of toStr(), and allocated the buffer using char count instead of byte count. fromByteArray() threw an exception for strings instead of deserializing them. Fix both directions using a length-prefixed format, and also fix the misleading "Unknown Tensor dtype" error message. Add round-trip tests for ASCII, empty, and Unicode strings.
1 parent 850d76d commit bf0cd1c

2 files changed

Lines changed: 50 additions & 4 deletions

File tree

  • extension/android/executorch_android/src

extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import com.facebook.jni.annotations.DoNotStrip;
1212
import java.nio.ByteBuffer;
13+
import java.nio.charset.StandardCharsets;
1314
import java.util.Arrays;
1415
import java.util.Locale;
1516
import org.pytorch.executorch.annotations.Experimental;
@@ -202,12 +203,14 @@ public byte[] toByteArray() {
202203
} else if (isDouble()) {
203204
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
204205
} else if (isString()) {
205-
return ByteBuffer.allocate(1 + toString().length())
206+
byte[] strBytes = toStr().getBytes(StandardCharsets.UTF_8);
207+
return ByteBuffer.allocate(1 + 4 + strBytes.length)
206208
.put((byte) TYPE_CODE_STRING)
207-
.put(toString().getBytes())
209+
.putInt(strBytes.length)
210+
.put(strBytes)
208211
.array();
209212
} else {
210-
throw new IllegalArgumentException("Unknown Tensor dtype");
213+
throw new IllegalArgumentException("Unknown EValue type code: " + mTypeCode);
211214
}
212215
}
213216

@@ -234,7 +237,10 @@ public static EValue fromByteArray(byte[] bytes) {
234237
byte[] bufferArray = buffer.array();
235238
return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
236239
case TYPE_CODE_STRING:
237-
throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
240+
int strLen = buffer.getInt();
241+
byte[] strBytes = new byte[strLen];
242+
buffer.get(strBytes);
243+
return from(new String(strBytes, StandardCharsets.UTF_8));
238244
case TYPE_CODE_DOUBLE:
239245
return from(buffer.getDouble());
240246
case TYPE_CODE_INT:

extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,46 @@ class EValueTest {
167167
assertEquals(1.345e-2, deser.toDouble(), 1e-6)
168168
}
169169

170+
@Test
171+
fun testStringSerde() {
172+
val evalue = EValue.from("hello")
173+
val bytes = evalue.toByteArray()
174+
175+
val deser = EValue.fromByteArray(bytes)
176+
assertTrue(deser.isString)
177+
assertEquals("hello", deser.toStr())
178+
}
179+
180+
@Test
181+
fun testEmptyStringSerde() {
182+
val evalue = EValue.from("")
183+
val bytes = evalue.toByteArray()
184+
185+
val deser = EValue.fromByteArray(bytes)
186+
assertTrue(deser.isString)
187+
assertEquals("", deser.toStr())
188+
}
189+
190+
@Test
191+
fun testChineseStringSerde() {
192+
val evalue = EValue.from("你好世界")
193+
val bytes = evalue.toByteArray()
194+
195+
val deser = EValue.fromByteArray(bytes)
196+
assertTrue(deser.isString)
197+
assertEquals("你好世界", deser.toStr())
198+
}
199+
200+
@Test
201+
fun testEmojiStringSerde() {
202+
val evalue = EValue.from("👋🌍")
203+
val bytes = evalue.toByteArray()
204+
205+
val deser = EValue.fromByteArray(bytes)
206+
assertTrue(deser.isString)
207+
assertEquals("👋🌍", deser.toStr())
208+
}
209+
170210
@Test
171211
fun testLongTensorSerde() {
172212
val data = longArrayOf(1, 2, 3, 4)

0 commit comments

Comments
 (0)