Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/contextual_retrieval/bm25_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
when collection data changes.
"""

from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Set
from loguru import logger
from rank_bm25 import BM25Okapi
import re
import asyncio
from contextual_retrieval.contextual_retrieval_api_client import get_http_client_manager
from contextual_retrieval.error_handler import SecureErrorHandler
from contextual_retrieval.constants import (
Expand All @@ -33,6 +34,11 @@ def __init__(
self.chunk_mapping: Dict[int, Dict[str, Any]] = {}
self.last_collection_stats: Dict[str, Any] = {}
self.tokenizer_pattern = re.compile(r"\w+") # Simple word tokenizer
# Background refresh state - prevents blocking queries during index rebuild
self._refresh_in_progress: bool = False
self._refresh_lock: asyncio.Lock = asyncio.Lock()
# Strong references to background tasks to prevent premature GC
self._background_tasks: Set[asyncio.Task[None]] = set()

async def _get_http_client_manager(self):
"""Get the HTTP client manager instance."""
Expand Down Expand Up @@ -103,10 +109,24 @@ async def search_bm25(
limit = self._config.search.topk_bm25

try:
# Check if index needs refresh
# Check if index needs refresh (non-blocking: schedule background rebuild,
# current query continues with the existing index to avoid latency).
if await self._should_refresh_index():
logger.info("Collection data changed - refreshing BM25 index")
await self.initialize_index()
# Avoid scheduling multiple concurrent refresh tasks; coalesce while a
# refresh is already in progress.
if not self._refresh_in_progress:
logger.info(
"Collection data changed - scheduling background BM25 refresh "
"(current query uses existing index)"
)
task = asyncio.create_task(self._background_refresh_index())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
else:
logger.debug(
"BM25 refresh already in progress; skipping scheduling of a "
"new background refresh task"
)

if not self.bm25_index:
logger.error("BM25 index not initialized")
Expand Down Expand Up @@ -162,6 +182,31 @@ async def search_bm25(
logger.error(f"BM25 search failed: {e}")
return []

async def _background_refresh_index(self) -> None:
"""
Rebuild the BM25 index in the background without blocking in-flight queries.

