From e3e1651dfa36faf1d5969a43c48b579e29698b0c Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 10:10:30 +0530 Subject: [PATCH] Sync changes from CDB_DiskANN repo - Refactored recall utilities in diskann-benchmark - Updated tokio utilities - Added attribute and format parser improvements in label-filter - Updated ground_truth utilities in diskann-tools --- diskann-benchmark/src/utils/recall.rs | 703 +----------------- diskann-benchmark/src/utils/tokio.rs | 20 +- diskann-label-filter/src/attribute.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 + .../src/utils/flatten_utils.rs | 2 +- diskann-tools/Cargo.toml | 18 +- diskann-tools/src/utils/ground_truth.rs | 161 +++- 7 files changed, 196 insertions(+), 711 deletions(-) diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 5b7fd1594..bfaf46772 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,15 +3,13 @@ * Licensed under the MIT license. */ -use std::{collections::HashSet, hash::Hash}; - -use diskann_utils::strided::StridedView; -use diskann_utils::views::MatrixView; +use diskann_benchmark_core as benchmark_core; +pub(crate) use benchmark_core::recall::knn; use serde::Serialize; -use thiserror::Error; -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, @@ -25,278 +23,19 @@ pub(crate) struct RecallMetrics { pub(crate) minimum: usize, /// The maximum observed recall (max possible value: `recall_k`). pub(crate) maximum: usize, - /// Recall results by query - pub(crate) by_query: Option>, -} - -// impl RecallMetrics { -// pub(crate) fn num_queries(&self) -> usize { -// self.num_queries -// } - -// pub(crate) fn average(&self) -> f64 { -// self.average -// } -// } - -#[derive(Debug, Error)] -pub(crate) enum ComputeRecallError { - #[error("results matrix has {0} rows but ground truth has {1}")] - RowsMismatch(usize, usize), - #[error("distances matrix has {0} rows but ground truth has {1}")] - DistanceRowsMismatch(usize, usize), - #[error("recall k value {0} must be less than or equal to recall n {1}")] - RecallKAndNError(usize, usize), - #[error("number of results per query {0} must be at least the specified recall k {1}")] - NotEnoughResults(usize, usize), - #[error( - "number of groundtruth values per query {0} must be at least the specified recall n {1}" - )] - NotEnoughGroundTruth(usize, usize), - #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")] - NotEnoughGroundTruthDistances(usize, usize), -} - -pub(crate) trait ComputeKnnRecall { - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result; -} - -impl ComputeKnnRecall for MatrixView<'_, T> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -impl ComputeKnnRecall for Vec> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -pub(crate) trait KnnRecall { - type Item; - - fn nrows(&self) -> usize; - fn ncols(&self) -> Option; - fn row(&self, i: usize) -> &[Self::Item]; -} - -impl KnnRecall for MatrixView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - MatrixView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(MatrixView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - MatrixView::<'_, T>::row(self, i) - } -} - -impl KnnRecall for Vec> { - type Item = T; - - fn nrows(&self) -> usize { - self.len() - } - fn ncols(&self) -> Option { - None - } - fn row(&self, i: usize) -> &[Self::Item] { - &self[i] - } } -impl KnnRecall for StridedView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - StridedView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(StridedView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - StridedView::<'_, T>::row(self, i) - } -} - -fn compute_knn_recall( - groundtruth: &K, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, - K: KnnRecall, -{ - if recall_k > recall_n { - return Err(ComputeRecallError::RecallKAndNError(recall_k, recall_n)); - } - - let nrows = results.nrows(); - if nrows != groundtruth.nrows() { - return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows())); - } - - if results.ncols() < recall_n && !allow_insufficient_results { - return Err(ComputeRecallError::NotEnoughResults( - results.ncols(), - recall_n, - )); - } - - // Validate groundtruth size for fixed-size sources - match groundtruth.ncols() { - Some(ncols) if ncols < recall_k => { - return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k)); - } - _ => {} - } - - if let Some(distances) = groundtruth_distances { - if nrows != distances.nrows() { - return Err(ComputeRecallError::DistanceRowsMismatch( - distances.nrows(), - nrows, - )); - } - - match groundtruth.ncols() { - Some(ncols) if distances.ncols() != ncols => { - return Err(ComputeRecallError::NotEnoughGroundTruthDistances( - distances.ncols(), - ncols, - )); - } - _ => {} +impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { + fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { + Self { + recall_k: m.recall_k, + recall_n: m.recall_n, + num_queries: m.num_queries, + average: m.average, + minimum: m.minimum, + maximum: m.maximum, } } - - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); - let mut this_groundtruth = HashSet::new(); - let mut this_results = HashSet::new(); - - for (i, result) in results.row_iter().enumerate() { - let gt_row = groundtruth.row(i); - - // Populate the groundtruth using the top-k - this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().copied().take(recall_k)); - - // If we have distances, then continue to append distances as long as the distance - // value is constant - if let Some(distances) = groundtruth_distances { - if recall_k > 0 { - let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { - if *d == last_distance { - this_groundtruth.insert(*g); - } else { - break; - } - } - } - } - } - - this_results.clear(); - this_results.extend(result.iter().copied().take(recall_n)); - - // Count the overlap - let r = this_groundtruth - .iter() - .filter(|i| this_results.contains(i)) - .count() - .min(recall_k); - - recall_values.push(r); - } - - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); - - let div = if groundtruth.ncols().is_some() { - recall_k * nrows - } else { - (0..groundtruth.nrows()) - .map(|i| groundtruth.row(i).len()) - .sum::() - .max(1) - }; - - let average = (total as f64) / (div as f64); - - Ok(RecallMetrics { - recall_k, - recall_n, - num_queries: nrows, - average, - minimum: *minimum, - maximum: *maximum, - by_query: if enhanced_metrics { - Some(recall_values) - } else { - None - }, - }) } /// Compute `k-recall-at-n` for all valid combinations of values in `recall_k` and @@ -309,14 +48,13 @@ where feature = "product-quantization" ))] pub(crate) fn compute_multiple_recalls( - results: StridedView<'_, T>, - groundtruth: StridedView<'_, T>, + results: &dyn benchmark_core::recall::Rows, + groundtruth: &dyn benchmark_core::recall::Rows, recall_k: &[usize], recall_n: &[usize], - enhanced_metrics: bool, -) -> Result, ComputeRecallError> +) -> Result, benchmark_core::recall::ComputeRecallError> where - T: Eq + Hash + Copy + std::fmt::Debug, + T: benchmark_core::recall::RecallCompatible, { let mut result = Vec::new(); for k in recall_k { @@ -325,414 +63,27 @@ where continue; } - result.push(compute_knn_recall( - &groundtruth, - None, - results, - *k, - *n, - false, - enhanced_metrics, - )?); + let recall = benchmark_core::recall::knn(groundtruth, None, results, *k, *n, false)?; + result.push((&recall).into()); } } Ok(result) } -#[derive(Debug, Serialize)] -pub(crate) struct APMetrics { +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] +pub(crate) struct AveragePrecisionMetrics { /// The number of queries. pub(crate) num_queries: usize, /// The average precision pub(crate) average_precision: f64, } -#[derive(Debug, Error)] -pub(crate) enum ComputeAPError { - #[error("results has {0} elements but ground truth has {1}")] - EntriesMismatch(usize, usize), -} - -/// Compute average precision of a range search result -pub(crate) fn compute_average_precision( - results: Vec>, - groundtruth: &[Vec], -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - if results.len() != groundtruth.len() { - return Err(ComputeAPError::EntriesMismatch( - results.len(), - groundtruth.len(), - )); - } - - // The actual recall computation. - let mut num_gt_results = 0; - let mut num_reported_results = 0; - - let mut scratch = HashSet::new(); - - std::iter::zip(results.iter(), groundtruth.iter()).for_each(|(result, gt)| { - scratch.clear(); - scratch.extend(result.iter().copied()); - num_reported_results += gt.iter().filter(|i| scratch.contains(i)).count(); - num_gt_results += gt.len(); - }); - - // Perform post-processing. - let average_precision = (num_reported_results as f64) / (num_gt_results as f64); - - Ok(APMetrics { - average_precision, - num_queries: results.len(), - }) -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use diskann_utils::views::Matrix; - - use super::*; - - pub(crate) fn compute_knn_recall( - results: StridedView<'_, u32>, - groundtruth: G, // StridedView - groundtruth_distances: Option>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result - where - G: ComputeKnnRecall + KnnRecall + Clone, - { - groundtruth.compute_knn_recall( - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } - - struct ExpectedRecall { - recall_k: usize, - recall_n: usize, - // Recall for each component. - components: Vec, - } - - impl ExpectedRecall { - fn new(recall_k: usize, recall_n: usize, components: Vec) -> Self { - assert!(recall_k <= recall_n); - components.iter().for_each(|x| { - assert!(*x <= recall_k); - }); - Self { - recall_k, - recall_n, - components, - } - } - - fn compute_recall(&self) -> f64 { - (self.components.iter().sum::() as f64) - / ((self.components.len() * self.recall_k) as f64) - } - } - - #[test] - fn test_happy_path() { - let groundtruth = Matrix::try_from( - vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0 - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - let distances = Matrix::try_from( - vec![ - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0 - 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1 - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2 - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - // Shift row 0 by one and row 1 by two. - let our_results = Matrix::try_from( - vec![ - 100, 0, 1, 2, 5, 6, // row 0 - 100, 101, 7, 8, 9, 10, // row 1 - 0, 1, 2, 3, 4, 5, // row 2 - 0, 1, 2, 3, 4, 5, // row 3 - ] - .into(), - 4, - 6, - ) - .unwrap(); - - //---------// - // No Ties // - //---------// - let expected_no_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]), - ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]), - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]), - ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]), - ]; - let epsilon = 1e-6; // Define a small tolerance - - for (i, expected) in expected_no_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - None, - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - } - - //-----------// - // With Ties // - //-----------// - let expected_with_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in - ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]), - ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]), - ]; - - for (i, expected) in expected_with_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - assert_eq!(recall.by_query, Some(expected.components.clone())); - } - } - - #[test] - fn test_errors() { - // k greater than n - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 11, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RecallKAndNError(..))); - } - - // Unequal rows - { - let groundtruth = Matrix::::new(0, 11, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RowsMismatch(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::RowsMismatch(..) - )); - } - - // Not enough results - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 5); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - false, - false, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughResults(..))); - let _ = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - true, - false, - ); - } - - // Not enough groundtruth - { - let groundtruth = Matrix::::new(0, 10, 5); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); - } - - // Distance Row Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 9, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..))); - } - - // Distance Cols Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 10, 9); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!( - err, - ComputeRecallError::NotEnoughGroundTruthDistances(..) - )); +impl From<&benchmark_core::recall::AveragePrecisionMetrics> for AveragePrecisionMetrics { + fn from(m: &benchmark_core::recall::AveragePrecisionMetrics) -> Self { + Self { + num_queries: m.num_queries, + average_precision: m.average_precision, } } } diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index a21d3f520..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -/// Create a multi-threaded runtime with `num_threads`. +/// Create a generic multi-threaded runtime with `num_threads`. pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { Ok(tokio::runtime::Builder::new_multi_thread() .worker_threads(num_threads) @@ -18,21 +18,3 @@ pub(crate) fn block_on(future: F) -> F::Output { .expect("current thread runtime initialization failed") .block_on(future) } - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_runtimes() { - for num_threads in [1, 2, 4, 8] { - let rt = runtime(num_threads).unwrap(); - let metrics = rt.metrics(); - assert_eq!(metrics.num_workers(), num_threads); - } - } -} diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index 9eb7ff500..f0d99bfd9 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,6 +5,7 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index c042d8338..5e9e3a9c1 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,8 +15,10 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, + } + /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 16404af4b..83c9f80f9 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&i, separator), out, separator); + flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); } } _ => { diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 7f0cb203a..1b4b3408e 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,14 +5,13 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true -license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] byteorder.workspace = true clap = { workspace = true, features = ["derive"] } -diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` +diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } diskann-disk = { workspace = true } diskann-utils = { workspace = true } @@ -24,31 +23,24 @@ ordered-float = "4.2.0" rand_distr.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive"] } -toml = "0.8.13" +serde_json.workspace = true bincode.workspace = true opentelemetry.workspace = true -opentelemetry_sdk.workspace = true -csv.workspace = true -tokio = { workspace = true, features = ["full"] } -arc-swap.workspace = true diskann-quantization = { workspace = true } diskann = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } tracing.workspace = true bit-set.workspace = true anyhow.workspace = true -serde_json.workspace = true itertools.workspace = true diskann-label-filter.workspace = true [dev-dependencies] rstest.workspace = true -assert_ok = "1.0.2" -# Use virtual-storage for integration tests -diskann-disk = { path = "../diskann-disk", features = ["virtual_storage"] } vfs = { workspace = true } -ureq = { version = "3.0.11", default-features = false, features = ["native-tls", "gzip"] } -diskann-providers = { path = "../diskann-providers", default-features = false, features = ["testing", "virtual_storage"] } +diskann-providers = { workspace = true, default-features = false, features = [ + "virtual_storage", +] } diskann-utils = { workspace = true, features = ["testing"] } [features] diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e96f7ae8f..31e69b2b2 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -25,18 +25,97 @@ use diskann_utils::views::Matrix; use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; +use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; +/// Expands a JSON object with array-valued fields into multiple objects with scalar values. +/// For example: {"country": ["AU", "NZ"], "year": 2007} +/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// +/// If multiple fields have arrays, all combinations are generated. +fn expand_array_fields(value: &Value) -> Vec { + match value { + Value::Object(map) => { + // Start with a single empty object + let mut results: Vec> = vec![Map::new()]; + + for (key, val) in map.iter() { + if let Value::Array(arr) = val { + // Expand: for each existing result, create copies for each array element + let mut new_results: Vec> = Vec::new(); + for existing in results.iter() { + for item in arr.iter() { + let mut new_map: Map = existing.clone(); + new_map.insert(key.clone(), item.clone()); + new_results.push(new_map); + } + } + // If array is empty, keep existing results without this key + if !arr.is_empty() { + results = new_results; + } + } else { + // Non-array field: add to all existing results + for existing in results.iter_mut() { + existing.insert(key.clone(), val.clone()); + } + } + } + + results.into_iter().map(Value::Object).collect() + } + // If not an object, return as-is + _ => vec![value.clone()], + } +} + +/// Evaluates a query expression against a label, expanding array fields first. +/// Returns true if any expanded variant matches the query. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + let expanded = expand_array_fields(label); + expanded.iter().any(|item| eval_query_expr(query_expr, item)) +} + pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; + tracing::info!( + "Loaded {} base labels from {}", + base_labels.len(), + base_label_filename + ); + + // Print first few base labels for debugging + for (i, label) in base_labels.iter().take(3).enumerate() { + tracing::debug!( + "Base label sample [{}]: doc_id={}, label={}", + i, + label.doc_id, + label.label + ); + } // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; + tracing::info!( + "Loaded {} queries from {}", + parsed_queries.len(), + query_label_filename + ); + + // Print first few queries for debugging + for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { + tracing::debug!( + "Query sample [{}]: query_id={}, expr={:?}", + i, + query_id, + query_expr + ); + } // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -45,7 +124,15 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - if eval_query_expr(query_expr, &base_label.label) { + // Handle case where base_label.label is an array - check if any element matches + // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) + let matches = if let Some(array) = base_label.label.as_array() { + array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + } else { + eval_query_with_array_expansion(query_expr, &base_label.label) + }; + + if matches { bitmap.insert(base_label.doc_id); } } @@ -53,6 +140,38 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); + // Debug: Print match statistics for each query + let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); + let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); + tracing::info!( + "Filter matching summary: {} total matches across {} queries ({} queries have matches)", + total_matches, + query_bitmaps.len(), + queries_with_matches + ); + + // Print per-query match counts + for (i, bitmap) in query_bitmaps.iter().enumerate() { + if i < 10 || bitmap.is_empty() { + tracing::debug!( + "Query {}: {} base vectors matched the filter", + i, + bitmap.len() + ); + } + } + + // If no matches, print more diagnostic info + if total_matches == 0 { + tracing::warn!("WARNING: No base vectors matched any query filters!"); + tracing::warn!("This could indicate a format mismatch between base labels and query filters."); + + // Try to identify what keys exist in base labels vs queries + if let Some(first_label) = base_labels.first() { + tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + } + } + Ok(query_bitmaps) } @@ -195,6 +314,44 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); + // Debug: Print top K matches for each query + tracing::info!( + "Ground truth computed for {} queries with recall_at={}", + ground_truth.len(), + recall_at + ); + for (query_idx, npq) in ground_truth.iter().enumerate() { + let neighbors: Vec<_> = npq.iter().collect(); + let neighbor_count = neighbors.len(); + + if query_idx < 10 { + // Print top K IDs and distances for first 10 queries + let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); + let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); + tracing::debug!( + "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", + query_idx, + neighbor_count, + top_ids, + top_dists + ); + } + + if neighbor_count == 0 { + tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); + } + } + + // Summary stats + let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); + let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + tracing::info!( + "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", + total_neighbors, + queries_with_neighbors, + ground_truth.len() - queries_with_neighbors + ); + if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter()