diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 8dffdc0ce6..550b8bf384 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -2,6 +2,7 @@ from __future__ import annotations +import copy import json import logging import re @@ -263,6 +264,26 @@ def __init__(self, **kwargs): DEFAULT_EXCLUDE: ClassVar[set[str]] = set() INJECTABLE: ClassVar[set[str]] = set() + _SHALLOW_COPY_FIELDS: ClassVar[set[str]] = {"raw_representation"} + + def __deepcopy__(self, memo: dict[int, Any]) -> SerializationMixin: + """Create a deep copy, preserving ``_SHALLOW_COPY_FIELDS`` by reference. + + Fields listed in ``_SHALLOW_COPY_FIELDS`` may contain LLM SDK objects + (e.g., proto/gRPC responses) that are not safe to deep-copy. They are + kept as shallow references in the copy; all other attributes are + deep-copied normally. + """ + cls = type(self) + result = cls.__new__(cls) + memo[id(self)] = result + shallow = cls._SHALLOW_COPY_FIELDS + for k, v in self.__dict__.items(): + if k in shallow: + object.__setattr__(result, k, v) + else: + object.__setattr__(result, k, copy.deepcopy(v, memo)) + return result def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: """Convert the instance and any nested objects to a dictionary. diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 7ae9dbaa3d..f058fc14c1 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -445,6 +445,8 @@ class Content: `Content.from_uri()`, etc. to create instances. """ + _SHALLOW_COPY_FIELDS: ClassVar[set[str]] = {"raw_representation"} + def __init__( self, type: ContentType, @@ -546,6 +548,23 @@ def __init__( self.approved = approved self.consent_link = consent_link + def __deepcopy__(self, memo: dict[int, Any]) -> Content: + """Create a deep copy, preserving ``_SHALLOW_COPY_FIELDS`` by reference. + + Fields listed in ``_SHALLOW_COPY_FIELDS`` may contain LLM SDK objects + (e.g., proto/gRPC responses) that are not safe to deep-copy. + """ + cls = type(self) + result = cls.__new__(cls) + memo[id(self)] = result + shallow = cls._SHALLOW_COPY_FIELDS + for k, v in self.__dict__.items(): + if k in shallow: + object.__setattr__(result, k, v) + else: + object.__setattr__(result, k, deepcopy(v, memo)) + return result + @classmethod def from_text( cls: type[ContentT], diff --git a/python/packages/core/tests/core/test_serializable_mixin.py b/python/packages/core/tests/core/test_serializable_mixin.py index 05ece1072b..8134e14680 100644 --- a/python/packages/core/tests/core/test_serializable_mixin.py +++ b/python/packages/core/tests/core/test_serializable_mixin.py @@ -427,3 +427,103 @@ def __init__(self, value: str, options: dict | None = None): assert obj.options["existing"] == "value" assert obj.options["injected"] == "option" + + def test_deepcopy_preserves_shallow_copy_fields_by_reference(self): + """Test that deepcopy keeps _SHALLOW_COPY_FIELDS fields as shallow references.""" + import copy + + class NonCopyable: + def __deepcopy__(self, memo): + raise TypeError("cannot deepcopy") + + class TestClass(SerializationMixin): + _SHALLOW_COPY_FIELDS = {"raw_representation", "other_opaque"} + + def __init__(self, items: list, raw_representation: Any = None, other_opaque: Any = None): + self.items = items + self.raw_representation = raw_representation + self.other_opaque = other_opaque + + raw = NonCopyable() + opaque = NonCopyable() + original_items = ["a", "b"] + obj = TestClass(items=original_items, raw_representation=raw, other_opaque=opaque) + cloned = copy.deepcopy(obj) + + # _SHALLOW_COPY_FIELDS fields should be the same object (shallow copy) + assert cloned.raw_representation is raw + assert cloned.other_opaque is opaque + # Normal attributes should be independent copies + assert cloned.items is not original_items + assert cloned.items == ["a", "b"] + + def test_deepcopy_deep_copies_non_shallow_copy_fields(self): + """Test that deepcopy fully copies fields not in _SHALLOW_COPY_FIELDS.""" + import copy + + class TestClass(SerializationMixin): + _SHALLOW_COPY_FIELDS = {"raw_representation"} + + def __init__(self, items: list, raw_representation: Any = None): + self.items = items + self.raw_representation = raw_representation + + original_list = ["a", "b"] + obj = TestClass(items=original_list, raw_representation="raw") + cloned = copy.deepcopy(obj) + + # list should be a new object + assert cloned.items is not original_list + assert cloned.items == ["a", "b"] + # raw_representation should be the same object + assert cloned.raw_representation is obj.raw_representation + + def test_deepcopy_deep_copies_default_exclude_fields(self): + """Test that DEFAULT_EXCLUDE fields are deep-copied unless also in _SHALLOW_COPY_FIELDS.""" + import copy + + class TestClass(SerializationMixin): + DEFAULT_EXCLUDE = {"additional_properties"} + + def __init__(self, items: list, additional_properties: dict | None = None): + self.items = items + self.additional_properties = additional_properties or {} + + original_props = {"key": "value"} + obj = TestClass(items=["a"], additional_properties=original_props) + cloned = copy.deepcopy(obj) + + # DEFAULT_EXCLUDE field should be deep-copied (independent copy) + assert cloned.additional_properties is not original_props + assert cloned.additional_properties == {"key": "value"} + + def test_deepcopy_shallow_copy_fields_override_default_exclude(self): + """Test that _SHALLOW_COPY_FIELDS controls deepcopy independently of DEFAULT_EXCLUDE.""" + import copy + + class NonCopyable: + def __deepcopy__(self, memo): + raise TypeError("cannot deepcopy") + + class TestClass(SerializationMixin): + DEFAULT_EXCLUDE = {"opaque", "additional_properties"} + _SHALLOW_COPY_FIELDS = {"opaque"} + + def __init__(self, items: list, opaque: Any = None, additional_properties: dict | None = None): + self.items = items + self.opaque = opaque + self.additional_properties = additional_properties or {} + + opaque = NonCopyable() + original_props = {"key": "value"} + obj = TestClass(items=["a"], opaque=opaque, additional_properties=original_props) + cloned = copy.deepcopy(obj) + + # Field in both DEFAULT_EXCLUDE and _SHALLOW_COPY_FIELDS: shallow-copied + assert cloned.opaque is opaque + # Field in DEFAULT_EXCLUDE only: deep-copied + assert cloned.additional_properties is not original_props + assert cloned.additional_properties == {"key": "value"} + # Normal field: deep-copied + assert cloned.items is not obj.items + assert cloned.items == ["a"] diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 0d314c1aa5..bdbe869394 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -1860,6 +1860,170 @@ def test_agent_run_response_update_all_content_types(): assert update_str.role == "user" +# region DeepCopy + + +class _NonCopyableRaw: + """Simulates an LLM SDK response object that cannot be deep-copied (e.g., proto/gRPC).""" + + def __deepcopy__(self, memo: dict) -> Any: + raise TypeError("Cannot deepcopy this object") + + +def test_content_deepcopy_preserves_raw_representation(): + """Test that deepcopy of Content keeps raw_representation by reference.""" + import copy + + raw = _NonCopyableRaw() + content = Content.from_text("hello", raw_representation=raw) + + cloned = copy.deepcopy(content) + + assert cloned.text == "hello" + assert cloned.raw_representation is raw + assert cloned.additional_properties is not content.additional_properties + + +def test_message_deepcopy_preserves_raw_representation(): + """Test that deepcopy of Message keeps raw_representation by reference.""" + import copy + + raw = _NonCopyableRaw() + msg = Message("assistant", ["hello"], raw_representation=raw) + + cloned = copy.deepcopy(msg) + + assert cloned.text == "hello" + assert cloned.raw_representation is raw + assert cloned.contents is not msg.contents + + +def test_agent_response_deepcopy_preserves_raw_representation(): + """Test that deepcopy of AgentResponse keeps raw_representation by reference.""" + import copy + + raw = _NonCopyableRaw() + response = AgentResponse( + messages=[Message("assistant", ["test"])], + raw_representation=raw, + ) + + cloned = copy.deepcopy(response) + + assert cloned.text == "test" + assert cloned.raw_representation is raw + assert cloned.messages is not response.messages + + +def test_chat_response_deepcopy_preserves_raw_representation(): + """Test that deepcopy of ChatResponse keeps raw_representation by reference.""" + import copy + + raw = _NonCopyableRaw() + response = ChatResponse( + messages=[Message("assistant", ["test"])], + raw_representation=raw, + ) + + cloned = copy.deepcopy(response) + + assert cloned.text == "test" + assert cloned.raw_representation is raw + assert cloned.messages is not response.messages + + +def test_chat_response_update_deepcopy_preserves_raw_representation(): + """Test that deepcopy of ChatResponseUpdate keeps raw_representation by reference.""" + import copy + + raw = _NonCopyableRaw() + update = ChatResponseUpdate( + contents=[Content.from_text("hello")], + role="assistant", + raw_representation=raw, + ) + + cloned = copy.deepcopy(update) + + assert cloned.text == "hello" + assert cloned.raw_representation is raw + assert cloned.contents is not update.contents + + +def test_agent_response_update_deepcopy_preserves_raw_representation(): + """Test that deepcopy of AgentResponseUpdate keeps raw_representation by reference.""" + import copy + + raw = _NonCopyableRaw() + update = AgentResponseUpdate( + contents=[Content.from_text("hello")], + role="assistant", + raw_representation=raw, + ) + + cloned = copy.deepcopy(update) + + assert cloned.text == "hello" + assert cloned.raw_representation is raw + assert cloned.contents is not update.contents + + +def test_nested_deepcopy_preserves_raw_representation(): + """Test that deepcopy of an AgentResponse with nested Message raw_representations works.""" + import copy + + raw_msg = _NonCopyableRaw() + raw_response = _NonCopyableRaw() + response = AgentResponse( + messages=[Message("assistant", ["hello"], raw_representation=raw_msg)], + raw_representation=raw_response, + ) + + cloned = copy.deepcopy(response) + + assert cloned.raw_representation is raw_response + assert cloned.messages[0].raw_representation is raw_msg + assert cloned.messages is not response.messages + assert cloned.text == "hello" + + +def test_content_deepcopy_shallow_copy_fields_identity(): + """Test that Content._SHALLOW_COPY_FIELDS fields are identity-preserved while others are deep-copied.""" + import copy + + raw = _NonCopyableRaw() + content = Content.from_text("hello", raw_representation=raw) + content.additional_properties["key"] = "value" + + cloned = copy.deepcopy(content) + + # _SHALLOW_COPY_FIELDS (raw_representation) should be same object + assert cloned.raw_representation is raw + # Non-shallow fields should be independent deep copies + assert cloned.additional_properties is not content.additional_properties + assert cloned.additional_properties == {"key": "value"} + + +def test_chat_response_deepcopy_deep_copies_additional_properties(): + """Test that ChatResponse deepcopy deep-copies additional_properties despite it being in DEFAULT_EXCLUDE.""" + import copy + + response = ChatResponse( + messages=[Message("assistant", ["test"])], + additional_properties={"key": [1, 2, 3]}, + ) + + cloned = copy.deepcopy(response) + + # additional_properties is in DEFAULT_EXCLUDE for serialization but not in _SHALLOW_COPY_FIELDS, + # so it should be deep-copied (independent copy) + assert cloned.additional_properties is not response.additional_properties + assert cloned.additional_properties == {"key": [1, 2, 3]} + + +# endregion + + # region Serialization diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 599e62d635..059e683745 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -383,3 +383,60 @@ async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None: result = await workflow.run("hello", messages=["stale"]) assert result is not None assert agent.call_count == 1 + + +class _NonCopyableRaw: + """Simulates an LLM SDK response object that cannot be deep-copied (e.g., proto/gRPC).""" + + def __deepcopy__(self, memo: dict) -> Any: + raise TypeError("Cannot deepcopy this object") + + +class _AgentWithRawRepr(BaseAgent): + """Agent that returns responses with a non-copyable raw_representation.""" + + def __init__(self, raw: Any, **kwargs: Any): + super().__init__(**kwargs) + self._raw = raw + + def run( + self, + messages: str | Message | list[str] | list[Message] | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + async def _run() -> AgentResponse: + return AgentResponse( + messages=[Message("assistant", [f"reply from {self.name}"])], + raw_representation=self._raw, + ) + + return _run() + + +async def test_agent_executor_workflow_with_non_copyable_raw_representation() -> None: + """Workflow should complete when AgentResponse contains a raw_representation that cannot be deep-copied.""" + raw = _NonCopyableRaw() + + agent_a = _AgentWithRawRepr(raw=raw, id="a", name="AgentA") + agent_b = _CountingAgent(id="b", name="AgentB") + + exec_a = AgentExecutor(agent_a, id="exec_a") + exec_b = AgentExecutor(agent_b, id="exec_b") + + workflow = SequentialBuilder(participants=[exec_a, exec_b]).build() + events = await workflow.run("hello") + + completed = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] + completed_a = [e for e in completed if e.executor_id == "exec_a"] + + assert len(completed_a) == 1 + assert completed_a[0].data is not None + + # The yielded AgentResponse should preserve its raw_representation reference + agent_responses = [d for d in completed_a[0].data if isinstance(d, AgentResponse)] + assert len(agent_responses) > 0 + assert agent_responses[0].text == "reply from AgentA" + assert agent_responses[0].raw_representation is raw