Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,14 @@ public byte[] toByteArray() {
} else if (isDouble()) {
return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array();
} else if (isString()) {
return ByteBuffer.allocate(1 + toString().length())
byte[] strBytes = toStr().getBytes();
return ByteBuffer.allocate(1 + 4 + strBytes.length)
.put((byte) TYPE_CODE_STRING)
.put(toString().getBytes())
.putInt(strBytes.length)
.put(strBytes)
.array();
} else {
throw new IllegalArgumentException("Unknown Tensor dtype");
throw new IllegalArgumentException("Unknown EValue type code: " + mTypeCode);
}
}

Expand All @@ -234,7 +236,10 @@ public static EValue fromByteArray(byte[] bytes) {
byte[] bufferArray = buffer.array();
return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length)));
case TYPE_CODE_STRING:
throw new IllegalArgumentException("TYPE_CODE_STRING is not supported");
int strLen = buffer.getInt();
byte[] strBytes = new byte[strLen];
buffer.get(strBytes);
return from(new String(strBytes));
case TYPE_CODE_DOUBLE:
return from(buffer.getDouble());
case TYPE_CODE_INT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,46 @@ class EValueTest {
assertEquals(1.345e-2, deser.toDouble(), 1e-6)
}

@Test
fun testStringSerde() {
val evalue = EValue.from("hello")
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
assertTrue(deser.isString)
assertEquals("hello", deser.toStr())
}

@Test
fun testEmptyStringSerde() {
val evalue = EValue.from("")
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
assertTrue(deser.isString)
assertEquals("", deser.toStr())
}

@Test
fun testChineseStringSerde() {
val evalue = EValue.from("你好世界")
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
assertTrue(deser.isString)
assertEquals("你好世界", deser.toStr())
}

@Test
fun testEmojiStringSerde() {
val evalue = EValue.from("👋🌍")
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
assertTrue(deser.isString)
assertEquals("👋🌍", deser.toStr())
}

@Test
fun testLongTensorSerde() {
val data = longArrayOf(1, 2, 3, 4)
Expand Down
Loading