diff --git a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java index 9d36acb..10abe96 100644 --- a/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java +++ b/src/test/java/ru/rt/restream/reindexer/connector/FloatVectorIvfTest.java @@ -35,8 +35,10 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.stream.Collectors; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -113,45 +115,43 @@ public void testSearchWithBaseParamK_isOk() { @Test public void testSearchWithBaseParamRadius_isOk() { - List list = db.query(namespaceName, VectorItem.class) + List foundIds = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, radius(0.1f)) - .toList(); + .toList() + .stream() + .map(VectorItem::getId) + .collect(Collectors.toList()); - assertThat(list.size(), is(4)); - assertThat(list.get(0).getId(), is(18)); - assertThat(list.get(1).getId(), is(6)); - assertThat(list.get(2).getId(), is(7)); - assertThat(list.get(3).getId(), is(8)); + assertThat(foundIds, containsInAnyOrder(18, 6, 7, 8)); } @Test public void testSearchWithBaseParamsKAndRadius_isOk() { // by k (3 records) + by radius (4 records) = 3 records - List list = db.query(namespaceName, VectorItem.class) + List foundIds = db.query(namespaceName, VectorItem.class) .selectAllFields() - .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, - KnnParams.base(3, 0.1f)) - .toList(); + .whereKnn("vector", new float[]{0.1f, 0.2f, 0.1f}, + KnnParams.base(3, 0.2f)) + .toList() + .stream() + .map(VectorItem::getId) + .collect(Collectors.toList()); - assertThat(list.size(), is(3)); - assertThat(list.get(0).getId(), is(18)); - assertThat(list.get(1).getId(), is(7)); - assertThat(list.get(2).getId(), is(8)); + assertThat(foundIds, containsInAnyOrder(18, 6, 8)); // by k (5 records) + by radius (4 records) = 4 records - list = db.query(namespaceName, VectorItem.class) + foundIds = db.query(namespaceName, VectorItem.class) .selectAllFields() - .whereKnn("vector", new float[]{0.1f, 0.1f, 0.1f}, - KnnParams.base(5, 0.1f)) - .toList(); - - assertThat(list.size(), is(4)); - assertThat(list.get(0).getId(), is(18)); - assertThat(list.get(1).getId(), is(6)); - assertThat(list.get(2).getId(), is(7)); - assertThat(list.get(3).getId(), is(8)); + .whereKnn("vector", new float[]{0.1f, 0.2f, 0.1f}, + KnnParams.base(5, 0.2f)) + .toList() + .stream() + .map(VectorItem::getId) + .collect(Collectors.toList()); + + assertThat(foundIds, containsInAnyOrder(18, 6, 8, 19)); } @Test @@ -170,16 +170,16 @@ public void testSearchWithIvfParams_isOk() { assertThat(list.get(1).getVector(), is(testItems.get(18).getVector())); // only radius - 3 records - list = db.query(namespaceName, VectorItem.class) + List foundIds = db.query(namespaceName, VectorItem.class) .selectAllFields() .whereKnn("vector", new float[]{0.23f, 0.23f, 0.0f}, KnnParams.ivf(KnnParams.radius(0.3f), 3)) - .toList(); + .toList() + .stream() + .map(VectorItem::getId) + .collect(Collectors.toList()); - assertThat(list.size(), is(3)); - assertThat(list.get(0).getId(), is(8)); - assertThat(list.get(1).getId(), is(18)); - assertThat(list.get(2).getId(), is(19)); + assertThat(foundIds, containsInAnyOrder(8, 18, 19)); // by k (2 records) + by radius (3 records) = 2 records list = db.query(namespaceName, VectorItem.class)