diff --git a/docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md b/docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md new file mode 100644 index 0000000..8a67e84 --- /dev/null +++ b/docs/CONTEXT_WORKFLOW_GREETING_DETECTION.md @@ -0,0 +1,323 @@ +# Context Workflow: Greeting Detection and Conversation History Analysis + +## Overview + +The **Context Workflow (Layer 2)** intercepts user queries that can be answered without searching the knowledge base. It handles two categories: + +1. **Greetings** — Detects and responds to social exchanges (hello, goodbye, thanks) in multiple languages +2. **Conversation history references** — Answers follow-up questions that refer to information already discussed in the session + +When the context workflow can answer, a response is returned immediately, bypassing the RAG pipeline entirely. When it cannot answer, the query falls through to the RAG workflow (Layer 3). + +--- + +## Architecture + +### Position in the Classifier Chain + +``` +User Query + ↓ +Layer 1: SERVICE → External API calls + ↓ (cannot handle) +Layer 2: CONTEXT → Greetings + conversation history ←── This document + ↓ (cannot handle) +Layer 3: RAG → Knowledge base retrieval + ↓ (cannot handle) +Layer 4: OOD → Out-of-domain fallback +``` + +### Key Components + +| Component | File | Responsibility | +|-----------|------|----------------| +| `ContextAnalyzer` | `src/tool_classifier/context_analyzer.py` | LLM-based greeting detection and context analysis | +| `ContextWorkflowExecutor` | `src/tool_classifier/workflows/context_workflow.py` | Orchestrates the workflow, handles streaming/non-streaming | +| `ToolClassifier` | `src/tool_classifier/classifier.py` | Invokes `ContextAnalyzer` during classification and routes to `ContextWorkflowExecutor` | +| `greeting_constants.py` | `src/tool_classifier/greeting_constants.py` | Fallback greeting responses for Estonian and English | + +--- + +## Full Request Flow + +``` +User Query + Conversation History + ↓ +ToolClassifier.classify() + ├─ Layer 1 (SERVICE): Embedding-based intent routing + │ └─ If no service tool matches → route to CONTEXT workflow + │ + └─ ClassificationResult(workflow=CONTEXT) + +ToolClassifier.route_to_workflow() + ├─ Non-streaming → ContextWorkflowExecutor.execute_async() + │ ├─ Phase 1: _detect() → context_analyzer.detect_context() [classification only] + │ ├─ If greeting → return greeting OrchestrationResponse + │ ├─ If can_answer → _generate_response_async() → context_analyzer.generate_context_response() + │ └─ Otherwise → return None (RAG fallback) + │ + └─ Streaming → ContextWorkflowExecutor.execute_streaming() + ├─ Phase 1: _detect() → context_analyzer.detect_context() [classification only] + ├─ If greeting → _stream_greeting() async generator + ├─ If can_answer → _create_history_stream() → context_analyzer.stream_context_response() + └─ Otherwise → return None (RAG fallback) +``` + +--- + +## Phase 1: Detection (Classify Only) + +### LLM Task + +Every query is checked against the **most recent 10 conversation turns** using a single LLM call (`detect_context()`). This phase **does not generate an answer** — it only classifies the query and extracts a relevant context snippet for Phase 2. + +The `ContextDetectionSignature` DSPy signature instructs the LLM to: + +1. Detect if the query is a greeting in any supported language +2. Check if the query references something discussed in the last 10 turns +3. If the query can be answered from history, extract the relevant snippet +4. Do **not** generate the final answer here — detection only + +### LLM Output Format + +The LLM returns a JSON object parsed into `ContextDetectionResult`: + +```json +{ + "is_greeting": false, + "can_answer_from_context": true, + "reasoning": "User is asking about tax rate discussed earlier", + "context_snippet": "Bot confirmed the flat rate is 20%, applying equally to all income brackets." +} +``` + +| Field | Type | Description | +|-------|------|-------------| +| `is_greeting` | `bool` | Whether the query is a greeting | +| `can_answer_from_context` | `bool` | Whether the query can be answered from conversation history | +| `reasoning` | `str` | Brief explanation of the detection decision | +| `context_snippet` | `str \| null` | Relevant excerpt from history for use in Phase 2, or `null` | + +> **Internal field**: `answered_from_summary` (bool, default `False`) is reserved for future summary-based detection paths. + +### Decision After Phase 1 + +``` +is_greeting=True → Phase 2: return greeting response (no LLM call) +can_answer_from_context=True AND snippet set → Phase 2: generate answer from snippet +Otherwise → Fall back to RAG +``` + +--- + +## Phase 2: Response Generation + +### Non-Streaming (`_generate_response_async`) + +Calls `generate_context_response(query, context_snippet)` which uses `ContextResponseGenerationSignature` to produce a complete answer in a single LLM call. Output guardrails are applied before returning the `OrchestrationResponse`. + +### Streaming (`_create_history_stream` → `stream_context_response`) + +Calls `stream_context_response(query, context_snippet)` which uses DSPy native streaming (`dspy.streamify`) with `ContextResponseGenerationSignature`. Tokens are yielded in real time and passed through NeMo Guardrails before being SSE-formatted. + +--- + +--- + +## Greeting Detection + +### Supported Languages + +| Language | Code | +|----------|------| +| Estonian | `et` | +| English | `en` | + +### Supported Greeting Types + +| Type | Estonian Examples | English Examples | +|------|-------------------|-----------------| +| `hello` | Tere, Hei, Tervist, Moi | Hello, Hi, Hey, Good morning | +| `goodbye` | Nägemist, Tšau | Bye, Goodbye, See you, Good night | +| `thanks` | Tänan, Aitäh, Tänud | Thank you, Thanks | +| `casual` | Tere, Tervist | Hey | + +### Greeting Response Generation + +Greeting detection is handled in **Phase 1 (`detect_context`)**, where the LLM classifies whether the query is a greeting and, if so, identifies the language and greeting type. This phase does **not** generate the final natural-language reply. +In **Phase 2**, `ContextWorkflowExecutor` calls `get_greeting_response(...)`, which returns a response based on predefined static templates in `greeting_constants.py`, ensuring the reply is in the detected language. If greeting detection fails or the greeting type is unsupported, the query falls through to the next workflow layer instead of attempting LLM-based greeting generation. +**Greeting response templates (`greeting_constants.py`):** + +```python +GREETINGS_ET = { + "hello": "Tere! Kuidas ma saan sind aidata?", + "goodbye": "Nägemist! Head päeva!", + "thanks": "Palun! Kui on veel küsimusi, küsi julgelt.", + "casual": "Tere! Mida ma saan sinu jaoks teha?", +} + +GREETINGS_EN = { + "hello": "Hello! How can I help you?", + "goodbye": "Goodbye! Have a great day!", + "thanks": "You're welcome! Feel free to ask if you have more questions.", + "casual": "Hey! What can I do for you?", +} +``` + +The fallback greeting type is determined by keyword matching in `_detect_greeting_type()` — checking for `thank/tänan/aitäh`, `bye/goodbye/nägemist/tšau`, before defaulting to `hello`. + +--- + +## Streaming Support + +The context workflow supports both response modes: + +### Non-Streaming (`execute_async`) + +Returns a complete `OrchestrationResponse` object with the answer as a single string. Output guardrails are applied before the response is returned. + +### Streaming (`execute_streaming`) + +Returns an `AsyncIterator[str]` that yields SSE (Server-Sent Events) chunks. + +**Greeting responses** are yielded as a single SSE chunk followed by `END`. + +**History responses** use DSPy native streaming (`dspy.streamify`) with `ContextResponseGenerationSignature`. Tokens are emitted in real time as they arrive from the LLM, then passed through NeMo Guardrails (`stream_with_guardrails`) before being SSE-formatted. If a guardrail violation is detected in a chunk, streaming stops and the violation message is sent instead. + +**SSE Format:** +``` +data: {"chatId": "abc123", "payload": {"content": "Tere! Kuidas ma"}, "timestamp": "...", "sentTo": []} + +data: {"chatId": "abc123", "payload": {"content": " saan sind aidata?"}, "timestamp": "...", "sentTo": []} + +data: {"chatId": "abc123", "payload": {"content": "END"}, "timestamp": "...", "sentTo": []} +``` + +--- + +## Cost Tracking + +LLM token usage and cost is tracked via `get_lm_usage_since()` and stored in `costs_metric` within the workflow executor. Costs are logged via `orchestration_service.log_costs()` at the end of each execution path. + +Two cost keys are tracked separately: + +```python +costs_metric = { + "context_detection": { + # Phase 1: detect_context() — single LLM call + "total_cost": 0.0012, + "total_tokens": 180, + "total_prompt_tokens": 150, + "total_completion_tokens": 30, + "num_calls": 1, + }, + "context_response": { + # Phase 2: generate_context_response() or stream_context_response() + "total_cost": 0.003, + "total_tokens": 140, + "total_prompt_tokens": 100, + "total_completion_tokens": 40, + "num_calls": 1, + }, +} +``` + +Greeting responses skip Phase 2, so only `"context_detection"` cost is populated. + +--- + +--- + +## Error Handling and Fallback + +| Failure Point | Behaviour | +|---------------|-----------| +| Phase 1 LLM call raises exception | `can_answer_from_context=False` → falls back to RAG | +| Phase 1 returns invalid JSON | Logged as warning, all flags default to `False` → falls back to RAG | +| Phase 2 LLM call raises exception | Logged as error, `_generate_response_async` returns `None` → falls back to RAG | +| Phase 2 returns empty answer | Logged as warning → falls back to RAG | +| Output guardrails fail | Logged as warning, response returned without guardrail check | +| Guardrail violation in streaming | `OUTPUT_GUARDRAIL_VIOLATION_MESSAGE` sent, stream terminated | +| `orchestration_service` unavailable | History streaming skipped → `None` returned → RAG fallback | +| `guardrails_adapter` not a `NeMoRailsAdapter` | Logged as warning → cannot stream → RAG fallback | +| Any unhandled exception in executor | Error logged, `execute_async/execute_streaming` returns `None` → RAG fallback via classifier | + +--- + +## Logging + +Key log entries emitted during a request: + +| Level | Message | When | +|-------|---------|------| +| `INFO` | `CONTEXT WORKFLOW (NON-STREAMING) \| Query: '...'` | `execute_async()` entry | +| `INFO` | `CONTEXT WORKFLOW (STREAMING) \| Query: '...'` | `execute_streaming()` entry | +| `INFO` | `CONTEXT DETECTOR: Phase 1 \| Query: '...' \| History: N turns` | `detect_context()` entry | +| `INFO` | `DETECTION RESULT \| Greeting: ... \| Can Answer: ... \| Has snippet: ...` | Phase 1 LLM response parsed | +| `INFO` | `Detection cost \| Total: $... \| Tokens: N` | After Phase 1 cost tracked | +| `INFO` | `Detection: greeting=... can_answer=...` | After `_detect()` returns in executor | +| `INFO` | `CONTEXT GENERATOR: Phase 2 non-streaming \| Query: '...'` | `generate_context_response()` entry | +| `INFO` | `CONTEXT GENERATOR: Phase 2 streaming \| Query: '...'` | `stream_context_response()` entry | +| `INFO` | `Context response streaming complete (final Prediction received)` | DSPy streaming finished | +| `WARNING` | `[chatId] Phase 2 empty answer — fallback to RAG` | Phase 2 returned no content | +| `WARNING` | `[chatId] Guardrails violation in context streaming` | Violation detected mid-stream | +| `WARNING` | `[chatId] Cannot answer from context — falling back to RAG` | Neither phase could answer | + +--- + +## Data Models + +### `ContextDetectionResult` (Phase 1 output) + +```python +class ContextDetectionResult(BaseModel): + is_greeting: bool # True if query is a greeting + can_answer_from_context: bool # True if query can be answered from last 10 turns + reasoning: str # LLM's brief explanation + answered_from_summary: bool # Reserved; always False in current workflow + context_snippet: Optional[str] # Relevant excerpt for Phase 2 generation, or None +``` + +### `ContextDetectionSignature` (DSPy — Phase 1) + +| Field | Type | Description | +|-------|------|-------------| +| `conversation_history` | Input | Last 10 turns formatted as JSON | +| `user_query` | Input | Current user query | +| `detection_result` | Output | JSON with `is_greeting`, `can_answer_from_context`, `reasoning`, `context_snippet` | + +> Detection only — **no answer generated here**. + +### `ContextResponseGenerationSignature` (DSPy — Phase 2) + +| Field | Type | Description | +|-------|------|-------------| +| `context_snippet` | Input | Relevant excerpt from Phase 1 | +| `user_query` | Input | Current user query | +| `answer` | Output | Natural language response in the same language as the query | + +--- + +## Decision Summary Table + +| Scenario | Phase 1 LLM Calls | Phase 2 LLM Calls | Outcome | +|----------|--------------------|--------------------|---------| +| Greeting detected | 1 (`detect_context`) | 0 (static response) | Context responds (greeting) | +| Follow-up answerable from last 10 turns | 1 (`detect_context`) | 1 (`generate_context_response` or `stream_context_response`) | Context responds | +| Cannot answer from last 10 turns | 1 (`detect_context`) | 0 | Falls back to RAG | +| Phase 1 LLM error / JSON parse failure | — | 0 | Falls back to RAG | +| Phase 2 LLM error or empty answer | 1 | — | Falls back to RAG | + +--- + +## File Reference + +| File | Purpose | +|------|---------| +| `src/tool_classifier/context_analyzer.py` | Core LLM analysis logic (all three steps) | +| `src/tool_classifier/workflows/context_workflow.py` | Workflow executor (streaming + non-streaming) | +| `src/tool_classifier/classifier.py` | Classification layer that invokes context analysis | +| `src/tool_classifier/greeting_constants.py` | Static fallback greeting responses (ET/EN) | +| `tests/test_context_analyzer.py` | Unit tests for `ContextAnalyzer` | +| `tests/test_context_workflow.py` | Unit tests for `ContextWorkflowExecutor` | +| `tests/test_context_workflow_integration.py` | Integration tests for the full classify → route → execute chain | \ No newline at end of file diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 7f7432f..7889987 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -639,11 +639,13 @@ async def stream_orchestration_response( ) # Classify query to determine workflow + start_time = time.time() classification = await self.tool_classifier.classify( query=request.message, conversation_history=request.conversationHistory, language=detected_language, ) + time_metric["classifier.classify"] = time.time() - start_time logger.info( f"[{request.chatId}] [{stream_ctx.stream_id}] Classification: {classification.workflow.value} " @@ -652,11 +654,14 @@ async def stream_orchestration_response( # Route to appropriate workflow (streaming) # route_to_workflow returns AsyncIterator[str] when is_streaming=True + start_time = time.time() stream_result = await self.tool_classifier.route_to_workflow( classification=classification, request=request, is_streaming=True, + time_metric=time_metric, ) + time_metric["classifier.route"] = time.time() - start_time async for sse_chunk in stream_result: yield sse_chunk diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py index 0e9b127..110c299 100644 --- a/src/llm_orchestration_service_api.py +++ b/src/llm_orchestration_service_api.py @@ -71,7 +71,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: if StreamConfig.RATE_LIMIT_ENABLED: app.state.rate_limiter = RateLimiter( requests_per_minute=StreamConfig.RATE_LIMIT_REQUESTS_PER_MINUTE, - tokens_per_second=StreamConfig.RATE_LIMIT_TOKENS_PER_SECOND, + tokens_per_minute=StreamConfig.RATE_LIMIT_TOKENS_PER_MINUTE, ) logger.info("Rate limiter initialized successfully") else: diff --git a/src/llm_orchestrator_config/stream_config.py b/src/llm_orchestrator_config/stream_config.py index ad19338..84e5edd 100644 --- a/src/llm_orchestrator_config/stream_config.py +++ b/src/llm_orchestrator_config/stream_config.py @@ -21,8 +21,7 @@ class StreamConfig: # Rate Limiting Configuration RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting - RATE_LIMIT_REQUESTS_PER_MINUTE: int = 10 # Max requests per user per minute - RATE_LIMIT_TOKENS_PER_SECOND: int = ( - 100 # Max tokens per user per second (burst control) - ) + RATE_LIMIT_REQUESTS_PER_MINUTE: int = 20 # Max requests per user per minute + RATE_LIMIT_TOKENS_PER_MINUTE: int = 40_000 # Max tokens per user per minute RATE_LIMIT_CLEANUP_INTERVAL: int = 300 # Cleanup old entries every 5 minutes + RATE_LIMIT_TOKEN_WINDOW_SECONDS: int = 60 # Sliding window size for token tracking diff --git a/src/tool_classifier/classifier.py b/src/tool_classifier/classifier.py index f18ef3e..1ada894 100644 --- a/src/tool_classifier/classifier.py +++ b/src/tool_classifier/classifier.py @@ -57,9 +57,9 @@ class ToolClassifier: def __init__( self, - llm_manager: Any, - orchestration_service: Any, - ): + llm_manager: Any, # noqa: ANN401 + orchestration_service: Any, # noqa: ANN401 + ) -> None: """ Initialize tool classifier with required dependencies. @@ -88,6 +88,7 @@ def __init__( ) self.context_workflow = ContextWorkflowExecutor( llm_manager=llm_manager, + orchestration_service=orchestration_service, ) self.rag_workflow = RAGWorkflowExecutor( orchestration_service=orchestration_service, @@ -622,7 +623,7 @@ def _get_workflow_executor(self, workflow_type: WorkflowType) -> Any: async def _execute_with_fallback_async( self, - workflow: Any, + workflow: Any, # noqa: ANN401 request: OrchestrationRequest, context: Dict[str, Any], start_layer: WorkflowType, @@ -696,11 +697,11 @@ async def _execute_with_fallback_async( if rag_result is not None: return rag_result else: - raise RuntimeError("RAG workflow returned None unexpectedly") + raise RuntimeError("RAG workflow returned None unexpectedly") from e async def _execute_with_fallback_streaming( self, - workflow: Any, + workflow: Any, # noqa: ANN401 request: OrchestrationRequest, context: Dict[str, Any], start_layer: WorkflowType, @@ -782,4 +783,4 @@ async def _execute_with_fallback_streaming( async for chunk in streaming_result: yield chunk else: - raise RuntimeError("RAG workflow returned None unexpectedly") + raise RuntimeError("RAG workflow returned None unexpectedly") from e diff --git a/src/tool_classifier/constants.py b/src/tool_classifier/constants.py index 65f3033..d839e2c 100644 --- a/src/tool_classifier/constants.py +++ b/src/tool_classifier/constants.py @@ -70,13 +70,15 @@ DENSE_SEARCH_TOP_K = 3 """Number of top results from dense-only search for relevance scoring.""" -DENSE_MIN_THRESHOLD = 0.38 +# DENSE_MIN_THRESHOLD = 0.38 +DENSE_MIN_THRESHOLD = 0.5 """Minimum dense cosine similarity to consider a result as a potential match. Below this → skip SERVICE entirely, go to CONTEXT/RAG. Note: Multilingual embeddings (Estonian/short queries) typically yield lower cosine scores (0.25-0.40) than English. Tune based on observed scores.""" -DENSE_HIGH_CONFIDENCE_THRESHOLD = 0.40 +# DENSE_HIGH_CONFIDENCE_THRESHOLD = 0.40 +DENSE_HIGH_CONFIDENCE_THRESHOLD = 0.55 """Dense cosine similarity for high-confidence service classification. Above this AND score gap is large → SERVICE without LLM confirmation.""" diff --git a/src/tool_classifier/context_analyzer.py b/src/tool_classifier/context_analyzer.py new file mode 100644 index 0000000..4572aef --- /dev/null +++ b/src/tool_classifier/context_analyzer.py @@ -0,0 +1,900 @@ +"""Context analyzer for greeting detection and conversation history analysis.""" + +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, List, Optional +import json +import dspy +import dspy.streaming +from dspy.streaming import StreamListener +from loguru import logger +from pydantic import BaseModel, Field + +from src.utils.cost_utils import get_lm_usage_since +from src.tool_classifier.greeting_constants import get_greeting_response + + +class ContextAnalysisResult(BaseModel): + """Result of context analysis.""" + + is_greeting: bool = Field( + ..., description="Whether the query is a greeting (hello, goodbye, thanks)" + ) + can_answer_from_context: bool = Field( + ..., description="Whether the query can be answered from conversation history" + ) + answer: Optional[str] = Field( + None, description="Generated response (greeting or context-based answer)" + ) + reasoning: str = Field(..., description="Brief explanation of the analysis") + answered_from_summary: bool = Field( + default=False, + description="Whether the answer was derived from a conversation summary (older turns beyond recent 10)", + ) + + +class ContextAnalysisSignature(dspy.Signature): + """Analyze user query for greeting detection and conversation history references. + + This signature instructs the LLM to: + 1. Detect greetings in multiple languages (Estonian, English) + 2. Check if query references conversation history + 3. Generate appropriate responses or extract answers from history + + Supported greeting types: + - hello: Tere, Hello, Hi, Hei, Hey, Moi, Good morning, Good afternoon, Good evening + - goodbye: Nägemist, Bye, Goodbye, See you, Good night + - thanks: Tänan, Aitäh, Thank you, Thanks, Much appreciated + - casual: Tervist, Tšau, Moikka + + The LLM should respond in the SAME language as the user's query. + """ + + conversation_history: str = dspy.InputField( + desc="Recent conversation history (last 10 turns) formatted as JSON" + ) + user_query: str = dspy.InputField( + desc="Current user query to analyze for greetings or context references" + ) + analysis_result: str = dspy.OutputField( + desc='JSON object with: {"is_greeting": bool, "can_answer_from_context": bool, "answer": str|null, "reasoning": str}. ' + "For greetings, generate a friendly response in the same language. " + "For context references, extract the answer from conversation history if available." + ) + + +class ConversationSummarySignature(dspy.Signature): + """Generate a concise summary of conversation history. + + Summarize the key topics, facts, decisions, and information discussed + in the conversation. Preserve specific details like numbers, names, + dates, and other factual information that might be referenced later. + + The summary should be in the SAME language as the conversation. + """ + + conversation_history: str = dspy.InputField( + desc="Conversation history formatted as JSON to summarize" + ) + summary: str = dspy.OutputField( + desc="Concise summary capturing key topics, facts, and information discussed. " + "Preserve specific details (numbers, names, dates) that could be referenced later." + ) + + +class SummaryAnalysisSignature(dspy.Signature): + """Analyze if a user query can be answered from a conversation summary. + + Given a summary of earlier conversation and the current user query, + determine if the query references information from the summarized conversation. + If yes, generate an appropriate answer based on the summary. + + The response should be in the SAME language as the user's query. + """ + + conversation_summary: str = dspy.InputField( + desc="Summary of earlier conversation history" + ) + user_query: str = dspy.InputField( + desc="Current user query to check against the conversation summary" + ) + analysis_result: str = dspy.OutputField( + desc='JSON object with: {"can_answer_from_context": bool, "answer": str|null, "reasoning": str}. ' + "If the query references information from the summary, extract/generate the answer. " + "If the summary does not contain relevant information, set can_answer_from_context to false." + ) + + +class ContextDetectionResult(BaseModel): + """Result of Phase 1 context detection (classify only, no answer generation).""" + + is_greeting: bool = Field(..., description="Whether the query is a greeting") + greeting_type: str = Field( + default="hello", + description="Type of greeting: hello, goodbye, thanks, or casual", + ) + can_answer_from_context: bool = Field( + ..., description="Whether the query can be answered from conversation history" + ) + reasoning: str = Field(..., description="Brief explanation of the detection") + answered_from_summary: bool = Field( + default=False, + description="Whether summary analysis was used for detection", + ) + # Relevant context snippet extracted for use in Phase 2 generation + context_snippet: Optional[str] = Field( + default=None, + description="The relevant part of history/summary to answer from, for Phase 2", + ) + + +class ContextDetectionSignature(dspy.Signature): + """Detect if a user query is a greeting or can be answered from conversation history. + + Phase 1 (detection only): classify the query WITHOUT generating the answer. + + Supported greeting types: + - hello: Tere, Hello, Hi, Hei, Hey, Moi, Good morning/afternoon/evening + - goodbye: Nägemist, Bye, Goodbye, See you, Good night + - thanks: Tänan, Aitäh, Thank you, Thanks, Much appreciated + - casual: Tervist, Tšau, Moikka + + Do NOT generate the answer here — only detect and extract a relevant context snippet. + """ + + conversation_history: str = dspy.InputField( + desc="Recent conversation history (last 10 turns) formatted as JSON" + ) + user_query: str = dspy.InputField(desc="Current user query to classify") + detection_result: str = dspy.OutputField( + desc='JSON object with: {"is_greeting": bool, "greeting_type": str, "can_answer_from_context": bool, ' + '"reasoning": str, "context_snippet": str|null}. ' + 'greeting_type must be one of: "hello", "goodbye", "thanks", "casual" — ' + 'set it only when is_greeting is true, defaulting to "hello" otherwise. ' + "context_snippet should contain the relevant excerpt from history if can_answer_from_context is true, " + "or null otherwise. Do NOT generate the final answer — only detect and extract." + ) + + +class ContextResponseGenerationSignature(dspy.Signature): + """Generate a response to a user query based on conversation history context. + + Phase 2 (generation): given the user query and relevant context, generate a helpful answer. + Respond in the SAME language as the user query. + """ + + context_snippet: str = dspy.InputField( + desc="Relevant excerpt from conversation history or summary that contains the answer" + ) + user_query: str = dspy.InputField(desc="Current user query to answer") + answer: str = dspy.OutputField( + desc="A helpful, natural response to the user query based on the provided context. " + "Respond in the same language as the user query." + ) + + +class ContextAnalyzer: + """ + Analyzer for greeting detection and context-based question answering. + + This class uses an LLM to intelligently detect: + - Greetings in multiple languages (Estonian, English) + - Questions that reference conversation history + - Generate appropriate responses based on context + + Example Usage: + analyzer = ContextAnalyzer(llm_manager) + result = await analyzer.analyze_context( + query="Tere!", + conversation_history=[], + language="et" + ) + # result.is_greeting = True + # result.answer = "Tere! Kuidas ma saan sind aidata?" + """ + + def __init__(self, llm_manager: Any) -> None: # noqa: ANN401 + """ + Initialize the context analyzer. + + Args: + llm_manager: LLM manager instance for making LLM calls + """ + self.llm_manager = llm_manager + self._module: Optional[dspy.Module] = None + self._summary_module: Optional[dspy.Module] = None + self._summary_analysis_module: Optional[dspy.Module] = None + # Phase 1 & 2 modules for two-phase detection+generation flow + self._detection_module: Optional[dspy.Module] = None + self._response_generation_module: Optional[dspy.Module] = None + self._stream_predictor: Optional[Any] = None + logger.info("Context analyzer initialized") + + def _format_conversation_history( + self, conversation_history: List[Dict[str, Any]], max_turns: int = 10 + ) -> str: + """ + Format conversation history for LLM consumption. + + Args: + conversation_history: List of conversation items with authorRole, message, timestamp + max_turns: Maximum number of turns to include (default: 10) + + Returns: + Formatted conversation history as JSON string + """ + # Take last N turns + recent_history = ( + conversation_history[-max_turns:] if conversation_history else [] + ) + + # Format as readable JSON + formatted_history = [ + { + "role": item.get("authorRole", "unknown"), + "message": item.get("message", ""), + "timestamp": item.get("timestamp", ""), + } + for item in recent_history + ] + + if not formatted_history: + return "[]" + + return json.dumps(formatted_history, ensure_ascii=False, indent=2) + + @staticmethod + def _merge_cost_dicts( + cost1: Dict[str, Any], cost2: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Merge two cost dictionaries by summing numeric values. + + Args: + cost1: First cost dictionary + cost2: Second cost dictionary + + Returns: + Merged cost dictionary with summed values + """ + return { + "total_cost": cost1.get("total_cost", 0) + cost2.get("total_cost", 0), + "total_tokens": cost1.get("total_tokens", 0) + cost2.get("total_tokens", 0), + "total_prompt_tokens": cost1.get("total_prompt_tokens", 0) + + cost2.get("total_prompt_tokens", 0), + "total_completion_tokens": cost1.get("total_completion_tokens", 0) + + cost2.get("total_completion_tokens", 0), + "num_calls": cost1.get("num_calls", 0) + cost2.get("num_calls", 0), + } + + async def detect_context( + self, + query: str, + conversation_history: List[Dict[str, Any]], + ) -> tuple[ContextDetectionResult, Dict[str, Any]]: + """ + Phase 1: Detect if query is a greeting or can be answered from history. + + Classify-only — no answer generated here. Returns a ContextDetectionResult + with is_greeting/can_answer_from_context flags and a context_snippet for + Phase 2 generation. + + Args: + query: User query to classify + conversation_history: Full conversation history + + Returns: + Tuple of (ContextDetectionResult, cost_dict) + """ + total_turns = len(conversation_history) + logger.info( + f"CONTEXT DETECTOR: Phase 1 | Query: '{query[:100]}' | " + f"History: {total_turns} turns" + ) + + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length for detection: {e}") + + formatted_history = self._format_conversation_history(conversation_history) + + self.llm_manager.ensure_global_config() + try: + with self.llm_manager.use_task_local(): + if self._detection_module is None: + self._detection_module = dspy.ChainOfThought( + ContextDetectionSignature + ) + response = self._detection_module( + conversation_history=formatted_history, + user_query=query, + ) + + try: + detection_data = json.loads(response.detection_result) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse detection response: {response.detection_result[:100]}" + ) + detection_data = { + "is_greeting": False, + "can_answer_from_context": False, + "reasoning": "Failed to parse detection response", + "context_snippet": None, + } + + result = ContextDetectionResult( + is_greeting=detection_data.get("is_greeting", False), + greeting_type=detection_data.get("greeting_type", "hello"), + can_answer_from_context=detection_data.get( + "can_answer_from_context", False + ), + reasoning=detection_data.get("reasoning", "Detection completed"), + context_snippet=detection_data.get("context_snippet"), + ) + logger.info( + f"DETECTION RESULT | Greeting: {result.is_greeting} | " + f"Can Answer: {result.can_answer_from_context} | " + f"Has snippet: {result.context_snippet is not None}" + ) + + except Exception as e: + logger.error(f"Context detection failed: {e}", exc_info=True) + result = ContextDetectionResult( + is_greeting=False, + can_answer_from_context=False, + reasoning=f"Detection error: {str(e)}", + ) + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Detection cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + return result, cost_dict + + async def stream_context_response( + self, + query: str, + context_snippet: str, + ) -> AsyncIterator[str]: + """ + Phase 2 (streaming): Stream a generated answer using DSPy native streaming. + + Uses ContextResponseGenerationSignature with DSPy's streamify() so tokens + are yielded in real time — same mechanism as ResponseGeneratorAgent.stream_response(). + + Args: + query: The user query to answer + context_snippet: Relevant context extracted during Phase 1 detection + + Yields: + Token strings as they arrive from the LLM + """ + logger.info(f"CONTEXT GENERATOR: Phase 2 streaming | Query: '{query[:100]}'") + + self.llm_manager.ensure_global_config() + output_stream = None + stream_started = False + try: + with self.llm_manager.use_task_local(): + if self._stream_predictor is None: + answer_listener = StreamListener(signature_field_name="answer") + self._stream_predictor = dspy.streamify( + dspy.Predict(ContextResponseGenerationSignature), + stream_listeners=[answer_listener], + ) + output_stream = self._stream_predictor( + context_snippet=context_snippet, + user_query=query, + ) + + async for chunk in output_stream: + if isinstance(chunk, dspy.streaming.StreamResponse): + if chunk.signature_field_name == "answer": + stream_started = True + yield chunk.chunk + elif isinstance(chunk, dspy.Prediction): + logger.info( + "Context response streaming complete (final Prediction received)" + ) + + if not stream_started: + logger.warning( + "Context streaming finished but no 'answer' tokens received." + ) + except GeneratorExit: + raise + except Exception as e: + logger.error(f"Error during context response streaming: {e}") + raise + finally: + if output_stream is not None: + try: + await output_stream.aclose() + except Exception as cleanup_error: + logger.debug( + f"Error during context stream cleanup: {cleanup_error}" + ) + + async def generate_context_response( + self, + query: str, + context_snippet: str, + ) -> tuple[str, Dict[str, Any]]: + """ + Phase 2 (non-streaming): Generate a complete answer from context snippet. + + Used for non-streaming mode after Phase 1 detection confirms context can answer. + + Args: + query: The user query to answer + context_snippet: Relevant context extracted during Phase 1 detection + + Returns: + Tuple of (answer_text, cost_dict) + """ + logger.info( + f"CONTEXT GENERATOR: Phase 2 non-streaming | Query: '{query[:100]}'" + ) + + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length for generation: {e}") + + self.llm_manager.ensure_global_config() + answer = "" + try: + with self.llm_manager.use_task_local(): + if self._response_generation_module is None: + self._response_generation_module = dspy.ChainOfThought( + ContextResponseGenerationSignature + ) + response = self._response_generation_module( + context_snippet=context_snippet, + user_query=query, + ) + answer = getattr(response, "answer", "") or "" + logger.info( + f"Context response generated: {len(answer)} chars | " + f"Preview: '{answer[:150]}'" + ) + except Exception as e: + logger.error(f"Context response generation failed: {e}", exc_info=True) + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Generation cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + return answer, cost_dict + + async def _generate_conversation_summary( + self, + older_history: List[Dict[str, Any]], + ) -> tuple[str, Dict[str, Any]]: + """ + Generate a concise summary of older conversation turns. + + Args: + older_history: Conversation turns older than the recent 10 + + Returns: + Tuple of (summary_text, cost_dict) + """ + logger.info(f"SUMMARY GENERATION: Summarizing {len(older_history)} older turns") + + # Track costs + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length for summary: {e}") + + # Format older history + formatted_history = self._format_conversation_history( + older_history, max_turns=len(older_history) + ) + + # Initialize and run summary module within task-local LLM config + try: + self.llm_manager.ensure_global_config() + with self.llm_manager.use_task_local(): + if self._summary_module is None: + self._summary_module = dspy.ChainOfThought( + ConversationSummarySignature + ) + response = self._summary_module( + conversation_history=formatted_history, + ) + summary = response.summary + logger.info( + f"Summary generated: {len(summary)} chars | " + f"Preview: '{summary[:150]}...'" + ) + except Exception as e: + logger.error(f"Summary generation failed: {e}", exc_info=True) + summary = "" + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Summary cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + + return summary, cost_dict + + async def _analyze_from_summary( + self, + query: str, + summary: str, + ) -> tuple[ContextAnalysisResult, Dict[str, Any]]: + """ + Check if a query can be answered from a conversation summary. + + Args: + query: User query to check + summary: Summary of older conversation turns + + Returns: + Tuple of (ContextAnalysisResult, cost_dict) + """ + logger.info( + f"SUMMARY ANALYSIS: Checking query against summary | Query: '{query[:100]}'" + ) + + # Ensure DSPy is configured and run analysis in a task-local LM context + self.llm_manager.ensure_global_config() + history_length_before = 0 + with self.llm_manager.use_task_local(): + # Track costs + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning( + f"Failed to get LM history length for summary analysis: {e}" + ) + # Initialize summary analysis module if needed + if self._summary_analysis_module is None: + self._summary_analysis_module = dspy.ChainOfThought( + SummaryAnalysisSignature + ) + try: + response = self._summary_analysis_module( + conversation_summary=summary, + user_query=query, + ) + # Parse JSON response + try: + analysis_data = json.loads(response.analysis_result) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse summary analysis response: " + f"{response.analysis_result[:100]}" + ) + analysis_data = { + "can_answer_from_context": False, + "answer": None, + "reasoning": "Failed to parse summary analysis response", + } + can_answer = analysis_data.get("can_answer_from_context", False) + answer = analysis_data.get("answer") + reasoning = analysis_data.get("reasoning", "Summary analysis completed") + logger.debug( + f"Raw summary analysis parsed | " + f"can_answer_from_context={can_answer} | " + f"has_answer={answer is not None}" + ) + # Only mark as answerable when both the LLM flag is True AND an answer exists + can_answer_from_context = bool(can_answer and answer) + result = ContextAnalysisResult( + is_greeting=False, + can_answer_from_context=can_answer_from_context, + answer=answer, + reasoning=reasoning, + answered_from_summary=can_answer_from_context, + ) + logger.info( + "SUMMARY ANALYSIS RESULT | " + f"Can answer from summary: {can_answer} | " + f"Can answer from context: {can_answer_from_context} | " + f"Has answer: {answer is not None} | Reasoning: {reasoning}" + ) + except Exception as e: + logger.error(f"Summary analysis failed: {e}", exc_info=True) + result = ContextAnalysisResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning=f"Summary analysis error: {str(e)}", + ) + + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Summary analysis cost | Total: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)}" + ) + + return result, cost_dict + + async def analyze_context( + self, + query: str, + conversation_history: List[Dict[str, Any]], + language: str = "et", + ) -> tuple[ContextAnalysisResult, Dict[str, Any]]: + """ + Analyze if query is a greeting or can be answered from conversation history. + + Implements a 3-step flow: + 1. Analyze recent 10 turns for greetings and history-answerable queries + 2. If cannot answer and total history > 10 turns, generate a summary of older turns + 3. Check if the query can be answered from the summary + 4. If still cannot answer, return cannot-answer result (falls through to RAG) + + Args: + query: User query to analyze + conversation_history: List of conversation items + language: Language code (et, en) for response generation + + Returns: + Tuple of (ContextAnalysisResult, cost_dict) + """ + total_turns = len(conversation_history) + logger.info( + f"CONTEXT ANALYZER: Starting analysis | Query: '{query[:100]}' | " + f"History: {total_turns} turns | Language: {language}" + ) + + # STEP 1: Analyze recent 10 turns (existing behavior) + result, cost_dict = await self._analyze_recent_history( + query=query, + conversation_history=conversation_history, + language=language, + ) + + # If greeting or can answer from recent history, return immediately + if (result.is_greeting or result.can_answer_from_context) and result.answer: + logger.info( + f"Answered from recent history | " + f"Greeting: {result.is_greeting} | From context: {result.can_answer_from_context}" + ) + return result, cost_dict + + # STEP 2 & 3: If history > 10 turns and couldn't answer from recent, try summary + if total_turns > 10: + logger.info( + f"History exceeds 10 turns ({total_turns} total) | " + f"Cannot answer from recent 10 | Attempting summary-based analysis" + ) + + # Get older turns (everything before the last 10) + older_history = conversation_history[:-10] + logger.info(f"Older history: {len(older_history)} turns to summarize") + + try: + # Generate summary of older turns + summary, summary_cost = await self._generate_conversation_summary( + older_history + ) + cost_dict = self._merge_cost_dicts(cost_dict, summary_cost) + + if summary: + # Analyze query against summary + summary_result, analysis_cost = await self._analyze_from_summary( + query=query, + summary=summary, + ) + cost_dict = self._merge_cost_dicts(cost_dict, analysis_cost) + + if summary_result.can_answer_from_context and summary_result.answer: + logger.info( + f"Answered from conversation summary | " + f"Reasoning: {summary_result.reasoning}" + ) + return summary_result, cost_dict + + logger.info( + "Cannot answer from summary either | Falling back to RAG" + ) + else: + logger.warning( + "Summary generation returned empty | Falling back to RAG" + ) + + except Exception as e: + logger.error(f"Summary-based analysis failed: {e}", exc_info=True) + else: + logger.info( + f"History has {total_turns} turns (<= 10) | " + f"No summary needed | Falling back to RAG" + ) + + # Cannot answer from context at all + logger.info( + f"CONTEXT ANALYZER FINAL DECISION | " + f"can_answer_from_context={result.can_answer_from_context} | " + f"is_greeting={result.is_greeting} | " + f"answered_from_summary={result.answered_from_summary} | " + f"has_answer={result.answer is not None} | " + f"action={'RESPOND' if (result.can_answer_from_context or result.is_greeting) and result.answer else 'FALLBACK_TO_RAG'}" + ) + return result, cost_dict + + async def _analyze_recent_history( + self, + query: str, + conversation_history: List[Dict[str, Any]], + language: str = "et", + ) -> tuple[ContextAnalysisResult, Dict[str, Any]]: + """ + Analyze the query against the most recent conversation turns. + + This is the original analysis logic extracted into its own method. + Checks for greetings and history-answerable queries in the last 10 turns. + + Args: + query: User query to analyze + conversation_history: Full conversation history (last 10 will be used) + language: Language code for response generation + + Returns: + Tuple of (ContextAnalysisResult, cost_dict) + """ + logger.info("STEP 1: Analyzing recent history (last 10 turns)") + + # Track LLM history for cost calculation + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception as e: + logger.warning(f"Failed to get LM history length: {e}") + + # Format conversation history (last 10 turns) + formatted_history = self._format_conversation_history(conversation_history) + + # Ensure LM is configured and use task-local context for DSPy operations + self.llm_manager.ensure_global_config() + try: + with self.llm_manager.use_task_local(): + # Initialize DSPy module if not already done + if self._module is None: + self._module = dspy.ChainOfThought(ContextAnalysisSignature) + # Call LLM for analysis + logger.info( + "Calling LLM for context analysis (greeting/history check)..." + ) + response = self._module( + conversation_history=formatted_history, + user_query=query, + ) + + # Parse the analysis result + analysis_json = response.analysis_result + + # Try to parse JSON response + try: + analysis_data = json.loads(analysis_json) + logger.debug( + f"Raw LLM response parsed | " + f"can_answer_from_context={analysis_data.get('can_answer_from_context')} | " + f"is_greeting={analysis_data.get('is_greeting')} | " + f"has_answer={analysis_data.get('answer') is not None}" + ) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse LLM response as JSON: {analysis_json[:100]}" + ) + # Fallback: treat as cannot answer + analysis_data = { + "is_greeting": False, + "can_answer_from_context": False, + "answer": None, + "reasoning": "Failed to parse LLM response", + } + + # Create result object + result = ContextAnalysisResult( + is_greeting=analysis_data.get("is_greeting", False), + can_answer_from_context=analysis_data.get( + "can_answer_from_context", False + ), + answer=analysis_data.get("answer"), + reasoning=analysis_data.get("reasoning", "Analysis completed"), + ) + + logger.info( + f"ANALYSIS RESULT | Greeting: {result.is_greeting} | " + f"Can Answer from Context: {result.can_answer_from_context} | " + f"Answer: {result.answer[:100] if result.answer else None} | " + f"Reasoning: {result.reasoning}" + ) + + # If greeting detected but LLM didn't generate an answer, use fallback + if result.is_greeting and result.answer is None: + greeting_type = self._detect_greeting_type(query) + fallback_answer = get_greeting_response(greeting_type, language) + result = ContextAnalysisResult( + is_greeting=result.is_greeting, + can_answer_from_context=result.can_answer_from_context, + answer=fallback_answer, + reasoning=result.reasoning, + ) + + except Exception as e: + logger.error(f"Context analysis failed: {e}", exc_info=True) + # Fallback result + result = ContextAnalysisResult( + is_greeting=False, + can_answer_from_context=False, + answer=None, + reasoning=f"Analysis error: {str(e)}", + ) + + # Calculate costs + cost_dict = get_lm_usage_since(history_length_before) + logger.info( + f"Cost tracking | Total cost: ${cost_dict.get('total_cost', 0):.6f} | " + f"Tokens: {cost_dict.get('total_tokens', 0)} | " + f"Calls: {cost_dict.get('num_calls', 0)}" + ) + + return result, cost_dict + + def _detect_greeting_type(self, query: str) -> str: + """ + Detect the type of greeting from the query text. + + Args: + query: User query string + + Returns: + Greeting type: 'thanks', 'goodbye', 'casual', or 'hello' (default) + """ + query_lower = query.lower().strip() + thanks_keywords = ["thank", "thanks", "tänan", "aitäh", "tänud"] + goodbye_keywords = ["bye", "goodbye", "nägemist", "tsau", "tšau", "head aega"] + casual_keywords = ["hei", "hey", "moi", "moikka"] + for kw in thanks_keywords: + if kw in query_lower: + return "thanks" + for kw in goodbye_keywords: + if kw in query_lower: + return "goodbye" + for kw in casual_keywords: + if kw in query_lower: + return "casual" + return "hello" + + def get_fallback_greeting_response(self, language: str = "et") -> str: + """ + Get a fallback greeting response without LLM call. + + Used when LLM-based greeting detection fails but we still want + to provide a friendly response. + + Args: + language: Language code (et, en) + + Returns: + Greeting message in the specified language + """ + greetings = { + "et": "Tere! Kuidas ma saan sind aidata?", + "en": "Hello! How can I help you?", + } + return greetings.get(language, greetings["et"]) diff --git a/src/tool_classifier/greeting_constants.py b/src/tool_classifier/greeting_constants.py new file mode 100644 index 0000000..272d6a4 --- /dev/null +++ b/src/tool_classifier/greeting_constants.py @@ -0,0 +1,40 @@ +"""Constants for greeting responses in multiple languages.""" + +from typing import Dict + +# Estonian greeting responses +GREETINGS_ET: Dict[str, str] = { + "hello": "Tere! Kuidas ma saan sind aidata?", + "goodbye": "Nägemist! Head päeva!", + "thanks": "Palun! Kui on veel küsimusi, küsi julgelt.", + "casual": "Tere! Mida ma saan sinu jaoks teha?", +} + +# English greeting responses +GREETINGS_EN: Dict[str, str] = { + "hello": "Hello! How can I help you?", + "goodbye": "Goodbye! Have a great day!", + "thanks": "You're welcome! Feel free to ask if you have more questions.", + "casual": "Hey! What can I do for you?", +} + +# Language-specific greeting mappings +GREETINGS_BY_LANGUAGE: Dict[str, Dict[str, str]] = { + "et": GREETINGS_ET, + "en": GREETINGS_EN, +} + + +def get_greeting_response(greeting_type: str = "hello", language: str = "et") -> str: + """ + Get a greeting response for a specific type and language. + + Args: + greeting_type: Type of greeting (hello, goodbye, thanks, casual) + language: Language code (et, en) + + Returns: + Greeting message in the specified language + """ + language_greetings = GREETINGS_BY_LANGUAGE.get(language, GREETINGS_EN) + return language_greetings.get(greeting_type, language_greetings["hello"]) diff --git a/src/tool_classifier/workflows/context_workflow.py b/src/tool_classifier/workflows/context_workflow.py index dc23e8b..8d69675 100644 --- a/src/tool_classifier/workflows/context_workflow.py +++ b/src/tool_classifier/workflows/context_workflow.py @@ -1,10 +1,22 @@ """Context workflow executor - Layer 2: Conversation history and greetings.""" from typing import Any, AsyncIterator, Dict, Optional +import time +import dspy from loguru import logger from models.request_models import OrchestrationRequest, OrchestrationResponse from tool_classifier.base_workflow import BaseWorkflow +from tool_classifier.context_analyzer import ContextAnalyzer, ContextDetectionResult +from tool_classifier.workflows.service_workflow import LLMServiceProtocol +from src.guardrails.nemo_rails_adapter import NeMoRailsAdapter +from src.llm_orchestrator_config.llm_manager import LLMManager +from src.utils.cost_utils import get_lm_usage_since +from src.utils.language_detector import detect_language +from src.llm_orchestrator_config.llm_ochestrator_constants import ( + GUARDRAILS_BLOCKED_PHRASES, + OUTPUT_GUARDRAIL_VIOLATION_MESSAGE, +) class ContextWorkflowExecutor(BaseWorkflow): @@ -12,24 +24,222 @@ class ContextWorkflowExecutor(BaseWorkflow): Handles greetings and conversation history queries (Layer 2). Detects: - - Greetings: "Hello", "Thanks", "Goodbye" + - Greetings: "Hello", "Thanks", "Goodbye" (multilingual: Estonian, English) - History references: "What did you say earlier?", "Can you repeat that?" Uses LLM for semantic detection (multilingual), no regex patterns. - Status: SKELETON - Returns None (fallback to RAG) - TODO: Implement greeting/context detection, answer extraction, guardrails + Implementation Strategy: + 1. Detect language from user query + 2. Use ContextAnalyzer (LLM-based) to check if: + - Query is a greeting -> generate friendly response + - Query references conversation history -> extract answer + 3. If can answer -> return response + 4. Otherwise -> return None (fallback to RAG) + + Cost Tracking: + - Tracks LLM costs for context analysis + - Logs via orchestration_service.log_costs() (same as service/RAG workflows) """ - def __init__(self, llm_manager: Any): + def __init__( + self, + llm_manager: LLMManager, + orchestration_service: Optional[LLMServiceProtocol] = None, + ) -> None: """ Initialize context workflow executor. Args: llm_manager: LLM manager for context analysis + orchestration_service: Reference to LLMOrchestrationService for cost logging """ self.llm_manager = llm_manager - logger.info("Context workflow executor initialized (skeleton)") + self.orchestration_service = orchestration_service + self.context_analyzer = ContextAnalyzer(llm_manager) + logger.info("Context workflow executor initialized") + + @staticmethod + def _build_history(request: OrchestrationRequest) -> list[Dict[str, Any]]: + return [ + { + "authorRole": item.authorRole, + "message": item.message, + "timestamp": item.timestamp, + } + for item in request.conversationHistory + ] + + async def _detect( + self, + message: str, + history: list[Dict[str, Any]], + time_metric: Dict[str, float], + costs_metric: Dict[str, Dict[str, Any]], + ) -> Optional[ContextDetectionResult]: + """Phase 1: run context detection. Returns ContextDetectionResult or None on error.""" + try: + start = time.time() + result, cost = await self.context_analyzer.detect_context( + query=message, conversation_history=history + ) + time_metric["context.detection"] = time.time() - start + costs_metric["context_detection"] = cost + return result + except Exception as e: + logger.error(f"Phase 1 detection failed: {e}", exc_info=True) + return None + + def _log_costs(self, costs_metric: Dict[str, Dict[str, Any]]) -> None: + if self.orchestration_service: + self.orchestration_service.log_costs(costs_metric) + + @staticmethod + def _is_guardrail_violation(chunk: str) -> bool: + """Return True if the chunk matches a known guardrail blocked phrase.""" + chunk_lower = chunk.strip().lower() + return any( + phrase.lower() in chunk_lower + and len(chunk_lower) <= len(phrase.lower()) + 20 + for phrase in GUARDRAILS_BLOCKED_PHRASES + ) + + async def _generate_response_async( + self, + request: OrchestrationRequest, + context_snippet: str, + time_metric: Dict[str, float], + costs_metric: Dict[str, Dict[str, Any]], + ) -> Optional[OrchestrationResponse]: + """Non-streaming: Generate response + apply output guardrails.""" + try: + start = time.time() + answer, cost = await self.context_analyzer.generate_context_response( + query=request.message, context_snippet=context_snippet + ) + time_metric["context.generation"] = time.time() - start + costs_metric["context_response"] = cost + except Exception as e: + logger.error(f"Phase 2 generation failed: {e}", exc_info=True) + self._log_costs(costs_metric) + return None + + if not answer: + logger.warning(f"[{request.chatId}] Phase 2 empty answer — fallback to RAG") + self._log_costs(costs_metric) + return None + + response = OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=answer, + ) + if self.orchestration_service: + try: + components = self.orchestration_service._initialize_service_components( + request + ) + response = await self.orchestration_service.handle_output_guardrails( + guardrails_adapter=components.get("guardrails_adapter"), + generated_response=response, + request=request, + costs_metric=costs_metric, + ) + except Exception as e: + logger.warning( + f"[{request.chatId}] Output guardrails check failed: {e}" + ) + self._log_costs(costs_metric) + return response + + async def _stream_history_generator( + self, + chat_id: str, + query: str, + context_snippet: str, + history_length_before: int, + guardrails_adapter: NeMoRailsAdapter, + costs_metric: Dict[str, Dict[str, Any]], + ) -> AsyncIterator[str]: + """Async generator: stream history answer through NeMo Guardrails.""" + bot_generator = self.context_analyzer.stream_context_response( + query=query, context_snippet=context_snippet + ) + orchestration_service = self.orchestration_service + if orchestration_service is None: + return + async for validated_chunk in guardrails_adapter.stream_with_guardrails( + user_message=query, bot_message_generator=bot_generator + ): + if isinstance(validated_chunk, str) and self._is_guardrail_violation( + validated_chunk + ): + logger.warning(f"[{chat_id}] Guardrails violation in context streaming") + yield orchestration_service.format_sse( + chat_id, OUTPUT_GUARDRAIL_VIOLATION_MESSAGE + ) + yield orchestration_service.format_sse(chat_id, "END") + costs_metric["context_response"] = get_lm_usage_since( + history_length_before + ) + orchestration_service.log_costs(costs_metric) + return + yield orchestration_service.format_sse(chat_id, validated_chunk) + yield orchestration_service.format_sse(chat_id, "END") + logger.info(f"[{chat_id}] Context streaming complete") + costs_metric["context_response"] = get_lm_usage_since(history_length_before) + orchestration_service.log_costs(costs_metric) + + async def _create_history_stream( + self, + request: OrchestrationRequest, + context_snippet: str, + costs_metric: Dict[str, Dict[str, Any]], + ) -> Optional[AsyncIterator[str]]: + """Set up guardrails adapter and return the history streaming generator.""" + if not self.orchestration_service: + logger.warning( + f"[{request.chatId}] No orchestration_service — cannot stream with guardrails" + ) + return None + try: + components = self.orchestration_service._initialize_service_components( + request + ) + guardrails_adapter = components.get("guardrails_adapter") + except Exception as e: + logger.error( + f"[{request.chatId}] Failed to initialize components: {e}", + exc_info=True, + ) + self._log_costs(costs_metric) + return None + + if not isinstance(guardrails_adapter, NeMoRailsAdapter): + logger.warning( + f"[{request.chatId}] guardrails_adapter unavailable — cannot stream" + ) + self._log_costs(costs_metric) + return None + + history_length_before = 0 + try: + lm = dspy.settings.lm + if lm and hasattr(lm, "history"): + history_length_before = len(lm.history) + except Exception: + pass + + return self._stream_history_generator( + chat_id=request.chatId, + query=request.message, + context_snippet=context_snippet, + history_length_before=history_length_before, + guardrails_adapter=guardrails_adapter, + costs_metric=costs_metric, + ) async def execute_async( self, @@ -38,26 +248,64 @@ async def execute_async( time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: """ - Execute context workflow in non-streaming mode. - - TODO: Check greeting (LLM) → generate response, OR check history (last 10 turns) - → extract answer → validate with guardrails. Return None if cannot answer. + Execute context workflow in non-streaming mode (two-phase). - Args: - request: Orchestration request with user query and history - context: Metadata with is_greeting, can_answer_from_history flags - time_metric: Optional timing dictionary for future timing tracking + Phase 1: Detect if query is a greeting or can be answered from history. + Phase 2: Generate response (greetings: pre-built; history: LLM + guardrails). Returns: - OrchestrationResponse with context-based answer or None to fallback + OrchestrationResponse or None to fallback to RAG """ - logger.debug( - f"[{request.chatId}] Context workflow execute_async called " - f"(not implemented - returning None)" + logger.info( + f"[{request.chatId}] CONTEXT WORKFLOW (NON-STREAMING) | " + f"Query: '{request.message[:100]}'" ) + costs_metric: Dict[str, Dict[str, Any]] = {} + if time_metric is None: + time_metric = {} + + language = detect_language(request.message) + history = self._build_history(request) - # TODO: Implement context workflow logic here - # For now, return None to trigger fallback to next layer (RAG) + detection_result = await self._detect( + request.message, history, time_metric, costs_metric + ) + if detection_result is None: + self._log_costs(costs_metric) + return None + + logger.info( + f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} " + f"can_answer={detection_result.can_answer_from_context}" + ) + + if detection_result.is_greeting: + from src.tool_classifier.greeting_constants import get_greeting_response + + greeting = get_greeting_response( + greeting_type=detection_result.greeting_type, language=language + ) + self._log_costs(costs_metric) + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=greeting, + ) + + if ( + detection_result.can_answer_from_context + and detection_result.context_snippet + ): + return await self._generate_response_async( + request, detection_result.context_snippet, time_metric, costs_metric + ) + + logger.warning( + f"[{request.chatId}] Cannot answer from context — falling back to RAG" + ) + self._log_costs(costs_metric) return None async def execute_streaming( @@ -67,24 +315,66 @@ async def execute_streaming( time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: """ - Execute context workflow in streaming mode. + Execute context workflow in streaming mode (two-phase). - TODO: Get answer (greeting/history) → validate BEFORE streaming → chunk and - yield as SSE. Return None if cannot answer. - - Args: - request: Orchestration request with user query and history - context: Metadata with is_greeting, can_answer_from_history flags - time_metric: Optional timing dictionary for future timing tracking + Phase 1: Detect context (blocking, fast — classification only). + Phase 2: Stream answer through NeMo Guardrails (same pipeline as RAG). Returns: - AsyncIterator yielding SSE strings or None to fallback + AsyncIterator yielding SSE strings or None to fallback to RAG """ - logger.debug( - f"[{request.chatId}] Context workflow execute_streaming called " - f"(not implemented - returning None)" + logger.info( + f"[{request.chatId}] CONTEXT WORKFLOW (STREAMING) | " + f"Query: '{request.message[:100]}'" + ) + costs_metric: Dict[str, Dict[str, Any]] = {} + if time_metric is None: + time_metric = {} + + language = detect_language(request.message) + history = self._build_history(request) + + detection_result = await self._detect( + request.message, history, time_metric, costs_metric + ) + if detection_result is None: + self._log_costs(costs_metric) + return None + + logger.info( + f"[{request.chatId}] Detection: greeting={detection_result.is_greeting} " + f"can_answer={detection_result.can_answer_from_context}" ) - # TODO: Implement context streaming logic here - # For now, return None to trigger fallback to next layer (RAG) + if detection_result.is_greeting: + from src.tool_classifier.greeting_constants import get_greeting_response + + greeting = get_greeting_response( + greeting_type=detection_result.greeting_type, language=language + ) + orchestration_service = self.orchestration_service + if orchestration_service is None: + self._log_costs(costs_metric) + return None + chat_id = request.chatId + + async def _stream_greeting() -> AsyncIterator[str]: + yield orchestration_service.format_sse(chat_id, greeting) + yield orchestration_service.format_sse(chat_id, "END") + orchestration_service.log_costs(costs_metric) + + return _stream_greeting() + + if ( + detection_result.can_answer_from_context + and detection_result.context_snippet + ): + return await self._create_history_stream( + request, detection_result.context_snippet, costs_metric + ) + + logger.warning( + f"[{request.chatId}] Cannot answer from context — falling back to RAG" + ) + self._log_costs(costs_metric) return None diff --git a/src/tool_classifier/workflows/rag_workflow.py b/src/tool_classifier/workflows/rag_workflow.py index b5da35b..9c983ce 100644 --- a/src/tool_classifier/workflows/rag_workflow.py +++ b/src/tool_classifier/workflows/rag_workflow.py @@ -64,7 +64,7 @@ async def execute_async( Args: request: Orchestration request with user query - context: Unused (RAG doesn't need classification metadata) + context: May contain pre-initialized "components" to avoid duplicate init time_metric: Optional timing dictionary from parent (for unified tracking) Returns: @@ -79,8 +79,12 @@ async def execute_async( if time_metric is None: time_metric = {} - # Initialize service components - components = self.orchestration_service._initialize_service_components(request) + # Reuse components from context if available, otherwise initialize + components = context.get("components") + if components is None: + components = self.orchestration_service._initialize_service_components( + request + ) # Call existing RAG pipeline with "rag" prefix for namespacing response = await self.orchestration_service._execute_orchestration_pipeline( @@ -105,6 +109,11 @@ async def execute_streaming( """ Execute RAG workflow in streaming mode. + Coroutine that returns an AsyncIterator so callers can safely use + ``await workflow.execute_streaming(...)`` and then iterate over the + returned stream without hitting a TypeError from awaiting an async + generator. + Delegates to existing streaming pipeline which handles: - Prompt refinement (blocking) - Chunk retrieval (blocking) @@ -118,7 +127,7 @@ async def execute_streaming( Args: request: Orchestration request with user query - context: Unused (RAG doesn't need classification metadata) + context: May contain pre-initialized "components" and "stream_ctx" time_metric: Optional timing dictionary from parent (for unified tracking) Returns: @@ -143,8 +152,7 @@ async def execute_streaming( # Get stream context from context if provided, otherwise create minimal tracking stream_ctx = context.get("stream_ctx") if stream_ctx is None: - # Create minimal stream context when called via tool classifier - # In production flow, this is provided by stream_orchestration_response + class MinimalStreamContext: """Minimal stream context for RAG workflow when called directly.""" @@ -154,25 +162,29 @@ def __init__(self, chat_id: str) -> None: self.bot_generator = None def mark_completed(self) -> None: - """No-op: Tracking handled by orchestration service.""" + # Intentionally empty: lifecycle tracking is handled by the orchestration service, not this minimal context pass def mark_cancelled(self) -> None: - """No-op: Tracking handled by orchestration service.""" + # Intentionally empty: lifecycle tracking is handled by the orchestration service, not this minimal context pass def mark_error(self, error_id: str) -> None: - """No-op: Tracking handled by orchestration service.""" + # Intentionally empty: lifecycle tracking is handled by the orchestration service, not this minimal context pass stream_ctx = MinimalStreamContext(request.chatId) - # Delegate to core RAG pipeline (bypasses classifier to avoid recursion) - async for sse_chunk in self.orchestration_service._stream_rag_pipeline( - request=request, - components=components, - stream_ctx=stream_ctx, - costs_metric=costs_metric, - time_metric=time_metric, - ): - yield sse_chunk + # Return an inner async generator so this method stays a coroutine. + # This avoids the TypeError when callers do ``await execute_streaming(...)``. + async def _stream() -> AsyncIterator[str]: + async for sse_chunk in self.orchestration_service._stream_rag_pipeline( + request=request, + components=components, + stream_ctx=stream_ctx, + costs_metric=costs_metric, + time_metric=time_metric, + ): + yield sse_chunk + + return _stream() diff --git a/src/tool_classifier/workflows/service_workflow.py b/src/tool_classifier/workflows/service_workflow.py index bb72f78..7882550 100644 --- a/src/tool_classifier/workflows/service_workflow.py +++ b/src/tool_classifier/workflows/service_workflow.py @@ -6,6 +6,7 @@ import httpx from loguru import logger +from src.guardrails.nemo_rails_adapter import NeMoRailsAdapter from src.utils.cost_utils import get_lm_usage_since from models.request_models import ( @@ -73,6 +74,22 @@ def log_costs(self, costs_metric: Dict[str, Dict[str, Any]]) -> None: """ ... + def _initialize_service_components( + self, request: OrchestrationRequest + ) -> Dict[str, Any]: + """Initialize and return service components dictionary.""" + ... + + async def handle_output_guardrails( + self, + guardrails_adapter: Optional[NeMoRailsAdapter], + generated_response: OrchestrationResponse, + request: OrchestrationRequest, + costs_metric: Dict[str, Dict[str, Any]], + ) -> OrchestrationResponse: + """Apply output guardrails to the generated response.""" + ... + class ServiceWorkflowExecutor(BaseWorkflow): """Executes external service calls via Ruuter endpoints (Layer 1).""" diff --git a/src/utils/rate_limiter.py b/src/utils/rate_limiter.py index 4b88d9d..d86829f 100644 --- a/src/utils/rate_limiter.py +++ b/src/utils/rate_limiter.py @@ -1,8 +1,8 @@ -"""Rate limiter for streaming endpoints with sliding window and token bucket algorithms.""" +"""Rate limiter for streaming endpoints with sliding window algorithms.""" import time from collections import defaultdict, deque -from typing import Dict, Deque, Tuple, Optional, Any +from typing import Dict, Deque, Optional, Any from threading import Lock from loguru import logger @@ -31,11 +31,11 @@ class RateLimitResult(BaseModel): class RateLimiter: """ - In-memory rate limiter with sliding window (requests/minute) and token bucket (tokens/second). + In-memory rate limiter using sliding windows for both requests and tokens. Features: - Sliding window for request rate limiting (e.g., 10 requests per minute) - - Token bucket for burst control (e.g., 100 tokens per second) + - Sliding window for token rate limiting (e.g., 40,000 tokens per minute) - Per-user tracking with authorId - Automatic cleanup of old entries to prevent memory leaks - Thread-safe operations @@ -43,7 +43,7 @@ class RateLimiter: Usage: rate_limiter = RateLimiter( requests_per_minute=10, - tokens_per_second=100 + tokens_per_minute=40_000, ) result = rate_limiter.check_rate_limit( @@ -59,28 +59,32 @@ class RateLimiter: def __init__( self, requests_per_minute: int = StreamConfig.RATE_LIMIT_REQUESTS_PER_MINUTE, - tokens_per_second: int = StreamConfig.RATE_LIMIT_TOKENS_PER_SECOND, + tokens_per_minute: int = StreamConfig.RATE_LIMIT_TOKENS_PER_MINUTE, cleanup_interval: int = StreamConfig.RATE_LIMIT_CLEANUP_INTERVAL, + token_window_seconds: int = StreamConfig.RATE_LIMIT_TOKEN_WINDOW_SECONDS, ): """ Initialize rate limiter. Args: requests_per_minute: Maximum requests per user per minute (sliding window) - tokens_per_second: Maximum tokens per user per second (token bucket) + tokens_per_minute: Maximum tokens per user per minute (sliding window) cleanup_interval: Seconds between automatic cleanup of old entries + token_window_seconds: Sliding window size in seconds for token tracking """ self.requests_per_minute = requests_per_minute - self.tokens_per_second = tokens_per_second + self.tokens_per_minute = tokens_per_minute self.cleanup_interval = cleanup_interval + self.token_window_seconds = token_window_seconds + # Scale the per-minute limit to the actual window size so the + # sliding-window comparison is consistent regardless of window length. + self.tokens_per_window = int(tokens_per_minute * token_window_seconds / 60) # Sliding window: Track request timestamps per user - # Format: {author_id: deque([timestamp1, timestamp2, ...])} self._request_history: Dict[str, Deque[float]] = defaultdict(deque) - # Token bucket: Track token consumption per user - # Format: {author_id: (last_refill_time, available_tokens)} - self._token_buckets: Dict[str, Tuple[float, float]] = {} + # Sliding window: Track token usage per user + self._token_history: Dict[str, Deque[tuple[float, int]]] = defaultdict(deque) # Thread safety self._lock = Lock() @@ -91,7 +95,7 @@ def __init__( logger.info( f"RateLimiter initialized - " f"requests_per_minute: {requests_per_minute}, " - f"tokens_per_second: {tokens_per_second}" + f"tokens_per_minute: {tokens_per_minute}" ) def check_rate_limit( @@ -121,7 +125,7 @@ def check_rate_limit( if not request_result.allowed: return request_result - # Check 2: Token bucket (tokens per second) + # Check 2: Sliding window (tokens per minute) if estimated_tokens > 0: token_result = self._check_token_limit( author_id, estimated_tokens, current_time @@ -186,12 +190,11 @@ def _check_token_limit( current_time: float, ) -> RateLimitResult: """ - Check token bucket limit. + Check sliding window token limit. - Token bucket algorithm: - - Bucket refills at constant rate (tokens_per_second) - - Burst allowed up to bucket capacity - - Request denied if insufficient tokens + Sliding window algorithm: + - Track cumulative tokens consumed within the window + - Reject if adding estimated tokens would exceed the limit Args: author_id: User identifier @@ -201,38 +204,42 @@ def _check_token_limit( Returns: RateLimitResult for token limit check """ - bucket_capacity = self.tokens_per_second - - # Get or initialize bucket for user - if author_id not in self._token_buckets: - # New user - start with full bucket - self._token_buckets[author_id] = (current_time, bucket_capacity) - - last_refill, available_tokens = self._token_buckets[author_id] - - # Refill tokens based on time elapsed - time_elapsed = current_time - last_refill - refill_amount = time_elapsed * self.tokens_per_second - available_tokens = min(bucket_capacity, available_tokens + refill_amount) - - # Check if enough tokens available - if available_tokens < estimated_tokens: - # Calculate time needed to refill enough tokens - tokens_needed = estimated_tokens - available_tokens - retry_after = int(tokens_needed / self.tokens_per_second) + 1 + token_history = self._token_history[author_id] + window_start = current_time - self.token_window_seconds + + # Remove entries outside the sliding window + while token_history and token_history[0][0] < window_start: + token_history.popleft() + + # Sum tokens consumed in the current window + current_token_usage = sum(tokens for _, tokens in token_history) + + # Check if adding this request would exceed the scaled window limit + if current_token_usage + estimated_tokens > self.tokens_per_window: + # Calculate retry_after based on oldest entry in window + if token_history: + oldest_timestamp = token_history[0][0] + retry_after = ( + int(oldest_timestamp + self.token_window_seconds - current_time) + 1 + ) + else: + retry_after = 1 logger.warning( f"Token rate limit exceeded for {author_id} - " - f"needed: {estimated_tokens}, available: {available_tokens:.0f} " - f"(retry after {retry_after}s)" + f"needed: {estimated_tokens}, " + f"current_usage: {current_token_usage}/{self.tokens_per_window} " + f"(window: {self.token_window_seconds}s, " + f"rate: {self.tokens_per_minute}/min, " + f"retry after {retry_after}s)" ) return RateLimitResult( allowed=False, retry_after=retry_after, limit_type="tokens", - current_usage=int(bucket_capacity - available_tokens), - limit=self.tokens_per_second, + current_usage=current_token_usage, + limit=self.tokens_per_window, ) return RateLimitResult(allowed=True) @@ -254,20 +261,9 @@ def _record_request( # Record request timestamp for sliding window self._request_history[author_id].append(current_time) - # Deduct tokens from bucket - if tokens_consumed > 0 and author_id in self._token_buckets: - last_refill, available_tokens = self._token_buckets[author_id] - - # Refill before deducting - time_elapsed = current_time - last_refill - refill_amount = time_elapsed * self.tokens_per_second - available_tokens = min( - self.tokens_per_second, available_tokens + refill_amount - ) - - # Deduct tokens - available_tokens -= tokens_consumed - self._token_buckets[author_id] = (current_time, available_tokens) + # Record token usage for sliding window + if tokens_consumed > 0: + self._token_history[author_id].append((current_time, tokens_consumed)) def _cleanup_old_entries(self, current_time: float) -> None: """ @@ -294,23 +290,25 @@ def _cleanup_old_entries(self, current_time: float) -> None: for author_id in users_to_remove: del self._request_history[author_id] - # Clean up token buckets (remove entries inactive for 5 minutes) - inactive_threshold = current_time - 300 - buckets_to_remove: list[str] = [] + # Clean up token history (remove entries outside window + inactive users) + token_window_start = current_time - self.token_window_seconds + token_users_to_remove: list[str] = [] - for author_id, (last_refill, _) in self._token_buckets.items(): - if last_refill < inactive_threshold: - buckets_to_remove.append(author_id) + for author_id, token_history in self._token_history.items(): + while token_history and token_history[0][0] < token_window_start: + token_history.popleft() + if not token_history: + token_users_to_remove.append(author_id) - for author_id in buckets_to_remove: - del self._token_buckets[author_id] + for author_id in token_users_to_remove: + del self._token_history[author_id] self._last_cleanup = current_time - if users_to_remove or buckets_to_remove: + if users_to_remove or token_users_to_remove: logger.debug( f"Cleaned up {len(users_to_remove)} request histories and " - f"{len(buckets_to_remove)} token buckets" + f"{len(token_users_to_remove)} token histories" ) def get_stats(self) -> Dict[str, Any]: @@ -323,9 +321,9 @@ def get_stats(self) -> Dict[str, Any]: with self._lock: return { "total_users_tracked": len(self._request_history), - "total_token_buckets": len(self._token_buckets), + "total_token_histories": len(self._token_history), "requests_per_minute_limit": self.requests_per_minute, - "tokens_per_second_limit": self.tokens_per_second, + "tokens_per_minute_limit": self.tokens_per_minute, "last_cleanup": self._last_cleanup, } @@ -339,7 +337,7 @@ def reset_user(self, author_id: str) -> None: with self._lock: if author_id in self._request_history: del self._request_history[author_id] - if author_id in self._token_buckets: - del self._token_buckets[author_id] + if author_id in self._token_history: + del self._token_history[author_id] logger.info(f"Reset rate limits for user: {author_id}")