Uses a lock to ensure only one rebuild runs at a time. If a rebuild is
already in progress when a second collection-change is detected, the
duplicate request is silently discarded — the in-progress rebuild will
capture the latest data anyway.
"""
if self._refresh_in_progress:
logger.debug("BM25 background refresh already running - skipping duplicate")
return
async with self._refresh_lock:
if self._refresh_in_progress:
return
self._refresh_in_progress = True
try:
logger.info("Starting background BM25 index refresh...")
await self.initialize_index()
logger.info("Background BM25 index refresh complete")
except Exception as e:
logger.error(f"Background BM25 refresh failed: {e}")
finally:
self._refresh_in_progress = False

async def _fetch_all_contextual_chunks(self) -> List[Dict[str, Any]]:
"""Fetch all chunks from contextual collections."""
all_chunks: List[Dict[str, Any]] = []
Expand Down
30 changes: 25 additions & 5 deletions src/contextual_retrieval/contextual_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
connection_id: Optional[str] = None,
config_path: Optional[str] = None,
llm_service: Optional["LLMOrchestrationService"] = None,
shared_bm25: Optional[SmartBM25Search] = None,
):
"""
Initialize contextual retriever.
Expand All @@ -52,6 +53,10 @@ def __init__(
connection_id: Optional connection ID
config_path: Optional config file path
llm_service: Optional LLM service instance (prevents circular dependency)
shared_bm25: Optional pre-warmed SmartBM25Search singleton. When
provided the retriever skips the expensive index-build step during
initialize() and reuses the already-ready index, eliminating the
cold-start latency on the first query.
"""
self.qdrant_url = qdrant_url
self.environment = environment
Expand All @@ -70,7 +75,14 @@ def __init__(
# Initialize components with configuration
self.provider_detection = DynamicProviderDetection(qdrant_url, self.config)
self.qdrant_search = QdrantContextualSearch(qdrant_url, self.config)
self.bm25_search = SmartBM25Search(qdrant_url, self.config)
# Use the injected pre-warmed singleton when available; create a fresh
# instance only as a fallback (avoids duplicate Qdrant scroll on startup).
self.bm25_search: SmartBM25Search = (
shared_bm25
if shared_bm25 is not None
else SmartBM25Search(qdrant_url, self.config)
)
self._bm25_is_shared: bool = shared_bm25 is not None
self.rank_fusion = DynamicRankFusion(self.config)

# State
Expand All @@ -87,10 +99,18 @@ async def initialize(self) -> bool:
try:
logger.info("Initializing Contextual Retriever...")

# Initialize BM25 index
bm25_success = await self.bm25_search.initialize_index()
if not bm25_success:
logger.warning("BM25 initialization failed - will skip BM25 search")
# If received a pre-warmed shared BM25 index, reuse it directly.
# This is the normal startup path and adds zero latency to the first query.
if self._bm25_is_shared and self.bm25_search.bm25_index is not None:
logger.info(
"Using pre-warmed shared BM25 index - skipping BM25 build "
f"({len(self.bm25_search.chunk_mapping)} chunks ready)"
)
else:
# No shared index available - build it now (fallback path).
bm25_success = await self.bm25_search.initialize_index()
if not bm25_success:
logger.warning("BM25 initialization failed - will skip BM25 search")

self.initialized = True
logger.info("Contextual Retriever initialized successfully")
Expand Down
41 changes: 15 additions & 26 deletions src/guardrails/nemo_rails_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def __init__(
self._rails: Optional[LLMRails] = None
self._initialized = False

logger.info(f"Initializing NeMoRailsAdapter for environment: {environment}")
logger.debug(f"NeMoRailsAdapter created for environment: {environment}")

def _register_custom_provider(self) -> None:
"""Register DSPy custom LLM provider with NeMo Guardrails."""
try:
from src.guardrails.dspy_nemo_adapter import DSPyLLMProviderFactory

logger.info("Registering DSPy custom LLM provider with NeMo Guardrails")
logger.debug("Registering DSPy custom LLM provider with NeMo Guardrails")

# NeMo Guardrails' register_llm_provider accepts callable factories at runtime.
# We instantiate DSPyLLMProviderFactory first, then register the instance.
Expand All @@ -74,7 +74,7 @@ def _register_custom_provider(self) -> None:
# We use cast to satisfy the type checker while maintaining runtime correctness.
factory = DSPyLLMProviderFactory()
register_llm_provider("dspy-custom", cast(Type[BaseLLM], factory))
logger.info("DSPy custom LLM provider registered successfully")
logger.debug("DSPy custom LLM provider registered successfully")

except Exception as e:
logger.error(f"Failed to register DSPy custom provider: {str(e)}")
Expand All @@ -86,8 +86,8 @@ def _ensure_initialized(self) -> None:
return

try:
logger.info(
"Initializing NeMo Guardrails with DSPy LLM and streaming support"
logger.debug(
f"Initializing NeMo Guardrails with DSPy LLM (env={self.environment})"
)

from llm_orchestrator_config.llm_manager import LLMManager
Expand All @@ -106,33 +106,24 @@ def _ensure_initialized(self) -> None:
guardrails_loader = get_guardrails_loader()
config_path, metadata = guardrails_loader.get_optimized_config_path()

logger.info(f"Loading guardrails config from: {config_path}")
logger.debug(f"Loading guardrails config from: {config_path}")

rails_config = RailsConfig.from_path(str(config_path.parent))

rails_config.streaming = True

logger.info("Streaming configuration:")
logger.info(f" Global streaming: {rails_config.streaming}")

if hasattr(rails_config, "rails") and hasattr(rails_config.rails, "output"):
if metadata.get("optimized", False):
version = metadata.get("version", "unknown")
metrics = metadata.get("metrics", {})
accuracy = metrics.get("weighted_accuracy", "N/A") if metrics else "N/A"
logger.info(
f" Output rails config exists: {rails_config.rails.output}"
f"Guardrails ready: OPTIMIZED config v={version}, "
f"weighted_accuracy={accuracy}, env={self.environment}"
)
else:
logger.info(" Output rails config will be loaded from YAML")

if metadata.get("optimized", False):
logger.info(
f"Loaded OPTIMIZED guardrails config (version: {metadata.get('version', 'unknown')})"
f"Guardrails ready: BASE config (no optimization), env={self.environment}"
)
metrics = metadata.get("metrics", {})
if metrics:
logger.info(
f" Optimization metrics: weighted_accuracy={metrics.get('weighted_accuracy', 'N/A')}"
)
else:
logger.info("Loaded BASE guardrails config (no optimization)")

from src.guardrails.dspy_nemo_adapter import DSPyNeMoLLM

Expand All @@ -144,18 +135,16 @@ def _ensure_initialized(self) -> None:
verbose=False,
)

if (
if not (
hasattr(self._rails.config, "streaming")
and self._rails.config.streaming
):
logger.info("✓ Streaming enabled in NeMo Guardrails configuration")
else:
logger.warning(
"Streaming not enabled in configuration - this may cause issues"
)

self._initialized = True
logger.info("NeMo Guardrails initialized successfully with DSPy LLM")
logger.debug("NeMo Guardrails initialized successfully with DSPy LLM")

except Exception as e:
logger.error(f"Failed to initialize NeMo Guardrails: {str(e)}")
Expand Down
22 changes: 14 additions & 8 deletions src/intent_data_enrichment/main_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,20 @@ async def enrich_service(service_data: ServiceData) -> EnrichmentResult:
qdrant.ensure_collection()

# Delete old points before inserting new ones
qdrant.delete_service_points(service_data.service_id)

# Step 5: Bulk upsert all points (examples + summary)
logger.info(
f"Step 5: Storing {len(enriched_points)} points in Qdrant "
f"({len(service_data.examples)} examples + 1 summary)"
)
success = qdrant.upsert_service_points(enriched_points)
deleted = qdrant.delete_service_points(service_data.service_id)
if not deleted:
logger.error(
f"Failed to delete existing points for service_id={service_data.service_id}; "
"aborting upsert to avoid stale data."
)
success = False
else:
# Step 5: Bulk upsert all points (examples + summary)
logger.info(
f"Step 5: Storing {len(enriched_points)} points in Qdrant "
f"({len(service_data.examples)} examples + 1 summary)"
)
success = qdrant.upsert_service_points(enriched_points)
finally:
qdrant.close()

Expand Down
66 changes: 64 additions & 2 deletions src/llm_orchestration_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from src.utils.query_validator import validate_query_basic
from src.guardrails import NeMoRailsAdapter, GuardrailCheckResult
from src.contextual_retrieval import ContextualRetriever
from src.contextual_retrieval.bm25_search import SmartBM25Search
from src.llm_orchestrator_config.exceptions import (
ContextualRetrieverInitializationError,
ContextualRetrievalFailureError,
Expand Down Expand Up @@ -133,6 +134,13 @@ def __init__(self) -> None:
# This allows components to be initialized per-request with proper context
self.tool_classifier = None

# Shared BM25 search index pre-warmed at startup.
# Populated by _prewarm_shared_bm25() which is called from the FastAPI
# lifespan so it runs inside the async event loop. Until then it is None
# and each ContextualRetriever will build the index on first query (graceful
# degradation path).
self.shared_bm25_search: Optional[SmartBM25Search] = None

# Initialize shared guardrails adapters at startup (production and testing)
self.shared_guardrails_adapters = (
self._initialize_shared_guardrails_at_startup()
Expand Down Expand Up @@ -168,10 +176,17 @@ def _initialize_shared_guardrails_at_startup(self) -> Dict[str, NeMoRailsAdapter
connection_id=None, # Shared configuration, not user-specific
)

# Eagerly trigger the full internal initialization (NeMo config
# loading, LLMRails creation, embedding model download) so that
# the first user query is not penalised by the cold-start cost.
# Without this, _ensure_initialized() runs lazily on the first
guardrails_adapter._ensure_initialized()

elapsed_time = time.time() - start_time
adapters[env] = guardrails_adapter
logger.info(
f" Guardrails for '{env}' initialized successfully in {elapsed_time:.3f}s"
f" Guardrails for '{env}' fully initialized in {elapsed_time:.3f}s "
f"(NeMo Rails + embedding model loaded)"
)

except Exception as e:
Expand All @@ -197,6 +212,53 @@ def _initialize_shared_guardrails_at_startup(self) -> Dict[str, NeMoRailsAdapter

return adapters

async def _prewarm_shared_bm25(self) -> None:
"""
Pre-warm the shared BM25 index at application startup.

Must be called from an async context (e.g. FastAPI lifespan) so that
asyncio is available for the HTTP calls to Qdrant. Absorbs the
cold-start latency (fetching all chunks + building BM25Okapi corpus)
at deploy time so that the first real user query is not penalised.

On any failure the method logs a warning and leaves
self.shared_bm25_search as None — the ContextualRetriever will then
fall back to building the index on the first query (graceful degradation).
"""
qdrant_url = os.getenv("QDRANT_URL", "http://qdrant:6333")
logger.info("Pre-warming shared BM25 index at startup...")
prewarm_start = time.time()
try:
bm25 = SmartBM25Search(qdrant_url=qdrant_url)
success = await bm25.initialize_index()
if success:
self.shared_bm25_search = bm25
elapsed = time.time() - prewarm_start
logger.info(
f"Shared BM25 index pre-warmed in {elapsed:.2f}s "
f"({len(bm25.chunk_mapping)} chunks indexed)"
)
else:
logger.warning(
"BM25 pre-warming produced an empty index - "
"index will be built on first query instead"
)
except Exception as e:
logger.warning(
f"BM25 pre-warming failed: {e} - "
f"index will be built on first query (graceful degradation)"
)

async def aclose(self) -> None:
"""Release all long-lived async resources held by the service.

Must be awaited during application shutdown (FastAPI lifespan teardown)
to avoid connection leaks from the ToolClassifier's httpx client.
"""
if self.tool_classifier is not None:
await self.tool_classifier.aclose()
logger.debug("LLMOrchestrationService async resources closed")

@observe(name="orchestration_request", as_type="agent")
async def process_orchestration_request(
self, request: OrchestrationRequest
Expand Down Expand Up @@ -1786,7 +1848,6 @@ def _initialize_guardrails(
environment=environment, connection_id=connection_id
)

logger.info("Guardrails adapter initialized successfully")
return guardrails_adapter

except Exception as e:
Expand Down Expand Up @@ -2322,6 +2383,7 @@ def _initialize_contextual_retriever(
environment=environment,
connection_id=connection_id,
llm_service=self, # Inject self to eliminate circular dependency
shared_bm25=self.shared_bm25_search, # Inject pre-warmed BM25 index
)

logger.info("Contextual retriever initialized successfully")
Expand Down
Loading
Loading