diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java index ab3b77ff1fb..8d7faf69768 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java @@ -202,12 +202,14 @@ public byte[] toByteArray() { } else if (isDouble()) { return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array(); } else if (isString()) { - return ByteBuffer.allocate(1 + toString().length()) + byte[] strBytes = toStr().getBytes(); + return ByteBuffer.allocate(1 + 4 + strBytes.length) .put((byte) TYPE_CODE_STRING) - .put(toString().getBytes()) + .putInt(strBytes.length) + .put(strBytes) .array(); } else { - throw new IllegalArgumentException("Unknown Tensor dtype"); + throw new IllegalArgumentException("Unknown EValue type code: " + mTypeCode); } } @@ -234,7 +236,10 @@ public static EValue fromByteArray(byte[] bytes) { byte[] bufferArray = buffer.array(); return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length))); case TYPE_CODE_STRING: - throw new IllegalArgumentException("TYPE_CODE_STRING is not supported"); + int strLen = buffer.getInt(); + byte[] strBytes = new byte[strLen]; + buffer.get(strBytes); + return from(new String(strBytes)); case TYPE_CODE_DOUBLE: return from(buffer.getDouble()); case TYPE_CODE_INT: diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt index 7e9fea9a699..c73053de6ed 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -167,6 +167,46 @@ class EValueTest { assertEquals(1.345e-2, deser.toDouble(), 1e-6) } + @Test + fun testStringSerde() { + val evalue = EValue.from("hello") + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertTrue(deser.isString) + assertEquals("hello", deser.toStr()) + } + + @Test + fun testEmptyStringSerde() { + val evalue = EValue.from("") + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertTrue(deser.isString) + assertEquals("", deser.toStr()) + } + + @Test + fun testChineseStringSerde() { + val evalue = EValue.from("你好世界") + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertTrue(deser.isString) + assertEquals("你好世界", deser.toStr()) + } + + @Test + fun testEmojiStringSerde() { + val evalue = EValue.from("👋🌍") + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertTrue(deser.isString) + assertEquals("👋🌍", deser.toStr()) + } + @Test fun testLongTensorSerde() { val data = longArrayOf(1, 2, 3, 4)