diff --git a/docs/decisions/0019-python-context-compaction-strategy.md b/docs/decisions/0019-python-context-compaction-strategy.md index 11e1c091e5..8fffb185d1 100644 --- a/docs/decisions/0019-python-context-compaction-strategy.md +++ b/docs/decisions/0019-python-context-compaction-strategy.md @@ -1240,3 +1240,10 @@ class AttributionAwareStrategy(CompactionStrategy): - [ADR-0016: Unifying Context Management with ContextPlugin](0016-python-context-middleware.md) — Parent ADR that established `ContextProvider`, `HistoryProvider`, and `AgentSession` architecture. - [Context Compaction Limitations Analysis](https://gist.github.com/victordibia/ec3f3baf97345f7e47da025cf55b999f) — Detailed analysis of why current architecture cannot support in-run compaction, with attempted solutions and their failure modes. Option 4 in this ADR corresponds to "Option A: Middleware Access to Mutable Message Source" from that analysis; Options 1-3 correspond to "Option B: Tool Loop Hook", adapted here to a `BaseChatClient` hook instead of `FunctionInvocationConfiguration`. + +### Implementation Rollout Note + +Implementation is split into two phases: + +1. **Phase 1 (PR 1):** runtime compaction foundation in `agent_framework/_compaction.py`, in-run integration, and extensive core tests, plus in-run compaction samples (`basics`, `advanced`, `custom`). +2. **Phase 2 (PR 2):** history/storage compaction (`upsert`-based full replacement), provider support, storage tests, and storage-focused sample (`storage`). diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 1cbcc7a8cb..72aa11fb13 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -29,6 +29,31 @@ SupportsMCPTool, SupportsWebSearchTool, ) +from ._compaction import ( + EXCLUDE_REASON_KEY, + EXCLUDED_KEY, + GROUP_ANNOTATION_KEY, + GROUP_HAS_REASONING_KEY, + GROUP_ID_KEY, + GROUP_INDEX_KEY, + GROUP_KIND_KEY, + GROUP_TOKEN_COUNT_KEY, + SUMMARIZED_BY_SUMMARY_ID_KEY, + SUMMARY_OF_GROUP_IDS_KEY, + SUMMARY_OF_MESSAGE_IDS_KEY, + CharacterEstimatorTokenizer, + CompactionStrategy, + SelectiveToolCallCompactionStrategy, + SlidingWindowStrategy, + SummarizationStrategy, + TokenBudgetComposedStrategy, + TokenizerProtocol, + TruncationStrategy, + annotate_message_groups, + apply_compaction, + included_messages, + included_token_count, +) from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool from ._middleware import ( AgentContext, @@ -191,6 +216,17 @@ "AGENT_FRAMEWORK_USER_AGENT", "APP_INFO", "DEFAULT_MAX_ITERATIONS", + "EXCLUDED_KEY", + "EXCLUDE_REASON_KEY", + "GROUP_ANNOTATION_KEY", + "GROUP_HAS_REASONING_KEY", + "GROUP_ID_KEY", + "GROUP_INDEX_KEY", + "GROUP_KIND_KEY", + "GROUP_TOKEN_COUNT_KEY", + "SUMMARIZED_BY_SUMMARY_ID_KEY", + "SUMMARY_OF_GROUP_IDS_KEY", + "SUMMARY_OF_MESSAGE_IDS_KEY", "USER_AGENT_KEY", "USER_AGENT_TELEMETRY_DISABLED_ENV_VAR", "Agent", @@ -205,9 +241,6 @@ "AgentResponseUpdate", "AgentRunInputs", "AgentSession", - "Skill", - "SkillResource", - "SkillsProvider", "Annotation", "BaseAgent", "BaseChatClient", @@ -215,6 +248,7 @@ "BaseEmbeddingClient", "BaseHistoryProvider", "Case", + "CharacterEstimatorTokenizer", "ChatAndFunctionMiddlewareTypes", "ChatContext", "ChatMiddleware", @@ -224,6 +258,7 @@ "ChatResponse", "ChatResponseUpdate", "CheckpointStorage", + "CompactionStrategy", "Content", "ContinuationToken", "Default", @@ -270,10 +305,16 @@ "Runner", "RunnerContext", "SecretString", + "SelectiveToolCallCompactionStrategy", "SessionContext", "SingleEdgeGroup", + "Skill", + "SkillResource", + "SkillsProvider", + "SlidingWindowStrategy", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", + "SummarizationStrategy", "SupportsAgentRun", "SupportsChatGetResponse", "SupportsCodeInterpreterTool", @@ -286,8 +327,11 @@ "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", "TextSpanRegion", + "TokenBudgetComposedStrategy", + "TokenizerProtocol", "ToolMode", "ToolTypes", + "TruncationStrategy", "TypeCompatibilityError", "UpdateT", "UsageDetails", @@ -314,12 +358,16 @@ "__version__", "add_usage_details", "agent_middleware", + "annotate_message_groups", + "apply_compaction", "chat_middleware", "create_edge_runner", "detect_media_type_from_base64", "executor", "function_middleware", "handler", + "included_messages", + "included_token_count", "load_settings", "map_chat_to_agent_update", "merge_chat_options", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a0c998757c..4a3cf22f3a 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -68,6 +68,7 @@ from typing_extensions import Self, TypedDict # pragma: no cover if TYPE_CHECKING: + from ._compaction import CompactionStrategy, TokenizerProtocol from ._types import ChatOptions logger = logging.getLogger("agent_framework") @@ -649,6 +650,8 @@ def __init__( tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, default_options: OptionsCoT | None = None, context_providers: Sequence[BaseContextProvider] | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, **kwargs: Any, ) -> None: """Initialize a Agent instance. @@ -672,6 +675,8 @@ def __init__( Note: response_format typing does not flow into run outputs when set via default_options. These can be overridden at runtime via the ``options`` parameter of ``run()``. tools: The tools to use for the request. + compaction_strategy: Optional in-run compaction strategy for function-calling loops. + tokenizer: Optional tokenizer for token-aware compaction strategies. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. """ opts = dict(default_options) if default_options else {} @@ -689,6 +694,12 @@ def __init__( **kwargs, ) self.client = client + self.compaction_strategy = compaction_strategy or getattr(client, "compaction_strategy", None) + self.tokenizer = tokenizer or getattr(client, "tokenizer", None) + if hasattr(self.client, "compaction_strategy"): + self.client.compaction_strategy = self.compaction_strategy + if hasattr(self.client, "tokenizer"): + self.client.tokenizer = self.tokenizer # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) @@ -1379,6 +1390,8 @@ def __init__( default_options: OptionsCoT | None = None, context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, **kwargs: Any, ) -> None: """Initialize a Agent instance.""" @@ -1392,5 +1405,7 @@ def __init__( default_options=default_options, context_providers=context_providers, middleware=middleware, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, **kwargs, ) diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 278657a154..ee19930e4e 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -52,6 +52,7 @@ if TYPE_CHECKING: from ._agents import Agent + from ._compaction import CompactionStrategy, TokenizerProtocol from ._middleware import ( MiddlewareTypes, ) @@ -252,7 +253,11 @@ async def _stream(): """ OTEL_PROVIDER_NAME: ClassVar[str] = "unknown" - DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"} + DEFAULT_EXCLUDE: ClassVar[set[str]] = { + "additional_properties", + "compaction_strategy", + "tokenizer", + } STORES_BY_DEFAULT: ClassVar[bool] = False """Whether this client stores conversation history server-side by default. @@ -267,15 +272,21 @@ def __init__( self, *, additional_properties: dict[str, Any] | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, **kwargs: Any, ) -> None: """Initialize a BaseChatClient instance. Keyword Args: additional_properties: Additional properties for the client. + compaction_strategy: Optional compaction strategy to apply before model calls. + tokenizer: Optional tokenizer used by token-aware compaction strategies. kwargs: Additional keyword arguments (merged into additional_properties). """ self.additional_properties = additional_properties or {} + self.compaction_strategy = compaction_strategy + self.tokenizer = tokenizer super().__init__(**kwargs) def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: @@ -334,6 +345,23 @@ def _build_response_stream( finalizer=lambda updates: self._finalize_response_updates(updates, response_format=response_format), ) + async def _prepare_messages_for_model_call( + self, + messages: Sequence[Message], + ) -> list[Message]: + prepared_messages = list(messages) + strategy = getattr(self, "compaction_strategy", None) + if strategy is None: + return prepared_messages + tokenizer = getattr(self, "tokenizer", None) + from ._compaction import apply_compaction + + return await apply_compaction( + prepared_messages, + strategy=strategy, + tokenizer=tokenizer, + ) + # region Internal method to be implemented by derived classes @abstractmethod @@ -413,12 +441,43 @@ def get_response( Returns: When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse. """ - return self._inner_get_response( - messages=messages, - stream=stream, - options=options or {}, # type: ignore[arg-type] - **kwargs, - ) + if getattr(self, "compaction_strategy", None) is None: + return self._inner_get_response( + messages=messages, + stream=stream, + options=options or {}, + **kwargs, + ) + + if stream: + + async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + prepared_messages = await self._prepare_messages_for_model_call(messages) + stream_response = self._inner_get_response( + messages=prepared_messages, + stream=True, + options=options or {}, + **kwargs, + ) + if isinstance(stream_response, ResponseStream): + return stream_response + awaited_stream_response = await stream_response + if isinstance(awaited_stream_response, ResponseStream): + return awaited_stream_response + raise ValueError("Streaming responses must return a ResponseStream.") + + return ResponseStream.from_awaitable(_get_stream()) + + async def _get_response() -> ChatResponse[Any]: + prepared_messages = await self._prepare_messages_for_model_call(messages) + return await self._inner_get_response( + messages=prepared_messages, + stream=False, + options=options or {}, + **kwargs, + ) + + return _get_response() def service_url(self) -> str: """Get the URL of the service. @@ -443,6 +502,8 @@ def as_agent( context_providers: Sequence[Any] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, + compaction_strategy: CompactionStrategy | None = None, + tokenizer: TokenizerProtocol | None = None, **kwargs: Any, ) -> Agent[OptionsCoT]: """Create a Agent with this client. @@ -465,6 +526,8 @@ def as_agent( context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. function_invocation_configuration: Optional function invocation configuration override. + compaction_strategy: Optional in-run compaction strategy used by function-calling loops. + tokenizer: Optional tokenizer used by token-aware compaction strategies. kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. Returns: @@ -490,6 +553,9 @@ def as_agent( """ from ._agents import Agent + strategy = getattr(self, "compaction_strategy", None) if compaction_strategy is None else compaction_strategy + resolved_tokenizer = getattr(self, "tokenizer", None) if tokenizer is None else tokenizer + return Agent( client=self, id=id, @@ -501,6 +567,8 @@ def as_agent( context_providers=context_providers, middleware=middleware, function_invocation_configuration=function_invocation_configuration, + compaction_strategy=strategy, + tokenizer=resolved_tokenizer, **kwargs, ) diff --git a/python/packages/core/agent_framework/_compaction.py b/python/packages/core/agent_framework/_compaction.py new file mode 100644 index 0000000000..5a79763dba --- /dev/null +++ b/python/packages/core/agent_framework/_compaction.py @@ -0,0 +1,1086 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + Protocol, + TypeAlias, + runtime_checkable, +) + +from ._types import Content, Message + +if TYPE_CHECKING: + from ._clients import SupportsChatGetResponse + +GroupKind: TypeAlias = Literal["system", "user", "assistant_text", "tool_call"] +GROUP_ANNOTATION_KEY = "_group" +GROUP_ID_KEY = "id" +GROUP_KIND_KEY = "kind" +GROUP_INDEX_KEY = "index" +GROUP_HAS_REASONING_KEY = "has_reasoning" +GROUP_TOKEN_COUNT_KEY = "token_count" # noqa: S105 - compaction metadata key, not a credential +EXCLUDED_KEY = "_excluded" +EXCLUDE_REASON_KEY = "_exclude_reason" +SUMMARY_OF_MESSAGE_IDS_KEY = "_summary_of_message_ids" +SUMMARY_OF_GROUP_IDS_KEY = "_summary_of_group_ids" +SUMMARIZED_BY_SUMMARY_ID_KEY = "_summarized_by_summary_id" + + +logger = logging.getLogger("agent_framework") + + +@runtime_checkable +class TokenizerProtocol(Protocol): + """Protocol for token counters used by token-aware compaction strategies.""" + + def count_tokens(self, text: str) -> int: + """Count tokens for a serialized message payload.""" + ... + + +@runtime_checkable +class CompactionStrategy(Protocol): + """Protocol for in-place message compaction strategies.""" + + async def __call__(self, messages: list[Message]) -> bool: + """Mutate message annotations and/or list contents in place. + + Assumes caller has already applied grouping annotations (and token + annotations when required by the strategy). + + Returns: + True if compaction changed message inclusion or content; otherwise False. + """ + ... + + +class CharacterEstimatorTokenizer: + """Fast heuristic tokenizer using a 4-char/token estimate.""" + + def count_tokens(self, text: str) -> int: + return max(1, len(text) // 4) + + +def _has_content_type(message: Message, content_type: str) -> bool: + return any(content.type == content_type for content in message.contents) + + +def _has_function_call(message: Message) -> bool: + return _has_content_type(message, "function_call") + + +def _has_reasoning(message: Message) -> bool: + return _has_content_type(message, "text_reasoning") + + +def _is_tool_call_assistant(message: Message) -> bool: + return message.role == "assistant" and _has_function_call(message) + + +def _is_reasoning_only_assistant(message: Message) -> bool: + if message.role != "assistant" or not message.contents: + return False + return all(content.type == "text_reasoning" for content in message.contents) + + +def _ensure_message_ids(messages: list[Message]) -> None: + for index, message in enumerate(messages): + if not message.message_id: + message.message_id = f"msg_{index}" + + +def _group_id_for(message: Message, group_index: int) -> str: + if message.message_id: + return f"group_{message.message_id}" + return f"group_index_{group_index}" + + +def group_messages(messages: list[Message]) -> list[dict[str, Any]]: + """Compute group spans and metadata for annotation. + + Returns: + Ordered list of lightweight span dicts with keys: + ``group_id``, ``kind``, ``start_index``, ``end_index``, ``has_reasoning``. + """ + _ensure_message_ids(messages) + spans: list[dict[str, Any]] = [] + i = 0 + group_index = 0 + + while i < len(messages): + current = messages[i] + + if current.role == "system": + spans.append({ + "group_id": _group_id_for(current, group_index), + "kind": "system", + "start_index": i, + "end_index": i, + "has_reasoning": _has_reasoning(current), + }) + i += 1 + group_index += 1 + continue + + if current.role == "user": + spans.append({ + "group_id": _group_id_for(current, group_index), + "kind": "user", + "start_index": i, + "end_index": i, + "has_reasoning": _has_reasoning(current), + }) + i += 1 + group_index += 1 + continue + + # Reasoning prefix before an assistant function_call joins the same tool_call group. + # This includes the OpenAI Responses shape where reasoning and function_call + # contents are co-located in the same assistant message. + if _is_reasoning_only_assistant(current): + prefix_start = i + j = i + while j < len(messages) and _is_reasoning_only_assistant(messages[j]): + j += 1 + if j < len(messages) and _is_tool_call_assistant(messages[j]): + k = j + 1 + has_reasoning = True + while k < len(messages) and _is_reasoning_only_assistant(messages[k]): + has_reasoning = True + k += 1 + while k < len(messages) and messages[k].role == "tool": + k += 1 + spans.append({ + "group_id": _group_id_for(messages[prefix_start], group_index), + "kind": "tool_call", + "start_index": prefix_start, + "end_index": k - 1, + "has_reasoning": has_reasoning or _has_reasoning(messages[j]), + }) + i = k + group_index += 1 + continue + + if _is_tool_call_assistant(current): + has_reasoning = _has_reasoning(current) + k = i + 1 + while k < len(messages) and _is_reasoning_only_assistant(messages[k]): + has_reasoning = True + k += 1 + while k < len(messages) and messages[k].role == "tool": + k += 1 + spans.append({ + "group_id": _group_id_for(current, group_index), + "kind": "tool_call", + "start_index": i, + "end_index": k - 1, + "has_reasoning": has_reasoning, + }) + i = k + group_index += 1 + continue + + if current.role == "tool": + k = i + 1 + while k < len(messages) and messages[k].role == "tool": + k += 1 + spans.append({ + "group_id": _group_id_for(current, group_index), + "kind": "tool_call", + "start_index": i, + "end_index": k - 1, + "has_reasoning": False, + }) + i = k + group_index += 1 + continue + + spans.append({ + "group_id": _group_id_for(current, group_index), + "kind": "assistant_text", + "start_index": i, + "end_index": i, + "has_reasoning": _has_reasoning(current), + }) + i += 1 + group_index += 1 + + return spans + + +def _coerce_group_kind(value: object) -> GroupKind | None: + if value == "system": + return "system" + if value == "user": + return "user" + if value == "assistant_text": + return "assistant_text" + if value == "tool_call": + return "tool_call" + return None + + +def _read_group_annotation(message: Message) -> dict[str, Any] | None: + raw_annotation = _read_group_annotation_raw(message) + if raw_annotation is None: + return None + + group_id = raw_annotation.get(GROUP_ID_KEY) + group_kind = _coerce_group_kind(raw_annotation.get(GROUP_KIND_KEY)) + group_index = raw_annotation.get(GROUP_INDEX_KEY) + has_reasoning = raw_annotation.get(GROUP_HAS_REASONING_KEY) + token_count = raw_annotation.get(GROUP_TOKEN_COUNT_KEY) + if token_count is not None and not isinstance(token_count, int): + return None + if ( + not isinstance(group_id, str) + or group_kind is None + or not isinstance(group_index, int) + or not isinstance(has_reasoning, bool) + ): + return None + + return raw_annotation + + +def _read_group_annotation_raw(message: Message) -> dict[str, Any] | None: + raw_annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if isinstance(raw_annotation, dict): + return raw_annotation + if isinstance(raw_annotation, Mapping): + annotation = dict(raw_annotation) + message.additional_properties[GROUP_ANNOTATION_KEY] = annotation + return annotation + return None + + +def _set_group_summarized_by_summary_id(message: Message, summary_id: str) -> None: + annotation = _read_group_annotation_raw(message) + if annotation is None: + annotation = {} + message.additional_properties[GROUP_ANNOTATION_KEY] = annotation + annotation[SUMMARIZED_BY_SUMMARY_ID_KEY] = summary_id + + +def _write_group_annotation( + message: Message, + *, + group_id: str, + kind: GroupKind, + index: int, + has_reasoning: bool, +) -> None: + existing_raw_annotation = _read_group_annotation_raw(message) + unknown_fields: dict[str, Any] = {} + token_count: int | None = None + if existing_raw_annotation is not None: + raw_token_count = existing_raw_annotation.get(GROUP_TOKEN_COUNT_KEY) + if isinstance(raw_token_count, int) or raw_token_count is None: + token_count = raw_token_count + unknown_fields = { + key: value + for key, value in existing_raw_annotation.items() + if key + not in { + GROUP_ID_KEY, + GROUP_KIND_KEY, + GROUP_INDEX_KEY, + GROUP_HAS_REASONING_KEY, + GROUP_TOKEN_COUNT_KEY, + } + } + + annotation = { + GROUP_ID_KEY: group_id, + GROUP_KIND_KEY: kind, + GROUP_INDEX_KEY: index, + GROUP_HAS_REASONING_KEY: has_reasoning, + GROUP_TOKEN_COUNT_KEY: token_count, + } + annotation.update(unknown_fields) + message.additional_properties[GROUP_ANNOTATION_KEY] = annotation + + +def _group_id(message: Message) -> str | None: + annotation = _read_group_annotation(message) + if annotation is None: + return None + group_id = annotation.get(GROUP_ID_KEY) + return group_id if isinstance(group_id, str) else None + + +def _group_kind(message: Message) -> GroupKind | None: + annotation = _read_group_annotation(message) + if annotation is None: + return None + return _coerce_group_kind(annotation.get(GROUP_KIND_KEY)) + + +def _group_index(message: Message) -> int | None: + annotation = _read_group_annotation(message) + if annotation is None: + return None + group_index = annotation.get(GROUP_INDEX_KEY) + return group_index if isinstance(group_index, int) else None + + +def _token_count(message: Message) -> int | None: + annotation = _read_group_annotation(message) + if annotation is None: + return None + token_count = annotation.get(GROUP_TOKEN_COUNT_KEY) + return token_count if isinstance(token_count, int) else None + + +def _write_token_count(message: Message, token_count: int) -> None: + annotation = _read_group_annotation_raw(message) + if annotation is None: + return + annotation[GROUP_TOKEN_COUNT_KEY] = token_count + message.additional_properties[GROUP_ANNOTATION_KEY] = annotation + + +def _ordered_group_ids_from_annotations(messages: Sequence[Message]) -> list[str]: + ordered_group_ids: list[str] = [] + seen: set[str] = set() + for message in messages: + group_id = _group_id(message) + if group_id is not None and group_id not in seen: + seen.add(group_id) + ordered_group_ids.append(group_id) + return ordered_group_ids + + +def _first_unannotated_index(messages: Sequence[Message]) -> int | None: + for index, message in enumerate(messages): + if _group_id(message) is None: + return index + return None + + +def _first_untokenized_index(messages: Sequence[Message]) -> int | None: + for index, message in enumerate(messages): + if _token_count(message) is None: + return index + return None + + +def _first_annotation_gaps( + messages: Sequence[Message], + *, + include_tokens: bool, +) -> tuple[int | None, int | None]: + first_unannotated: int | None = None + first_untokenized: int | None = None + for index, message in enumerate(messages): + missing_group_annotation = first_unannotated is None and _group_id(message) is None + missing_token_annotation = include_tokens and first_untokenized is None and _token_count(message) is None + + if missing_group_annotation: + first_unannotated = index + if missing_token_annotation: + first_untokenized = index + + if missing_group_annotation or missing_token_annotation: + break + return first_unannotated, first_untokenized + + +def _reannotation_start(messages: Sequence[Message], index: int) -> int: + if index <= 0: + return 0 + previous_index = index - 1 + previous_group_id = _group_id(messages[previous_index]) + if previous_group_id is None: + return previous_index + while previous_index > 0: + prior_group_id = _group_id(messages[previous_index - 1]) + if prior_group_id != previous_group_id: + break + previous_index -= 1 + return previous_index + + +def annotate_message_groups( + messages: list[Message], + *, + from_index: int | None = None, + force_reannotate: bool = False, + tokenizer: TokenizerProtocol | None = None, +) -> list[str]: + """Annotate message groups while reusing existing annotations when possible. + + By default, the function re-annotates only the suffix that contains new + messages and keeps previously annotated prefixes untouched. When a + ``tokenizer`` is provided, token-count annotations are also populated + incrementally. + """ + if not messages: + return [] + + if force_reannotate: + start_index = 0 + elif from_index is not None: + start_index = max(0, min(from_index, len(messages) - 1)) + else: + first_unannotated_index, first_untokenized_index = _first_annotation_gaps( + messages, + include_tokens=tokenizer is not None, + ) + candidate_starts = [index for index in (first_unannotated_index, first_untokenized_index) if index is not None] + if not candidate_starts: + return _ordered_group_ids_from_annotations(messages) + start_index = min(candidate_starts) + + start_index = _reannotation_start(messages, start_index) + + # Continue group indices from the preserved prefix when only re-annotating a suffix. + group_index_offset = 0 + if start_index > 0: + previous_group_index = _group_index(messages[start_index - 1]) + if previous_group_index is not None: + group_index_offset = previous_group_index + 1 + + spans = group_messages(messages[start_index:]) + for span_index, span in enumerate(spans): + group_id = str(span["group_id"]) + kind = _coerce_group_kind(span["kind"]) + if kind is None: + raise ValueError(f"Unexpected group kind in span: {span['kind']}") + local_start_index = int(span["start_index"]) + local_end_index = int(span["end_index"]) + has_reasoning = bool(span["has_reasoning"]) + for idx in range(start_index + local_start_index, start_index + local_end_index + 1): + message = messages[idx] + _write_group_annotation( + message, + group_id=group_id, + kind=kind, + index=group_index_offset + span_index, + has_reasoning=has_reasoning, + ) + message.additional_properties.setdefault(EXCLUDED_KEY, False) + if tokenizer is not None and _token_count(message) is None: + _write_token_count(message, tokenizer.count_tokens(_serialize_message(message))) + return _ordered_group_ids_from_annotations(messages) + + +def _serialize_content(content: Content) -> dict[str, Any]: + payload = content.to_dict(exclude_none=True) + payload.pop("raw_representation", None) + return payload + + +def _serialize_message(message: Message) -> str: + serialized_contents = [_serialize_content(content) for content in message.contents] + payload = { + "role": message.role, + "message_id": message.message_id, + "contents": serialized_contents, + } + return json.dumps(payload, ensure_ascii=True, sort_keys=True, default=str) + + +def annotate_token_counts( + messages: list[Message], + *, + tokenizer: TokenizerProtocol, + from_index: int | None = None, + force_retokenize: bool = False, +) -> None: + """Annotate token-count metadata, incrementally by default.""" + if not messages: + return + + # Token counts are stored inside group annotations. + annotate_message_groups(messages, from_index=from_index) + + if force_retokenize: + start_index = 0 + elif from_index is not None: + start_index = max(0, min(from_index, len(messages) - 1)) + else: + first_untokenized_index = _first_untokenized_index(messages) + if first_untokenized_index is None: + return + start_index = first_untokenized_index + + for message in messages[start_index:]: + _write_token_count(message, tokenizer.count_tokens(_serialize_message(message))) + + +def extend_compaction_messages( + messages: list[Message], + new_messages: Sequence[Message], + *, + tokenizer: TokenizerProtocol | None = None, +) -> None: + """Append a batch of messages and annotate only the appended tail.""" + if not new_messages: + return + + start_index = len(messages) + messages.extend(new_messages) + annotate_message_groups( + messages, + from_index=start_index, + tokenizer=tokenizer, + ) + + +def append_compaction_message( + messages: list[Message], + message: Message, + *, + tokenizer: TokenizerProtocol | None = None, +) -> None: + """Append a single message and incrementally annotate metadata.""" + extend_compaction_messages(messages, [message], tokenizer=tokenizer) + + +def included_messages(messages: list[Message]) -> list[Message]: + return [message for message in messages if not message.additional_properties.get(EXCLUDED_KEY, False)] + + +def included_token_count(messages: list[Message]) -> int: + total = 0 + for message in included_messages(messages): + token_count = _token_count(message) + if token_count is not None: + total += token_count + return total + + +def set_excluded(message: Message, *, excluded: bool, reason: str | None = None) -> bool: + changed = bool(message.additional_properties.get(EXCLUDED_KEY, False)) != excluded + if changed: + message.additional_properties[EXCLUDED_KEY] = excluded + if reason is not None: + message.additional_properties[EXCLUDE_REASON_KEY] = reason + return changed + + +def exclude_group_ids(messages: list[Message], group_ids: set[str], *, reason: str) -> bool: + changed = False + for message in messages: + group_id = _group_id(message) + if group_id is not None and group_id in group_ids: + changed = set_excluded(message, excluded=True, reason=reason) or changed + return changed + + +def project_included_messages(messages: list[Message]) -> list[Message]: + return included_messages(messages) + + +def _group_messages_by_id(messages: list[Message]) -> dict[str, list[Message]]: + grouped: dict[str, list[Message]] = {} + for message in messages: + group_id = _group_id(message) + if group_id is None: + continue + grouped.setdefault(group_id, []).append(message) + return grouped + + +def _group_kind_map(messages: list[Message]) -> dict[str, GroupKind]: + kinds: dict[str, GroupKind] = {} + for message in messages: + group_id = _group_id(message) + group_kind = _group_kind(message) + if group_id is not None and group_kind is not None and group_id not in kinds: + kinds[group_id] = group_kind + return kinds + + +def _group_start_indices(messages: list[Message]) -> dict[str, int]: + starts: dict[str, int] = {} + for idx, message in enumerate(messages): + group_id = _group_id(message) + if group_id is not None and group_id not in starts: + starts[group_id] = idx + return starts + + +def _included_group_ids(messages: list[Message], ordered_group_ids: list[str]) -> list[str]: + grouped = _group_messages_by_id(messages) + included_ids: list[str] = [] + for group_id in ordered_group_ids: + if any(not m.additional_properties.get(EXCLUDED_KEY, False) for m in grouped.get(group_id, [])): + included_ids.append(group_id) + return included_ids + + +def _count_included_messages(messages: list[Message]) -> int: + return len(included_messages(messages)) + + +def _count_included_tokens(messages: list[Message]) -> int: + return included_token_count(messages) + + +class TruncationStrategy: + """Oldest-first compaction using a single metric threshold. + + This strategy runs after group annotations are computed and excludes whole + groups (never partial tool-call groups). The metric is: + - token count when ``tokenizer`` is provided + - included message count when ``tokenizer`` is not provided + Compaction triggers when the metric exceeds ``max_n`` and trims to + ``compact_to``. + """ + + def __init__( + self, + *, + max_n: int, + compact_to: int, + tokenizer: TokenizerProtocol | None = None, + preserve_system: bool = True, + ) -> None: + """Create a truncation strategy. + + Keyword Args: + max_n: Trigger threshold measured in tokens when ``tokenizer`` is + provided, otherwise measured in included messages. + compact_to: Target value for the same metric used by ``max_n``. + This argument is required and must be explicitly set. + tokenizer: Optional tokenizer used for token-based truncation. + preserve_system: When True, system groups remain included and only + non-system groups are eligible for exclusion. + """ + if max_n <= 0: + raise ValueError("max_n must be greater than 0.") + if compact_to <= 0: + raise ValueError("compact_to must be greater than 0.") + if compact_to > max_n: + raise ValueError("compact_to must be less than or equal to max_n.") + self.max_n = max_n + self.compact_to = compact_to + self.tokenizer = tokenizer + self.preserve_system = preserve_system + + async def __call__(self, messages: list[Message]) -> bool: + ordered_group_ids = _ordered_group_ids_from_annotations(messages) + if self.tokenizer is not None: + over_limit = _count_included_tokens(messages) > self.max_n + else: + over_limit = _count_included_messages(messages) > self.max_n + if not over_limit: + return False + + grouped = _group_messages_by_id(messages) + kinds = _group_kind_map(messages) + protected_ids = set() + if self.preserve_system: + protected_ids = {group_id for group_id in ordered_group_ids if kinds.get(group_id) == "system"} + + changed = False + for group_id in ordered_group_ids: + if self.tokenizer is not None: + target_met = _count_included_tokens(messages) <= self.compact_to + else: + target_met = _count_included_messages(messages) <= self.compact_to + if target_met: + break + if group_id in protected_ids: + continue + for message in grouped.get(group_id, []): + changed = set_excluded(message, excluded=True, reason="truncation") or changed + return changed + + +class SlidingWindowStrategy: + """Windowed compaction that keeps the most recent non-system groups. + + The strategy preserves recency by retaining only the last + ``keep_last_groups`` included non-system groups. System groups can be kept + as stable anchors when ``preserve_system`` is enabled. + + This can remove older user and assistant groups while keeping system + instructions, which is useful when directives must persist but conversation + history grows. Use ``SelectiveToolCallCompactionStrategy`` when only tool + groups should be reduced. + """ + + def __init__(self, *, keep_last_groups: int, preserve_system: bool = True) -> None: + """Create a sliding-window strategy. + + Args: + keep_last_groups: Number of most-recent non-system groups to keep. + preserve_system: Whether system groups should always remain included. + """ + self.keep_last_groups = keep_last_groups + self.preserve_system = preserve_system + + async def __call__(self, messages: list[Message]) -> bool: + ordered_group_ids = _ordered_group_ids_from_annotations(messages) + grouped = _group_messages_by_id(messages) + kinds = _group_kind_map(messages) + + included_group_ids = _included_group_ids(messages, ordered_group_ids) + non_system_group_ids = [group_id for group_id in included_group_ids if kinds.get(group_id) != "system"] + keep_non_system_ids = set(non_system_group_ids[-self.keep_last_groups :]) + keep_ids = set(keep_non_system_ids) + if self.preserve_system: + keep_ids.update(group_id for group_id in ordered_group_ids if kinds.get(group_id) == "system") + + changed = False + for group_id in included_group_ids: + if group_id in keep_ids: + continue + for message in grouped.get(group_id, []): + changed = set_excluded(message, excluded=True, reason="sliding_window") or changed + return changed + + +class SelectiveToolCallCompactionStrategy: + """Compaction focused on reducing tool-call history growth. + + This strategy only targets groups annotated as ``tool_call`` and keeps the + latest ``keep_last_tool_call_groups`` included tool-call groups. It is + useful when tool chatter dominates token usage. + + It does not change non-tool-call groups, so it can be combined with other + strategies that target different aspects of the message history. + """ + + def __init__(self, *, keep_last_tool_call_groups: int = 1) -> None: + """Create a tool-call-focused compaction strategy. + + Args: + keep_last_tool_call_groups: Number of newest included tool-call + groups to retain. Set to 0 to remove all included tool-call + groups. + + Raises: + ValueError: If ``keep_last_tool_call_groups`` is negative. + """ + if keep_last_tool_call_groups < 0: + raise ValueError("keep_last_tool_call_groups must be greater than or equal to 0.") + self.keep_last_tool_call_groups = keep_last_tool_call_groups + + async def __call__(self, messages: list[Message]) -> bool: + ordered_group_ids = _ordered_group_ids_from_annotations(messages) + grouped = _group_messages_by_id(messages) + kinds = _group_kind_map(messages) + + included_tool_group_ids = [ + group_id + for group_id in _included_group_ids(messages, ordered_group_ids) + if kinds.get(group_id) == "tool_call" + ] + if len(included_tool_group_ids) <= self.keep_last_tool_call_groups: + return False + + keep_ids = ( + set(included_tool_group_ids[-self.keep_last_tool_call_groups :]) + if self.keep_last_tool_call_groups > 0 + else set() + ) + changed = False + for group_id in included_tool_group_ids: + if group_id in keep_ids: + continue + for message in grouped.get(group_id, []): + changed = set_excluded(message, excluded=True, reason="tool_call_compaction") or changed + return changed + + +def _format_messages_for_summary(messages: list[Message]) -> str: + lines: list[str] = [] + for index, message in enumerate(messages, start=1): + content_text = message.text + if not content_text: + content_text = ", ".join(content.type for content in message.contents) + lines.append(f"{index}. [{message.role}] {content_text}") + return "\n".join(lines) + + +DEFAULT_SUMMARIZATION_PROMPT: Final[ + str +] = """**Generate a clear and complete summary of the entire conversation in no more than five sentences.** + +The summary must always: +- Reflect contributions from both the user and the assistant +- Preserve context to support ongoing dialogue +- Incorporate any previously provided summary +- Emphasize the most relevant and meaningful points + +The summary must never: +- Offer critique, correction, interpretation, or speculation +- Highlight errors, misunderstandings, or judgments of accuracy +- Comment on events or ideas not present in the conversation +- Omit any details included in an earlier summary +""" + + +class SummarizationStrategy: + """Summarize older included groups and replace them with linked summary text. + + The strategy monitors included non-system message count and triggers when + that count grows beyond ``target_count + threshold``. When triggered, it + summarizes the oldest groups and retains the newest content near + ``target_count`` (subject to atomic group boundaries). It writes trace + metadata in both directions: summary -> original message/group IDs and + original -> summary ID. + """ + + def __init__( + self, + *, + client: SupportsChatGetResponse[Any], + target_count: int = 4, + threshold: int | None = 2, + prompt: str | None = None, + ) -> None: + """Create a summarization strategy. + + Keyword Args: + client: A chat client compatible with ``SupportsChatGetResponse`` + used to generate summary text. + target_count: Target number of included non-system messages to + retain after summarization. Must be greater than 0. + threshold: Extra included non-system messages allowed above + ``target_count`` before summarization triggers. Must be greater + than or equal to 0 when provided. + prompt: Optional summarization instruction. If omitted, a default + prompt that preserves goals, decisions, and unresolved items is + used. + + Raises: + ValueError: If ``target_count`` is less than 1. + ValueError: If ``threshold`` is provided and is negative. + """ + if target_count <= 0: + raise ValueError("target_count must be greater than 0.") + if threshold is not None and threshold < 0: + raise ValueError("threshold must be greater than or equal to 0.") + self.client = client + self.target_count = target_count + self.threshold = threshold if threshold is not None else 0 + self.prompt = prompt or DEFAULT_SUMMARIZATION_PROMPT + + async def __call__(self, messages: list[Message]) -> bool: + ordered_group_ids = _ordered_group_ids_from_annotations(messages) + grouped = _group_messages_by_id(messages) + kinds = _group_kind_map(messages) + starts = _group_start_indices(messages) + + included_non_system_groups: list[tuple[str, list[Message]]] = [] + included_non_system_message_count = 0 + for group_id in _included_group_ids(messages, ordered_group_ids): + if kinds.get(group_id) == "system": + continue + group_messages = [ + message + for message in grouped.get(group_id, []) + if not message.additional_properties.get(EXCLUDED_KEY, False) + ] + if not group_messages: + continue + included_non_system_groups.append((group_id, group_messages)) + included_non_system_message_count += len(group_messages) + + if included_non_system_message_count <= self.target_count + self.threshold: + return False + + keep_group_ids: list[str] = [] + retained_message_count = 0 + for group_id, group_messages in reversed(included_non_system_groups): + if retained_message_count >= self.target_count and keep_group_ids: + break + keep_group_ids.append(group_id) + retained_message_count += len(group_messages) + keep_group_id_set = set(keep_group_ids) + + group_ids_to_summarize = [ + group_id for group_id, _ in included_non_system_groups if group_id not in keep_group_id_set + ] + if not group_ids_to_summarize: + return False + + messages_to_summarize: list[Message] = [] + for group_id, group_messages in included_non_system_groups: + if group_id in keep_group_id_set: + continue + messages_to_summarize.extend(group_messages) + if not messages_to_summarize: + return False + + try: + summary_response = await self.client.get_response( + [ + Message(role="system", text=self.prompt), + Message( + role="user", + text=_format_messages_for_summary(messages_to_summarize), + ), + ], + stream=False, + options={}, + ) + except Exception as exc: + logger.warning( + "Skipping summarization compaction: summary generation failed (%s).", + exc, + ) + return False + + summary_text = summary_response.text.strip() if summary_response.text else "" + if not summary_text: + logger.warning("Skipping summarization compaction: summarizer returned no text.") + return False + summary_id = f"summary_{len(messages)}" + original_message_ids = [message.message_id for message in messages_to_summarize if message.message_id] + summary_of_group_ids = list(group_ids_to_summarize) + summary_annotation = { + SUMMARY_OF_MESSAGE_IDS_KEY: original_message_ids, + SUMMARY_OF_GROUP_IDS_KEY: summary_of_group_ids, + } + + summary_message = Message( + role="assistant", + text=summary_text, + message_id=summary_id, + additional_properties={ + GROUP_ANNOTATION_KEY: summary_annotation, + }, + ) + + for message in messages_to_summarize: + _set_group_summarized_by_summary_id(message, summary_id) + set_excluded(message, excluded=True, reason="summarized") + + insertion_index = min(starts[group_id] for group_id in group_ids_to_summarize if group_id in starts) + messages.insert(insertion_index, summary_message) + return True + + +class TokenBudgetComposedStrategy: + """Compose multiple strategies until an included-token budget is satisfied. + + Strategies run in the provided order over shared message annotations. After + each step, token counts are refreshed. If no strategy reaches budget, a + deterministic fallback excludes oldest groups (and finally anchors when + necessary) to enforce the limit. + """ + + def __init__( + self, + *, + token_budget: int, + tokenizer: TokenizerProtocol, + strategies: Sequence[CompactionStrategy], + early_stop: bool = True, + ) -> None: + """Create a composed token-budget strategy. + + Args: + token_budget: Maximum included token count allowed after compaction. + tokenizer: Tokenizer implementation used for per-message token + annotation. + strategies: Ordered strategy sequence to execute before fallback. + early_stop: When True, stop as soon as budget is satisfied. + """ + self.token_budget = token_budget + self.tokenizer = tokenizer + self.strategies = list(strategies) + self.early_stop = early_stop + + async def __call__(self, messages: list[Message]) -> bool: + annotate_message_groups(messages) + annotate_token_counts(messages, tokenizer=self.tokenizer) + if included_token_count(messages) <= self.token_budget: + return False + + changed = False + for strategy in self.strategies: + changed = (await strategy(messages)) or changed + annotate_message_groups(messages) + annotate_token_counts(messages, tokenizer=self.tokenizer) + if self.early_stop and included_token_count(messages) <= self.token_budget: + return changed + + if included_token_count(messages) <= self.token_budget: + return changed + + ordered_group_ids = annotate_message_groups(messages) + grouped = _group_messages_by_id(messages) + kinds = _group_kind_map(messages) + for group_id in ordered_group_ids: + if kinds.get(group_id) == "system": + continue + for message in grouped.get(group_id, []): + changed = set_excluded(message, excluded=True, reason="token_budget_fallback") or changed + if included_token_count(messages) <= self.token_budget: + break + if included_token_count(messages) <= self.token_budget: + return changed + + # Strict budget enforcement fallback: if anchors alone exceed budget, exclude remaining groups. + for group_id in ordered_group_ids: + if kinds.get(group_id) != "system": + continue + for message in grouped.get(group_id, []): + changed = set_excluded(message, excluded=True, reason="token_budget_fallback_strict") or changed + if included_token_count(messages) <= self.token_budget: + break + return changed + + +async def apply_compaction( + messages: list[Message], + *, + strategy: CompactionStrategy | None, + tokenizer: TokenizerProtocol | None = None, +) -> list[Message]: + """Apply configured compaction and return projected model-input messages.""" + if strategy is None: + return messages + annotate_message_groups(messages) + if tokenizer is not None: + annotate_token_counts(messages, tokenizer=tokenizer) + await strategy(messages) + return project_included_messages(messages) + + +__all__ = [ + "EXCLUDED_KEY", + "EXCLUDE_REASON_KEY", + "GROUP_ANNOTATION_KEY", + "GROUP_HAS_REASONING_KEY", + "GROUP_ID_KEY", + "GROUP_INDEX_KEY", + "GROUP_KIND_KEY", + "GROUP_TOKEN_COUNT_KEY", + "SUMMARIZED_BY_SUMMARY_ID_KEY", + "SUMMARY_OF_GROUP_IDS_KEY", + "SUMMARY_OF_MESSAGE_IDS_KEY", + "CharacterEstimatorTokenizer", + "CompactionStrategy", + "GroupKind", + "SelectiveToolCallCompactionStrategy", + "SlidingWindowStrategy", + "SummarizationStrategy", + "TokenBudgetComposedStrategy", + "TokenizerProtocol", + "TruncationStrategy", + "annotate_message_groups", + "annotate_token_counts", + "append_compaction_message", + "apply_compaction", + "extend_compaction_messages", + "group_messages", + "included_messages", + "included_token_count", + "project_included_messages", +] diff --git a/python/packages/core/agent_framework/_skills.py b/python/packages/core/agent_framework/_skills.py index 9e11ecbe96..49695c89e6 100644 --- a/python/packages/core/agent_framework/_skills.py +++ b/python/packages/core/agent_framework/_skills.py @@ -151,6 +151,7 @@ class Skill: content="Use this skill for DB tasks.", ) + @skill.resource def get_schema() -> str: return "CREATE TABLE ..." @@ -972,9 +973,7 @@ def _load_skills( if skills: for code_skill in skills: - error = _validate_skill_metadata( - code_skill.name, code_skill.description, "code skill" - ) + error = _validate_skill_metadata(code_skill.name, code_skill.description, "code skill") if error: logger.warning(error) continue diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index ee0e813d27..5f4011828f 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -278,6 +278,17 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any: return value +def _restore_compaction_annotation_in_additional_properties( + additional_properties: MutableMapping[str, Any] | None, + *, + allow_none: bool = False, +) -> dict[str, Any] | None: + if additional_properties is None: + return None if allow_none else {} + + return dict(additional_properties) + + # endregion # region Constants and types @@ -512,7 +523,9 @@ def __init__( """ self.type = type self.annotations = annotations - self.additional_properties: dict[str, Any] = additional_properties or {} # type: ignore[assignment] + self.additional_properties: dict[str, Any] = ( + _restore_compaction_annotation_in_additional_properties(additional_properties) or {} + ) self.raw_representation = raw_representation # Set all content-specific attributes @@ -1680,7 +1693,9 @@ def __init__( self.contents = parsed_contents self.author_name = author_name self.message_id = message_id - self.additional_properties = additional_properties or {} + self.additional_properties = ( + _restore_compaction_annotation_in_additional_properties(additional_properties) or {} + ) self.raw_representation = raw_representation @property @@ -2034,7 +2049,9 @@ def __init__( self._value: ResponseModelT | None = value self._response_format: type[BaseModel] | None = response_format self._value_parsed: bool = value is not None - self.additional_properties = additional_properties or {} + self.additional_properties = ( + _restore_compaction_annotation_in_additional_properties(additional_properties) or {} + ) self.continuation_token = continuation_token self.raw_representation: Any | list[Any] | None = raw_representation @@ -2284,7 +2301,10 @@ def __init__( self.created_at = created_at self.finish_reason = finish_reason self.continuation_token = continuation_token - self.additional_properties = additional_properties + self.additional_properties = _restore_compaction_annotation_in_additional_properties( + additional_properties, + allow_none=True, + ) self.raw_representation = raw_representation @property @@ -2397,7 +2417,9 @@ def __init__( self._value: ResponseModelT | None = value self._response_format: type[BaseModel] | None = response_format self._value_parsed: bool = value is not None - self.additional_properties = additional_properties or {} + self.additional_properties = ( + _restore_compaction_annotation_in_additional_properties(additional_properties) or {} + ) self.continuation_token = continuation_token self.raw_representation = raw_representation @@ -2631,7 +2653,10 @@ def __init__( self.message_id = message_id self.created_at = created_at self.continuation_token = continuation_token - self.additional_properties = additional_properties + self.additional_properties = _restore_compaction_annotation_in_additional_properties( + additional_properties, + allow_none=True, + ) self.raw_representation: Any | list[Any] | None = raw_representation @property @@ -3414,7 +3439,9 @@ def __init__( self._dimensions = dimensions self.model_id = model_id self.created_at = created_at - self.additional_properties = additional_properties or {} + self.additional_properties = ( + _restore_compaction_annotation_in_additional_properties(additional_properties) or {} + ) @property def dimensions(self) -> int | None: @@ -3472,7 +3499,9 @@ def __init__( super().__init__(embeddings or []) self.options = options self.usage = usage - self.additional_properties = additional_properties or {} + self.additional_properties = ( + _restore_compaction_annotation_in_additional_properties(additional_properties) or {} + ) # endregion diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index a23b1d2a5f..4108b88d11 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from typing import Any from unittest.mock import patch from agent_framework import ( @@ -13,6 +14,7 @@ SupportsImageGenerationTool, SupportsMCPTool, SupportsWebSearchTool, + TruncationStrategy, ) @@ -48,6 +50,60 @@ async def test_base_client_get_response_streaming(chat_client_base: SupportsChat assert update.text == "update - Hello" or update.text == "another update" +async def test_base_client_applies_compaction_before_non_streaming_inner_call( + chat_client_base: SupportsChatGetResponse, +): + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] + chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined] + captured_roles: list[list[str]] = [] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + + async def _capture( + *, + messages: list[Message], + options: dict[str, Any], + **kwargs: Any, + ) -> ChatResponse: + captured_roles.append([message.role for message in messages]) + return await original(messages=messages, options=options, **kwargs) + + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + await chat_client_base.get_response([ + Message(role="user", text="Hello"), + Message(role="assistant", text="Previous response"), + ]) + assert captured_roles == [["assistant"]] + + +async def test_base_client_applies_compaction_before_streaming_inner_call( + chat_client_base: SupportsChatGetResponse, +): + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] + chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined] + captured_roles: list[list[str]] = [] + original = chat_client_base._get_streaming_response # type: ignore[attr-defined] + + def _capture( + *, + messages: list[Message], + options: dict[str, Any], + **kwargs: Any, + ): + captured_roles.append([message.role for message in messages]) + return original(messages=messages, options=options, **kwargs) + + chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined,method-assign] + async for _ in chat_client_base.get_response( + [ + Message(role="user", text="Hello"), + Message(role="assistant", text="Previous response"), + ], + stream=True, + ): + pass + assert captured_roles == [["assistant"]] + + async def test_chat_client_instructions_handling(chat_client_base: SupportsChatGetResponse): instructions = "You are a helpful assistant." diff --git a/python/packages/core/tests/core/test_compaction.py b/python/packages/core/tests/core/test_compaction.py new file mode 100644 index 0000000000..55e1d0abca --- /dev/null +++ b/python/packages/core/tests/core/test_compaction.py @@ -0,0 +1,450 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import logging +from typing import Any + +from agent_framework import ChatResponse, Content, Message +from agent_framework._compaction import ( + EXCLUDED_KEY, + GROUP_ANNOTATION_KEY, + GROUP_HAS_REASONING_KEY, + GROUP_ID_KEY, + GROUP_KIND_KEY, + GROUP_TOKEN_COUNT_KEY, + SUMMARIZED_BY_SUMMARY_ID_KEY, + SUMMARY_OF_GROUP_IDS_KEY, + SUMMARY_OF_MESSAGE_IDS_KEY, + CharacterEstimatorTokenizer, + SelectiveToolCallCompactionStrategy, + SlidingWindowStrategy, + SummarizationStrategy, + TokenBudgetComposedStrategy, + TruncationStrategy, + annotate_message_groups, + append_compaction_message, + apply_compaction, + extend_compaction_messages, + included_messages, + included_token_count, +) + + +def _assistant_function_call(call_id: str) -> Message: + return Message( + role="assistant", + contents=[Content.from_function_call(call_id=call_id, name="tool", arguments='{"value":"x"}')], + ) + + +def _assistant_reasoning_and_function_calls(*call_ids: str) -> Message: + contents: list[Content] = [Content.from_text_reasoning(text="thinking")] + for call_id in call_ids: + contents.append( + Content.from_function_call( + call_id=call_id, + name="tool", + arguments='{"value":"x"}', + ) + ) + return Message(role="assistant", contents=contents) + + +def _tool_result(call_id: str, result: str) -> Message: + return Message( + role="tool", + contents=[Content.from_function_result(call_id=call_id, result=result)], + ) + + +def _group_id(message: Message) -> str | None: + annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if not isinstance(annotation, dict): + return None + value = annotation.get(GROUP_ID_KEY) + return value if isinstance(value, str) else None + + +def _group_kind(message: Message) -> str | None: + annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if not isinstance(annotation, dict): + return None + value = annotation.get(GROUP_KIND_KEY) + return value if isinstance(value, str) else None + + +def _group_has_reasoning(message: Message) -> bool | None: + annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if not isinstance(annotation, dict): + return None + value = annotation.get(GROUP_HAS_REASONING_KEY) + return value if isinstance(value, bool) else None + + +def _token_count(message: Message) -> int | None: + annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if not isinstance(annotation, dict): + return None + value = annotation.get(GROUP_TOKEN_COUNT_KEY) + return value if isinstance(value, int) else None + + +def _group_unknown_value(message: Message, key: str) -> Any: + annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if not isinstance(annotation, dict): + return None + return annotation.get(key) + + +def test_group_annotations_keep_tool_call_and_tool_result_atomic() -> None: + messages = [ + Message(role="user", text="hello"), + _assistant_function_call("c1"), + _tool_result("c1", "ok"), + Message(role="assistant", text="final"), + ] + + annotate_message_groups(messages) + + call_group = _group_id(messages[1]) + assert call_group is not None + assert call_group == _group_id(messages[2]) + assert _group_id(messages[1]) != _group_id(messages[0]) + + +def test_group_annotations_include_reasoning_in_tool_call_group() -> None: + messages = [ + _assistant_reasoning_and_function_calls("c2"), + _tool_result("c2", "ok"), + ] + + annotate_message_groups(messages) + + first_group = _group_id(messages[0]) + assert first_group is not None + assert _group_id(messages[1]) == first_group + assert _group_has_reasoning(messages[0]) is True + assert _group_kind(messages[0]) == "tool_call" + + +def test_group_annotations_handle_same_message_reasoning_and_function_calls() -> None: + messages = [ + Message(role="user", text="hello"), + _assistant_reasoning_and_function_calls("c1", "c2"), + _tool_result("c1", "ok1"), + _tool_result("c2", "ok2"), + Message(role="assistant", text="final"), + ] + + annotate_message_groups(messages) + + call_group = _group_id(messages[1]) + assert call_group is not None + assert _group_id(messages[2]) == call_group + assert _group_id(messages[3]) == call_group + assert _group_kind(messages[1]) == "tool_call" + assert _group_has_reasoning(messages[1]) is True + + +def test_annotate_message_groups_with_tokenizer_adds_token_counts() -> None: + messages = [ + Message(role="user", text="hello"), + Message(role="assistant", text="world"), + ] + + annotate_message_groups( + messages, + tokenizer=CharacterEstimatorTokenizer(), + ) + + assert isinstance(_token_count(messages[0]), int) + assert isinstance(_token_count(messages[1]), int) + + +def test_extend_compaction_messages_preserves_existing_annotations_and_tokens() -> None: + tokenizer = CharacterEstimatorTokenizer() + messages = [_assistant_function_call("c3")] + annotate_message_groups(messages) + old_group_id = _group_id(messages[0]) + assert old_group_id is not None + old_token_count = tokenizer.count_tokens("precomputed") + annotation = messages[0].additional_properties.get(GROUP_ANNOTATION_KEY) + if isinstance(annotation, dict): + annotation[GROUP_TOKEN_COUNT_KEY] = old_token_count + + extend_compaction_messages(messages, [_tool_result("c3", "ok")], tokenizer=tokenizer) + + assert _group_id(messages[1]) == old_group_id + assert _token_count(messages[0]) == old_token_count + assert isinstance(_token_count(messages[1]), int) + + +def test_append_compaction_message_annotates_new_message() -> None: + messages = [Message(role="user", text="hello")] + annotate_message_groups(messages) + append_compaction_message(messages, Message(role="assistant", text="world")) + + assert len(messages) == 2 + assert isinstance(_group_id(messages[1]), str) + + +async def test_truncation_strategy_keeps_system_anchor() -> None: + messages = [ + Message(role="system", text="you are helpful"), + Message(role="user", text="u1"), + Message(role="assistant", text="a1"), + Message(role="user", text="u2"), + Message(role="assistant", text="a2"), + ] + strategy = TruncationStrategy(max_n=3, compact_to=3, preserve_system=True) + annotate_message_groups(messages) + + changed = await strategy(messages) + + assert changed is True + projected = included_messages(messages) + assert projected[0].role == "system" + assert len(projected) <= 3 + + +async def test_truncation_strategy_compacts_when_token_limit_exceeded() -> None: + tokenizer = CharacterEstimatorTokenizer() + messages = [ + Message(role="system", text="you are helpful"), + Message(role="user", text="u1 " * 200), + Message(role="assistant", text="a1 " * 200), + ] + strategy = TruncationStrategy( + max_n=80, + compact_to=40, + tokenizer=tokenizer, + preserve_system=True, + ) + annotate_message_groups(messages, tokenizer=tokenizer) + + changed = await strategy(messages) + + assert changed is True + projected = included_messages(messages) + assert projected[0].role == "system" + assert included_token_count(messages) <= 40 + + +def test_truncation_strategy_validates_token_targets() -> None: + try: + TruncationStrategy(max_n=3, compact_to=4) + except ValueError as exc: + assert "compact_to must be less than or equal to max_n" in str(exc) + else: + raise AssertionError("Expected ValueError when compact_to is greater than max_n.") + + +async def test_selective_tool_call_strategy_excludes_older_tool_groups() -> None: + messages = [ + Message(role="user", text="u"), + _assistant_function_call("call-1"), + _tool_result("call-1", "r1"), + _assistant_function_call("call-2"), + _tool_result("call-2", "r2"), + Message(role="assistant", text="done"), + ] + strategy = SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=1) + annotate_message_groups(messages) + + changed = await strategy(messages) + + assert changed is True + assert messages[1].additional_properties.get(EXCLUDED_KEY) is True + assert messages[2].additional_properties.get(EXCLUDED_KEY) is True + assert messages[3].additional_properties.get(EXCLUDED_KEY) is not True + assert messages[4].additional_properties.get(EXCLUDED_KEY) is not True + + +async def test_selective_tool_call_strategy_with_zero_removes_assistant_tool_pair() -> None: + messages = [ + Message(role="user", text="u"), + _assistant_function_call("call-1"), + _tool_result("call-1", "r1"), + Message(role="assistant", text="done"), + ] + strategy = SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0) + annotate_message_groups(messages) + + changed = await strategy(messages) + + assert changed is True + assert messages[1].additional_properties.get(EXCLUDED_KEY) is True + assert messages[2].additional_properties.get(EXCLUDED_KEY) is True + assert messages[0].additional_properties.get(EXCLUDED_KEY) is not True + assert messages[3].additional_properties.get(EXCLUDED_KEY) is not True + + +def test_selective_tool_call_strategy_rejects_negative_keep_count() -> None: + try: + SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=-1) + except ValueError as exc: + assert "must be greater than or equal to 0" in str(exc) + else: + raise AssertionError("Expected ValueError for negative keep_last_tool_call_groups.") + + +class _FakeSummarizer: + async def get_response( + self, + messages: list[Message], + *, + stream: bool = False, + options: dict[str, Any] | None = None, + **kwargs: Any, + ) -> ChatResponse: + return ChatResponse(messages=[Message(role="assistant", text="summarized context")]) + + +class _FailingSummarizer: + async def get_response( + self, + messages: list[Message], + *, + stream: bool = False, + options: dict[str, Any] | None = None, + **kwargs: Any, + ) -> ChatResponse: + raise RuntimeError("summary failed") + + +class _EmptySummarizer: + async def get_response( + self, + messages: list[Message], + *, + stream: bool = False, + options: dict[str, Any] | None = None, + **kwargs: Any, + ) -> ChatResponse: + return ChatResponse(messages=[Message(role="assistant", text=" ")]) + + +async def test_summarization_strategy_adds_bidirectional_trace_links() -> None: + messages = [ + Message(role="user", text="u1"), + Message(role="assistant", text="a1"), + Message(role="user", text="u2"), + Message(role="assistant", text="a2"), + Message(role="user", text="u3"), + Message(role="assistant", text="a3"), + ] + strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0) + annotate_message_groups(messages) + + changed = await strategy(messages) + + assert changed is True + summary_messages = [ + message for message in messages if _group_unknown_value(message, SUMMARY_OF_MESSAGE_IDS_KEY) is not None + ] + assert len(summary_messages) == 1 + summary = summary_messages[0] + summary_id = summary.message_id + assert summary_id is not None + assert _group_unknown_value(summary, SUMMARY_OF_GROUP_IDS_KEY) + summarized_message_ids = _group_unknown_value(summary, SUMMARY_OF_MESSAGE_IDS_KEY) + assert isinstance(summarized_message_ids, list) + for message in messages: + if message.message_id in summarized_message_ids: + assert _group_unknown_value(message, SUMMARIZED_BY_SUMMARY_ID_KEY) == summary_id + assert message.additional_properties.get(EXCLUDED_KEY) is True + + +async def test_summarization_strategy_returns_false_when_summary_generation_fails( + caplog: Any, +) -> None: + messages = [ + Message(role="user", text="u1"), + Message(role="assistant", text="a1"), + Message(role="user", text="u2"), + Message(role="assistant", text="a2"), + Message(role="user", text="u3"), + Message(role="assistant", text="a3"), + ] + strategy = SummarizationStrategy(client=_FailingSummarizer(), target_count=2, threshold=0) + annotate_message_groups(messages) + + with caplog.at_level(logging.WARNING, logger="agent_framework"): + changed = await strategy(messages) + + assert changed is False + assert any("summary generation failed" in record.message for record in caplog.records) + assert all(message.additional_properties.get(EXCLUDED_KEY) is not True for message in messages) + + +async def test_summarization_strategy_returns_false_when_summary_is_empty( + caplog: Any, +) -> None: + messages = [ + Message(role="user", text="u1"), + Message(role="assistant", text="a1"), + Message(role="user", text="u2"), + Message(role="assistant", text="a2"), + Message(role="user", text="u3"), + Message(role="assistant", text="a3"), + ] + strategy = SummarizationStrategy(client=_EmptySummarizer(), target_count=2, threshold=0) + annotate_message_groups(messages) + + with caplog.at_level(logging.WARNING, logger="agent_framework"): + changed = await strategy(messages) + + assert changed is False + assert any("returned no text" in record.message for record in caplog.records) + assert all(message.additional_properties.get(EXCLUDED_KEY) is not True for message in messages) + + +async def test_token_budget_composed_strategy_meets_budget_or_falls_back() -> None: + messages = [ + Message(role="system", text="system"), + Message(role="user", text="user " * 200), + Message(role="assistant", text="assistant " * 200), + ] + strategy = TokenBudgetComposedStrategy( + token_budget=20, + tokenizer=CharacterEstimatorTokenizer(), + strategies=[SlidingWindowStrategy(keep_last_groups=1)], + ) + + changed = await strategy(messages) + + assert changed is True + assert included_token_count(messages) <= 20 + + +class _ExcludeOldestNonSystem: + async def __call__(self, messages: list[Message]) -> bool: + group_ids = annotate_message_groups(messages) + kinds: dict[str, str] = {} + for message in messages: + group_id = _group_id(message) + kind = _group_kind(message) + if group_id is not None and kind is not None and group_id not in kinds: + kinds[group_id] = kind + for group_id in group_ids: + if kinds.get(group_id) == "system": + continue + for message in messages: + if _group_id(message) == group_id: + message.additional_properties[EXCLUDED_KEY] = True + return True + return False + + +async def test_apply_compaction_projects_included_messages_only() -> None: + messages = [ + Message(role="system", text="sys"), + Message(role="user", text="hello"), + Message(role="assistant", text="world"), + ] + + projected = await apply_compaction(messages, strategy=_ExcludeOldestNonSystem()) + + assert len(projected) < len(messages) + assert projected[0].role == "system" diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 319d35f152..def12d0fe4 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -15,9 +15,27 @@ SupportsChatGetResponse, tool, ) +from agent_framework._compaction import ( + EXCLUDED_KEY, + GROUP_ANNOTATION_KEY, + GROUP_ID_KEY, + CharacterEstimatorTokenizer, + SlidingWindowStrategy, + TokenBudgetComposedStrategy, + annotate_message_groups, + included_token_count, +) from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination +def _group_id(message: Message) -> str | None: + annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + if not isinstance(annotation, dict): + return None + value = annotation.get(GROUP_ID_KEY) + return value if isinstance(value, str) else None + + async def test_base_client_with_function_calling(chat_client_base: SupportsChatGetResponse): exec_counter = 0 @@ -131,6 +149,127 @@ def ai_func(arg1: str) -> str: assert response.messages[3].contents[0].type == "function_result" +async def test_function_loop_applies_compaction_projection_each_model_call(chat_client_base: SupportsChatGetResponse): + @tool(name="test_function", approval_mode="never_require") + def ai_func(arg1: str) -> str: + return f"Processed {arg1}" + + class _ExcludeOldestGroupAfterFirstTurn: + async def __call__(self, messages: list[Message]) -> bool: + groups = annotate_message_groups(messages) + if len(groups) <= 1: + return False + oldest_group_id = groups[0] + changed = False + for message in messages: + if _group_id(message) == oldest_group_id: + if message.additional_properties.get(EXCLUDED_KEY) is not True: + changed = True + message.additional_properties[EXCLUDED_KEY] = True + return changed + + captured_roles: list[list[str]] = [] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + + async def _capture( + *, + messages: list[Message], + options: dict[str, Any], + **kwargs: Any, + ) -> ChatResponse: + captured_roles.append([message.role for message in messages]) + return await original(messages=messages, options=options, **kwargs) + + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base.compaction_strategy = _ExcludeOldestGroupAfterFirstTurn() # type: ignore[attr-defined] + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="done")), + ] + + await chat_client_base.get_response( + [Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]} + ) + + assert len(captured_roles) >= 2 + assert "user" in captured_roles[0] + assert "user" not in captured_roles[1] + + +async def test_function_loop_token_budget_strategy_caps_tokens_each_iteration( + chat_client_base: SupportsChatGetResponse, +): + exec_counter = 0 + token_budget = 500 + tokenizer = CharacterEstimatorTokenizer() + + @tool(name="test_function", approval_mode="never_require") + def ai_func(arg1: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Processed {arg1}. " + ("result " * 120) + + captured_token_counts: list[int] = [] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + + async def _capture( + *, + messages: list[Message], + options: dict[str, Any], + **kwargs: Any, + ) -> ChatResponse: + annotate_message_groups(messages, force_reannotate=True, tokenizer=tokenizer) + captured_token_counts.append(included_token_count(messages)) + return await original(messages=messages, options=options, **kwargs) + + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] + chat_client_base.compaction_strategy = TokenBudgetComposedStrategy( # type: ignore[attr-defined] + token_budget=token_budget, + tokenizer=tokenizer, + strategies=[SlidingWindowStrategy(keep_last_groups=2)], + ) + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], + ) + ), + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="2", name="test_function", arguments='{"arg1": "value2"}') + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="done")), + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="hello " * 160)], + options={"tool_choice": "auto", "tools": [ai_func]}, + ) + + assert response.messages[-1].text == "done" + assert exec_counter == 2 + assert len(captured_token_counts) >= 3 + assert all(token_count > 0 for token_count in captured_token_counts) + assert all(token_count <= token_budget for token_count in captured_token_counts) + + async def test_base_client_with_streaming_function_calling(chat_client_base: SupportsChatGetResponse): exec_counter = 0 diff --git a/python/packages/core/tests/core/test_skills.py b/python/packages/core/tests/core/test_skills.py index c572f4727b..e64691e655 100644 --- a/python/packages/core/tests/core/test_skills.py +++ b/python/packages/core/tests/core/test_skills.py @@ -10,7 +10,7 @@ import pytest -from agent_framework import Skill, SkillResource, SkillsProvider, SessionContext +from agent_framework import SessionContext, Skill, SkillResource, SkillsProvider from agent_framework._skills import ( DEFAULT_RESOURCE_EXTENSIONS, _create_instructions, @@ -1348,9 +1348,7 @@ class TestReadAndParseSkillFile: def test_valid_file(self, tmp_path: Path) -> None: skill_dir = tmp_path / "my-skill" skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text( - "---\nname: my-skill\ndescription: A skill.\n---\nBody.", encoding="utf-8" - ) + (skill_dir / "SKILL.md").write_text("---\nname: my-skill\ndescription: A skill.\n---\nBody.", encoding="utf-8") result = _read_and_parse_skill_file(str(skill_dir)) assert result is not None name, desc, content = result @@ -1393,7 +1391,7 @@ def test_with_description(self) -> None: def test_xml_escapes_name(self) -> None: r = SkillResource(name='ref"special', content="data") elem = _create_resource_element(r) - assert '"' in elem + assert """ in elem def test_xml_escapes_description(self) -> None: r = SkillResource(name="ref", description='Uses & "quotes"', content="data") diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index bcf3a6891b..a02310d99c 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -28,6 +28,12 @@ merge_chat_options, tool, ) +from agent_framework._compaction import ( + GROUP_ANNOTATION_KEY, + GROUP_HAS_REASONING_KEY, + GROUP_ID_KEY, + GROUP_TOKEN_COUNT_KEY, +) from agent_framework._types import ( _get_data_bytes, _get_data_bytes_as_str, @@ -1646,6 +1652,78 @@ def test_chat_message_complex_content_serialization(): assert reconstructed.contents[2].type == "function_result" +def test_message_roundtrip_preserves_compaction_annotation_dict() -> None: + message = Message( + role="assistant", + contents=[Content.from_text("Hello")], + additional_properties={ + GROUP_ANNOTATION_KEY: { + "id": "group_1", + "kind": "assistant_text", + "index": 1, + "has_reasoning": False, + "token_count": 42, + } + }, + ) + + restored = Message.from_dict(message.to_dict()) + annotation = restored.additional_properties.get(GROUP_ANNOTATION_KEY) + + assert isinstance(annotation, dict) + assert annotation[GROUP_ID_KEY] == "group_1" + assert annotation[GROUP_TOKEN_COUNT_KEY] == 42 + + +def test_content_roundtrip_preserves_compaction_annotation_dict() -> None: + content = Content.from_text( + text="Hello", + additional_properties={ + GROUP_ANNOTATION_KEY: { + "id": "group_2", + "kind": "assistant_text", + "index": 2, + "has_reasoning": False, + "token_count": None, + } + }, + ) + + restored = Content.from_dict(content.to_dict()) + annotation = restored.additional_properties.get(GROUP_ANNOTATION_KEY) + + assert isinstance(annotation, dict) + assert annotation[GROUP_ID_KEY] == "group_2" + assert annotation[GROUP_TOKEN_COUNT_KEY] is None + + +def test_chat_response_roundtrip_preserves_compaction_annotation_dict() -> None: + response = ChatResponse( + messages=[ + Message( + role="assistant", + contents=[Content.from_text("Hello")], + additional_properties={ + GROUP_ANNOTATION_KEY: { + "id": "group_3", + "kind": "assistant_text", + "index": 3, + "has_reasoning": True, + "token_count": 15, + } + }, + ) + ] + ) + + restored = ChatResponse.from_dict(response.to_dict()) + annotation = restored.messages[0].additional_properties.get(GROUP_ANNOTATION_KEY) + + assert isinstance(annotation, dict) + assert annotation[GROUP_ID_KEY] == "group_3" + assert annotation[GROUP_HAS_REASONING_KEY] is True + + def test_usage_content_serialization_with_details(): """Test UsageContent from_dict and to_dict with UsageDetails conversion.""" diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index e049dbd16e..d5a9903b93 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -524,6 +524,58 @@ def test_response_content_creation_with_reasoning() -> None: assert response.messages[0].contents[0].text == "Reasoning step" +def test_response_content_keeps_reasoning_and_function_calls_in_one_message() -> None: + """Reasoning + function calls should parse into one assistant message.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + + mock_response = MagicMock() + mock_response.output_parsed = None + mock_response.metadata = {} + mock_response.usage = None + mock_response.id = "test-id" + mock_response.model = "test-model" + mock_response.created_at = 1000000000 + + mock_reasoning_content = MagicMock() + mock_reasoning_content.text = "Reasoning step" + + mock_reasoning_item = MagicMock() + mock_reasoning_item.type = "reasoning" + mock_reasoning_item.id = "rs_123" + mock_reasoning_item.content = [mock_reasoning_content] + mock_reasoning_item.summary = [] + + mock_function_call_item_1 = MagicMock() + mock_function_call_item_1.type = "function_call" + mock_function_call_item_1.id = "fc_1" + mock_function_call_item_1.call_id = "call_1" + mock_function_call_item_1.name = "tool_1" + mock_function_call_item_1.arguments = '{"x": 1}' + + mock_function_call_item_2 = MagicMock() + mock_function_call_item_2.type = "function_call" + mock_function_call_item_2.id = "fc_2" + mock_function_call_item_2.call_id = "call_2" + mock_function_call_item_2.name = "tool_2" + mock_function_call_item_2.arguments = '{"y": 2}' + + mock_response.output = [ + mock_reasoning_item, + mock_function_call_item_1, + mock_function_call_item_2, + ] + + response = client._parse_response_from_openai(mock_response, options={}) # type: ignore + + assert len(response.messages) == 1 + assert response.messages[0].role == "assistant" + assert [content.type for content in response.messages[0].contents] == [ + "text_reasoning", + "function_call", + "function_call", + ] + + def test_response_content_creation_with_code_interpreter() -> None: """Test _parse_response_from_openai with code interpreter outputs.""" diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 788e96e61e..599e62d635 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload import pytest + from agent_framework import ( AgentExecutor, AgentResponse, @@ -59,30 +60,19 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.call_count += 1 if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[ - Content.from_text( - text=f"Response #{self.call_count}: {self.name}" - ) - ] + contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")] ) return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) async def _run() -> AgentResponse: - return AgentResponse( - messages=[ - Message("assistant", [f"Response #{self.call_count}: {self.name}"]) - ] - ) + return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])]) return _run() @@ -120,10 +110,7 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -138,9 +125,9 @@ async def _mark_result_hook_called( self.result_hook_called = True return response - return ResponseStream( - _stream(), finalizer=AgentResponse.from_updates - ).with_result_hook(_mark_result_hook_called) + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook( + _mark_result_hook_called + ) async def _run() -> AgentResponse: return AgentResponse(messages=[Message("assistant", ["hook test"])]) @@ -148,9 +135,7 @@ async def _run() -> AgentResponse: return _run() -async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> ( - None -): +async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None: """AgentExecutor should call get_final_response() so stream result hooks execute.""" agent = _StreamingHookAgent(id="hook_agent", name="HookAgent") executor = AgentExecutor(agent, id="hook_exec") @@ -217,9 +202,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: executor_state = executor_states[executor.id] # type: ignore[index] assert "cache" in executor_state, "Checkpoint should store executor cache state" - assert "agent_session" in executor_state, ( - "Checkpoint should store executor session state" - ) + assert "agent_session" in executor_state, "Checkpoint should store executor session state" # Verify session state structure session_state = executor_state["agent_session"] # type: ignore[index] @@ -240,15 +223,11 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert restored_agent.call_count == 0 # Build new workflow with the restored executor - wf_resume = SequentialBuilder( - participants=[restored_executor], checkpoint_storage=storage - ).build() + wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build() # Resume from checkpoint resumed_output: AgentExecutorResponse | None = None - async for ev in wf_resume.run( - checkpoint_id=restore_checkpoint.checkpoint_id, stream=True - ): + async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True): if ev.type == "output": resumed_output = ev.data # type: ignore[assignment] if ev.type == "status" and ev.state in ( @@ -391,11 +370,7 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( assert options is not None assert options["additional_function_arguments"]["custom"] == 1 - warned_keys = { - r.message.split("'")[1] - for r in caplog.records - if "reserved" in r.message.lower() - } + warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} assert warned_keys == {"session", "stream", "messages"} diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index 07d1e64c08..633ba1072c 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -16,10 +16,31 @@ def __init__(self, agent_id: str, name: str | None = None) -> None: self.description: str | None = None @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session for the agent.""" diff --git a/python/packages/core/tests/workflow/test_edge.py b/python/packages/core/tests/workflow/test_edge.py index ecaa341726..422d530631 100644 --- a/python/packages/core/tests/workflow/test_edge.py +++ b/python/packages/core/tests/workflow/test_edge.py @@ -4,9 +4,8 @@ from typing import Any from unittest.mock import patch -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - import pytest +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from agent_framework import ( Executor, diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 77827c0634..77777e198b 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -3,6 +3,8 @@ from dataclasses import dataclass import pytest +from typing_extensions import Never + from agent_framework import ( Executor, Message, @@ -14,7 +16,6 @@ handler, response_handler, ) -from typing_extensions import Never # Module-level types for string forward reference tests @@ -155,11 +156,7 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, collector).build() events = await workflow.run("hello world") - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] assert len(invoked_events) == 2 @@ -193,16 +190,10 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: sender = MultiSenderExecutor(id="sender") collector = CollectorExecutor(id="collector") - workflow = ( - WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() - ) + workflow = WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() events = await workflow.run("hello") - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] # Sender should have completed with the sent messages sender_completed = next(e for e in completed_events if e.executor_id == "sender") @@ -210,9 +201,7 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: assert sender_completed.data == ["hello-first", "hello-second"] # Collector should have completed with no sent messages (None) - collector_completed_events = [ - e for e in completed_events if e.executor_id == "collector" - ] + collector_completed_events = [e for e in completed_events if e.executor_id == "collector"] # Collector is called twice (once per message from sender) assert len(collector_completed_events) == 2 for collector_completed in collector_completed_events: @@ -231,11 +220,7 @@ async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: workflow = WorkflowBuilder(start_executor=executor).build() events = await workflow.run("test") - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] assert len(completed_events) == 1 assert completed_events[0].executor_id == "yielder" @@ -263,9 +248,7 @@ class Response: class ProcessorExecutor(Executor): @handler - async def handle( - self, request: Request, ctx: WorkflowContext[Response] - ) -> None: + async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None: response = Response(results=[request.query.upper()] * request.limit) await ctx.send_message(response) @@ -277,23 +260,13 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: processor = ProcessorExecutor(id="processor") collector = CollectorExecutor(id="collector") - workflow = ( - WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() - ) + workflow = WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() input_request = Request(query="hello", limit=3) events = await workflow.run(input_request) - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] # Check processor invoked event has the Request object processor_invoked = next(e for e in invoked_events if e.executor_id == "processor") @@ -302,9 +275,7 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: assert processor_invoked.data.limit == 3 # Check processor completed event has the Response object - processor_completed = next( - e for e in completed_events if e.executor_id == "processor" - ) + processor_completed = next(e for e in completed_events if e.executor_id == "processor") assert processor_completed.data is not None assert len(processor_completed.data) == 1 assert isinstance(processor_completed.data[0], Response) @@ -390,9 +361,7 @@ async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: # Test executor with union workflow output types class UnionWorkflowOutputExecutor(Executor): @handler - async def handle( - self, text: str, ctx: WorkflowContext[int, str | bool] - ) -> None: + async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: pass executor = UnionWorkflowOutputExecutor(id="union_workflow_output") @@ -403,15 +372,11 @@ async def handle( # Test executor with multiple handlers having different workflow output types class MultiHandlerWorkflowExecutor(Executor): @handler - async def handle_string( - self, text: str, ctx: WorkflowContext[int, str] - ) -> None: + async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None: pass @handler - async def handle_number( - self, num: int, ctx: WorkflowContext[bool, float] - ) -> None: + async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: pass executor = MultiHandlerWorkflowExecutor(id="multi_workflow") @@ -465,9 +430,7 @@ async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: pass @response_handler - async def handle_response( - self, original_request: str, response: bool, ctx: WorkflowContext[float] - ) -> None: + async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None: pass executor = RequestResponseExecutor(id="request_response") @@ -574,9 +537,7 @@ async def test_executor_invoked_event_data_not_mutated_by_handler(): """Test that executor_invoked event (type='executor_invoked').data captures original input, not mutated input.""" @executor(id="Mutator") - async def mutator( - messages: list[Message], ctx: WorkflowContext[list[Message]] - ) -> None: + async def mutator(messages: list[Message], ctx: WorkflowContext[list[Message]]) -> None: # The handler mutates the input list by appending new messages original_len = len(messages) messages.append(Message(role="assistant", text="Added by executor")) @@ -591,11 +552,7 @@ async def mutator( events = await workflow.run(input_messages) # Find the invoked event for the Mutator executor - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] assert len(invoked_events) == 1 mutator_invoked = invoked_events[0] @@ -672,12 +629,8 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess] # Verify can_handle - assert exec_instance.can_handle( - WorkflowMessage(data={"key": "value"}, source_id="mock") - ) - assert not exec_instance.can_handle( - WorkflowMessage(data="string", source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock")) + assert not exec_instance.can_handle(WorkflowMessage(data="string", source_id="mock")) def test_handler_with_explicit_union_input_type(self): """Test that explicit union input_type is handled correctly.""" @@ -698,9 +651,7 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) assert exec_instance.can_handle(WorkflowMessage(data=42, source_id="mock")) # Cannot handle float - assert not exec_instance.can_handle( - WorkflowMessage(data=3.14, source_id="mock") - ) + assert not exec_instance.can_handle(WorkflowMessage(data=3.14, source_id="mock")) def test_handler_with_explicit_union_output_type(self): """Test that explicit union output is normalized to a list.""" @@ -776,9 +727,7 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: class OnlyWorkflowOutputExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler(workflow_output=bool) - async def handle( - self, message: str, ctx: WorkflowContext[int, str] - ) -> None: + async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None: pass def test_handler_explicit_input_type_allows_no_message_annotation(self): @@ -803,9 +752,7 @@ async def handle_explicit(self, message, ctx: WorkflowContext) -> None: # type: pass @handler - async def handle_introspected( - self, message: float, ctx: WorkflowContext[bool] - ) -> None: + async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None: pass exec_instance = MixedExecutor(id="mixed") @@ -831,9 +778,7 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n # Should resolve the string to the actual type assert ForwardRefMessage in exec_instance._handlers # pyright: ignore[reportPrivateUsage] - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock")) def test_handler_with_string_forward_reference_union(self): """Test that string forward references work with union types.""" @@ -846,12 +791,8 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n exec_instance = StringUnionExecutor(id="string_union") # Should handle both types - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock") - ) - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock")) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock")) def test_handler_with_string_forward_reference_output_type(self): """Test that string forward references work for output_type.""" @@ -890,9 +831,7 @@ def test_handler_with_explicit_workflow_output_and_output(self): class PrecedenceExecutor(Executor): @handler(input=int, output=float, workflow_output=str) - async def handle( - self, message: int, ctx: WorkflowContext[int, bool] - ) -> None: + async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: pass exec_instance = PrecedenceExecutor(id="precedence") @@ -958,9 +897,7 @@ class StringUnionWorkflowOutputExecutor(Executor): async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - exec_instance = StringUnionWorkflowOutputExecutor( - id="string_union_workflow_output" - ) + exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output") # Should resolve both types from string union assert ForwardRefTypeA in exec_instance.workflow_output_types @@ -971,14 +908,10 @@ def test_handler_fallback_to_introspection_for_workflow_output_type(self): class IntrospectedWorkflowOutputExecutor(Executor): @handler - async def handle( - self, message: str, ctx: WorkflowContext[int, bool] - ) -> None: + async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None: pass - exec_instance = IntrospectedWorkflowOutputExecutor( - id="introspected_workflow_output" - ) + exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output") # Should use introspected types from WorkflowContext[int, bool] assert int in exec_instance.output_types diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index b5a8bb9902..eacf70c6db 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -717,9 +717,23 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession return AgentSession() @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -813,9 +827,23 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession return AgentSession() @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 0850c6b060..d315f75f85 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -52,9 +52,23 @@ def __init__(self, name: str = "test_agent") -> None: self.captured_kwargs = [] @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -90,9 +104,23 @@ def __init__(self, name: str = "options_agent") -> None: self.captured_kwargs = [] @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -475,9 +503,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -538,9 +580,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -605,9 +661,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index 34c7e8c93f..bf2e277d10 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -38,7 +38,9 @@ async def test_executor_failed_and_workflow_failed_events_streaming(): events.append(ev) # executor_failed event (type='executor_failed') should be emitted before workflow failed event - executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [ + e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed" + ] assert executor_failed_events, "executor_failed event should be emitted when start executor fails" assert executor_failed_events[0].executor_id == "f" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK @@ -96,7 +98,9 @@ async def test_executor_failed_event_from_second_executor_in_chain(): events.append(ev) # executor_failed event should be emitted for the failing executor - executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [ + e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed" + ] assert executor_failed_events, "executor_failed event should be emitted when second executor fails" assert executor_failed_events[0].executor_id == "failing" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK diff --git a/python/samples/02-agents/compaction/README.md b/python/samples/02-agents/compaction/README.md new file mode 100644 index 0000000000..49c42b174d --- /dev/null +++ b/python/samples/02-agents/compaction/README.md @@ -0,0 +1,20 @@ +# Context Compaction Samples + +This folder demonstrates context compaction patterns introduced by ADR-0019. + +## Files + +- `basics.py` — builds a local message list and applies each built-in in-run strategy. +- `advanced.py` — composes multiple strategies with `TokenBudgetComposedStrategy`. +- `custom.py` — defines a custom strategy implementing the `CompactionStrategy` protocol. +- `tiktoken_tokenizer.py` — shows a `TokenizerProtocol` implementation backed by `tiktoken`. +- `storage.py` — planned for Phase 2 (history/storage compaction and `upsert` flow). + +Run samples with: + +```bash +uv run samples/02-agents/compaction/basics.py +uv run samples/02-agents/compaction/advanced.py +uv run samples/02-agents/compaction/custom.py +uv run samples/02-agents/compaction/tiktoken_tokenizer.py +``` diff --git a/python/samples/02-agents/compaction/advanced.py b/python/samples/02-agents/compaction/advanced.py new file mode 100644 index 0000000000..7cf1fc7f39 --- /dev/null +++ b/python/samples/02-agents/compaction/advanced.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +from agent_framework import ( + CharacterEstimatorTokenizer, + ChatResponse, + Message, + SelectiveToolCallCompactionStrategy, + SlidingWindowStrategy, + SummarizationStrategy, + TokenBudgetComposedStrategy, + annotate_message_groups, + apply_compaction, + included_token_count, +) + +"""This sample demonstrates composed in-run compaction with a token budget. + +Key components: +- TokenBudgetComposedStrategy +- Sequential strategy composition +- Summarization with a SupportsChatGetResponse-compatible summarizer client +""" + + +class BudgetSummaryClient: + async def get_response( + self, + messages: list[Message], + *, + stream: bool = False, + options: dict[str, Any] | None = None, + **kwargs: Any, + ) -> ChatResponse: + summary_text = f"Budget summary generated from {len(messages)} prompt messages." + return ChatResponse(messages=[Message(role="assistant", text=summary_text)]) + + +def _build_long_history() -> list[Message]: + history = [Message(role="system", text="You are a migration copilot.")] + for i in range(1, 8): + history.append( + Message( + role="user", + text=f"Iteration {i}: capture migration requirements and edge cases.", + ) + ) + history.append( + Message( + role="assistant", + text=( + f"Iteration {i}: detailed plan with dependencies, rollback guidance, and testing details. " + "This sentence is intentionally long to create token pressure." + ), + ) + ) + return history + + +async def main() -> None: + # 1. Build synthetic history representing long-running in-run growth. + messages = _build_long_history() + + # 2. Configure tokenizer and measure token count before compaction. + tokenizer = CharacterEstimatorTokenizer() + annotate_message_groups(messages, tokenizer=tokenizer) + budget_before = included_token_count(messages) + + # 3. Configure composed strategy stack. + composed = TokenBudgetComposedStrategy( + token_budget=200, + tokenizer=tokenizer, + strategies=[ + SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0), + SummarizationStrategy( + client=BudgetSummaryClient(), + target_count=3, + threshold=3, + ), + SlidingWindowStrategy(keep_last_groups=4), + ], + ) + + # 4. Apply compaction and inspect the budget result. + projected = await apply_compaction(messages, strategy=composed, tokenizer=tokenizer) + budget_after = included_token_count(messages) + + print(f"Projected messages after compaction: {len(projected)}") + print(f"Included token count before compaction: {budget_before}") + print(f"Included token count after compaction: {budget_after}") + print("Projected roles:", [m.role for m in projected]) + print("Projected messages with token counts:") + for msg in projected: + group = msg.additional_properties.get("_group") + token_count = group.get("token_count") if isinstance(group, dict) else None + text_preview = msg.text[:80] if msg.text else "" + print(f"- [{msg.role}] {text_preview} ({token_count} tokens)") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +Projected messages after compaction: 3 +Included token count before compaction: 793 +Included token count after compaction: 144 +Projected roles: ['system', 'user', 'assistant'] +Projected messages with token counts: +- [system] You are a migration copilot. (35 tokens) +- [user] Iteration 7: capture migration requirements and edge cases. (43 tokens) +- [assistant] Iteration 7: detailed plan with dependencies, rollback guidance, and testing det (66 tokens) +""" diff --git a/python/samples/02-agents/compaction/basics.py b/python/samples/02-agents/compaction/basics.py new file mode 100644 index 0000000000..6cbc65d112 --- /dev/null +++ b/python/samples/02-agents/compaction/basics.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +from agent_framework import ( + CharacterEstimatorTokenizer, + ChatResponse, + Content, + Message, + SelectiveToolCallCompactionStrategy, + SlidingWindowStrategy, + SummarizationStrategy, + TokenBudgetComposedStrategy, + TruncationStrategy, + apply_compaction, +) + +"""This sample demonstrates selecting one compaction strategy at a time. + +How to use this sample: +- Keep one ``selected_strategy`` block active in ``main``. +- Comment the active block and uncomment one of the alternatives to switch strategies. +- Run again to compare behavior against the same "before" message list shown once. +""" + +SUMMARY_OF_MESSAGE_IDS_KEY = "_summary_of_message_ids" +SUMMARIZED_BY_SUMMARY_ID_KEY = "_summarized_by_summary_id" + +# Keep optional strategy classes imported for quick uncomment/switch in main(). +AVAILABLE_STRATEGY_TYPES = ( + TruncationStrategy, + CharacterEstimatorTokenizer, + SlidingWindowStrategy, + SelectiveToolCallCompactionStrategy, + SummarizationStrategy, + TokenBudgetComposedStrategy, +) + + +class LocalSummaryClient: + """Simple local summarizer compatible with SupportsChatGetResponse.""" + + async def get_response( + self, + messages: list[Message], + *, + stream: bool = False, + options: dict[str, Any] | None = None, + **kwargs: Any, + ) -> ChatResponse: + return ChatResponse( + messages=[ + Message(role="assistant", text=f"Summary for {len(messages)} messages.") + ] + ) + + +async def main() -> None: + # 1. Build one baseline history and print it once. + messages = [ + Message(role="system", text="You are a helpful assistant."), + Message(role="user", text="Plan a data migration."), + Message(role="assistant", text="I will gather requirements."), + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="list_tables", + arguments='{"db":"legacy"}', + ) + ], + ), + Message( + role="tool", + contents=[ + Content.from_function_result( + call_id="call_1", + result="users, orders, events", + ) + ], + ), + Message(role="assistant", text="I found three core tables."), + Message(role="user", text="Estimate effort and risks."), + Message(role="assistant", text="Primary risk is schema drift."), + ] + print("\n--- Before compaction ---") + print(f"Message count: {len(messages)}") + for index, message in enumerate(messages, start=1): + message_text = message.text or ", ".join( + content.type for content in message.contents + ) + print(f"{index:02d}. [{message.role}] {message_text}") + + # 2. Select exactly one strategy (default shown below). + # Truncate when included history exceeds 5 messages, then keep 4. + # System remains anchored, so the oldest non-system messages are removed first. + # selected_strategy_name = "TruncationStrategy" + # selected_strategy = TruncationStrategy(max_n=5, compact_to=4, preserve_system=True) + + # Keep the most recent 4 non-system groups and preserve the system anchor. + # A group represents a user turn (and related assistant/tool follow-up). + # selected_strategy_name = "SlidingWindowStrategy" + # selected_strategy = SlidingWindowStrategy(keep_last_groups=4, preserve_system=True) + + # This means all tool-call groups are removed (assistant function_call message + # plus matching tool result messages). In this example, setting to 0 removes + # the single assistant+tool pair. + selected_strategy_name = "SelectiveToolCallCompactionStrategy" + selected_strategy = SelectiveToolCallCompactionStrategy( + keep_last_tool_call_groups=0 + ) + + # Summarize older messages so only recent context remains, and attach summary + # trace metadata linking summary -> originals and originals -> summary. + # summary_client = LocalSummaryClient() + # selected_strategy_name = "SummarizationStrategy" + # selected_strategy = SummarizationStrategy( + # client=summary_client, target_count=3, threshold=2 + # ) + + # tokenizer = CharacterEstimatorTokenizer() + # selected_strategy_name = "TokenBudgetComposedStrategy" + # selected_strategy = TokenBudgetComposedStrategy( + # token_budget=150, + # tokenizer=tokenizer, + # strategies=[ + # SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0), + # SlidingWindowStrategy(keep_last_groups=2), + # ], + # ) + + # 3. Apply the selected strategy and print projected output. + projected = await apply_compaction(messages, strategy=selected_strategy) + print(f"\n--- After compaction ({selected_strategy_name}) ---") + print(f"Message count: {len(projected)}") + for index, message in enumerate(projected, start=1): + message_text = message.text or ", ".join( + content.type for content in message.contents + ) + print(f"{index:02d}. [{message.role}] {message_text}") + + summaries = [] + summarized = [] + for message in messages: + group_annotation = message.additional_properties.get("_group") + if not isinstance(group_annotation, dict): + continue + if group_annotation.get(SUMMARY_OF_MESSAGE_IDS_KEY): + summaries.append(message) + if group_annotation.get(SUMMARIZED_BY_SUMMARY_ID_KEY): + summarized.append(message) + if summaries or summarized: + print("Summary trace metadata present:") + for message in summaries: + group_annotation = message.additional_properties.get("_group") + summarized_ids = ( + group_annotation.get(SUMMARY_OF_MESSAGE_IDS_KEY) + if isinstance(group_annotation, dict) + else None + ) + print(f" summary_id={message.message_id} summarizes={summarized_ids}") + for message in summarized: + group_annotation = message.additional_properties.get("_group") + summarized_by = ( + group_annotation.get(SUMMARIZED_BY_SUMMARY_ID_KEY) + if isinstance(group_annotation, dict) + else None + ) + print(f" original_id={message.message_id} summarized_by={summarized_by}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output (always present): +--- Before compaction --- +Message count: 8 +01. [system] You are a helpful assistant. +02. [user] Plan a data migration. +03. [assistant] I will gather requirements. +04. [assistant] function_call +05. [tool] function_result +06. [assistant] I found three core tables. +07. [user] Estimate effort and risks. +08. [assistant] Primary risk is schema drift. +""" + +""" +Sample output (varies based on selected strategy): +--- After compaction (TruncationStrategy) --- +Message count: 4 +01. [system] You are a helpful assistant. +02. [assistant] I found three core tables. +03. [user] Estimate effort and risks. +04. [assistant] Primary risk is schema drift. + +--- After compaction (SlidingWindowStrategy) --- +Message count: 6 +01. [system] You are a helpful assistant. +02. [assistant] function_call +03. [tool] function_result +04. [assistant] I found three core tables. +05. [user] Estimate effort and risks. +06. [assistant] Primary risk is schema drift. + +--- After compaction (SelectiveToolCallCompactionStrategy) --- +Message count: 6 +01. [system] You are a helpful assistant. +02. [user] Plan a data migration. +03. [assistant] I will gather requirements. +04. [assistant] I found three core tables. +05. [user] Estimate effort and risks. +06. [assistant] Primary risk is schema drift. + +--- After compaction (SummarizationStrategy) --- +Message count: 5 +01. [system] You are a helpful assistant. +02. [assistant] Summary for 2 messages. +03. [assistant] I found three core tables. +04. [user] Estimate effort and risks. +05. [assistant] Primary risk is schema drift. +Summary trace metadata present: + summary_id=summary_8 summarizes=['msg_1', 'msg_2', 'msg_3', 'msg_4'] + original_id=msg_1 summarized_by=summary_8 + original_id=msg_2 summarized_by=summary_8 + original_id=msg_3 summarized_by=summary_8 + original_id=msg_4 summarized_by=summary_8 + +--- After compaction (TokenBudgetComposedStrategy) --- +Message count: 3 +01. [system] You are a helpful assistant. +02. [user] Estimate effort and risks. +03. [assistant] Primary risk is schema drift. +""" diff --git a/python/samples/02-agents/compaction/custom.py b/python/samples/02-agents/compaction/custom.py new file mode 100644 index 0000000000..ea9647b9ae --- /dev/null +++ b/python/samples/02-agents/compaction/custom.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio + +from agent_framework import ( + Message, + annotate_message_groups, + apply_compaction, + included_messages, +) + +"""This sample demonstrates authoring a custom compaction strategy. + +The custom strategy keeps system messages and the most recent user turn while +excluding older non-system groups. +""" + +EXCLUDED_KEY = "_excluded" +GROUP_ANNOTATION_KEY = "_group" + + +class KeepLastUserTurnStrategy: + async def __call__(self, messages: list[Message]) -> bool: + group_ids = annotate_message_groups(messages) + group_kinds: dict[str, str] = {} + for message in messages: + group_annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + group_id = group_annotation.get("id") if isinstance(group_annotation, dict) else None + kind = group_annotation.get("kind") if isinstance(group_annotation, dict) else None + if ( + isinstance(group_id, str) + and isinstance(kind, str) + and group_id not in group_kinds + ): + group_kinds[group_id] = kind + user_group_ids = [ + group_id for group_id in group_ids if group_kinds.get(group_id) == "user" + ] + if not user_group_ids: + return False + keep_user_group_id = user_group_ids[-1] + + changed = False + for message in messages: + group_annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY) + group_id = group_annotation.get("id") if isinstance(group_annotation, dict) else None + if message.role == "system": + continue + if group_id == keep_user_group_id: + continue + if message.additional_properties.get(EXCLUDED_KEY) is not True: + changed = True + message.additional_properties[EXCLUDED_KEY] = True + return changed + + +def _messages() -> list[Message]: + return [ + Message(role="system", text="You are concise."), + Message(role="user", text="first request"), + Message(role="assistant", text="first response"), + Message(role="user", text="second request"), + Message(role="assistant", text="second response"), + ] + + +async def main() -> None: + # 1. Build a short conversation. + messages = _messages() + print(f"Number of messages before compaction: {len(messages)}") + # 2. Apply custom strategy. + await apply_compaction(messages, strategy=KeepLastUserTurnStrategy()) + # 3. Print projected messages. + projected = included_messages(messages) + print(f"Number of messages after compaction: {len(projected)}") + for msg in projected: + print(f"[{msg.role}] {msg.text}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +Number of messages before compaction: 5 +Number of messages after compaction: 2 +[system] You are concise. +[user] second request +""" diff --git a/python/samples/02-agents/compaction/tiktoken_tokenizer.py b/python/samples/02-agents/compaction/tiktoken_tokenizer.py new file mode 100644 index 0000000000..ac282db338 --- /dev/null +++ b/python/samples/02-agents/compaction/tiktoken_tokenizer.py @@ -0,0 +1,124 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "tiktoken", +# ] +# /// +# Run with: uv run samples/02-agents/compaction/tiktoken_tokenizer.py + +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +import tiktoken +from agent_framework import ( + Message, + TokenizerProtocol, + TruncationStrategy, + annotate_message_groups, + apply_compaction, + included_token_count, +) + +"""This sample demonstrates a custom TokenizerProtocol implementation with tiktoken. + +Key components: +- `TiktokenTokenizer` backed by `tiktoken` +- Token-based `TruncationStrategy` (`max_n` / `compact_to`) +- Inspecting projected roles and remaining included token count +""" + + +class TiktokenTokenizer(TokenizerProtocol): + """TokenizerProtocol implementation backed by tiktoken's o200k_base (gpt-4.1 and up default) encoding.""" + + def __init__( + self, *, encoding_name: str = "o200k_base", model_name: str | None = None + ) -> None: + if model_name is not None: + self._encoding = tiktoken.encoding_for_model(model_name) + else: + self._encoding: Any = tiktoken.get_encoding(encoding_name) + + def count_tokens(self, text: str) -> int: + return len(self._encoding.encode(text)) + + +def _build_messages() -> list[Message]: + return [ + Message(role="system", text="You are a migration assistant."), + Message( + role="user", + text="List all migration risks and include detailed mitigations for each risk category.", + ), + Message( + role="assistant", + text=( + "Primary risks include schema drift, missing foreign key constraints, " + "and data quality regressions. Mitigations include staged validation, " + "shadow writes, and replay-based verification." + ), + ), + Message( + role="user", + text=( + "Now provide a detailed checklist with owners, rollback " + "gates, and validation criteria." + ), + ), + Message( + role="assistant", + text=( + "Checklist: baseline snapshots, migration dry-run, production " + "canary, progressive deployment, automated integrity checks, and " + "post-migration reconciliation." + ), + ), + ] + + +async def main() -> None: + # 1. Create a tokenizer implementation that uses tiktoken. + tokenizer = TiktokenTokenizer() + + # 2. Configure token-based truncation. + strategy = TruncationStrategy( + max_n=250, + compact_to=150, + tokenizer=tokenizer, + preserve_system=True, + ) + + # 3. Build conversation and measure token count before compaction. + messages = _build_messages() + annotate_message_groups(messages, tokenizer=tokenizer) + token_count_before = included_token_count(messages) + + # 4. Apply compaction and measure token count after compaction. + projected = await apply_compaction(messages, strategy=strategy, tokenizer=tokenizer) + token_count_after = included_token_count(messages) + + # 5. Print before/after token counts and projected conversation. + print(f"Projected messages: {len(projected)}") + print(f"Included token count before compaction: {token_count_before}") + print(f"Included token count after compaction: {token_count_after}") + print("Projected roles:", [message.role for message in projected]) + for message in projected: + token_count = message.additional_properties.get("_group", {}).get("token_count") + print(f"- [{message.role}] {message.text} ({token_count} tokens)") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Projected messages: 3 +Included token count before compaction: 263 +Included token count after compaction: 149 +Projected roles: ['system', 'user', 'assistant'] +- [system] You are a migration assistant. (40 tokens) +- [user] Now provide a detailed checklist with owners, rollback gates, and validation criteria. (49 tokens) +- [assistant] Checklist: baseline snapshots, migration dry-run, production canary, + progressive deployment, automated integrity checks, and post-migration reconciliation. (60 tokens) +"""