diff --git a/pyproject.toml b/pyproject.toml index 14fc521c8..8edc523fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,9 @@ dev = [ "types-PyYAML>=6.0.12.9", ] +[tool.pytest.ini_options] +pythonpath = ["."] + [tool.mypy] plugins = [] ignore_missing_imports = true diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index 9904e2b95..dd876ca58 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -6,7 +6,6 @@ from sklearn.metrics.pairwise import cosine_similarity from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ConversationMessageWithSimilarity, EmbeddingMessageWithSimilarity -from pyrit.memory.memory_models import ConversationData, EmbeddingData class ConversationAnalytics: @@ -24,11 +23,11 @@ def __init__(self, *, memory_interface: MemoryInterface): """ self.memory_interface = memory_interface - def get_similar_chat_messages_by_content( + def get_prompt_entries_with_same_converted_content( self, *, chat_message_content: str ) -> list[ConversationMessageWithSimilarity]: """ - Retrieves chat messages that are similar to the given content based on exact matches. + Retrieves chat messages that have the same converted content Args: chat_message_content (str): The content of the chat message to find similar messages for. @@ -37,16 +36,16 @@ def get_similar_chat_messages_by_content( list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on content. """ - all_memories = self.memory_interface.get_all_memory(ConversationData) + all_memories = self.memory_interface.get_all_prompt_entries() similar_messages = [] for memory in all_memories: - if memory.content == chat_message_content: + if memory.converted_prompt_text == chat_message_content: similar_messages.append( ConversationMessageWithSimilarity( score=1.0, role=memory.role, - content=memory.content, + content=memory.converted_prompt_text, metric="exact_match", # Exact match ) ) @@ -67,12 +66,13 @@ def get_similar_chat_messages_by_embedding( List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on embedding similarity. """ - all_memories = self.memory_interface.get_all_memory(EmbeddingData) + + all_embdedding_memory = self.memory_interface.get_all_embeddings() similar_messages = [] target_embedding = np.array(chat_message_embedding).reshape(1, -1) - for memory in all_memories: + for memory in all_embdedding_memory: if not hasattr(memory, "embedding") or memory.embedding is None: continue @@ -82,7 +82,7 @@ def get_similar_chat_messages_by_embedding( if similarity_score >= threshold: similar_messages.append( EmbeddingMessageWithSimilarity( - score=similarity_score, uuid=memory.uuid, metric="cosine_similarity" # type: ignore + score=similarity_score, uuid=memory.id, metric="cosine_similarity" # type: ignore ) ) diff --git a/pyrit/common/print.py b/pyrit/common/print.py index f7f71142d..490c7e459 100644 --- a/pyrit/common/print.py +++ b/pyrit/common/print.py @@ -4,13 +4,13 @@ import textwrap import termcolor -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.models import ChatMessage from termcolor._types import Color def print_chat_messages_with_color( - messages: list[ChatMessage | ConversationData], + messages: list[ChatMessage | PromptMemoryEntry], max_content_character_width: int = 80, left_padding_width: int = 20, custom_colors: dict[str, Color] = None, diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 6b927adb2..226a45311 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory.duckdb_memory import DuckDBMemory from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter -__all__ = ["ConversationData", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"] +__all__ = ["PromptMemoryEntry", "EmbeddingData", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"] diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 854a57c83..8bb5efc29 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -12,10 +12,8 @@ from sqlalchemy.engine.base import Engine from contextlib import closing -from pyrit.memory.memory_models import ConversationData, Base -from pyrit.memory.memory_embedding import default_memory_embedding_factory +from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base from pyrit.memory.memory_interface import MemoryInterface -from pyrit.interfaces import EmbeddingSupport from pyrit.common.path import RESULTS_PATH from pyrit.common.singleton import Singleton @@ -33,15 +31,19 @@ class DuckDBMemory(MemoryInterface, metaclass=Singleton): DEFAULT_DB_FILE_NAME = "pyrit_duckdb_storage.db" def __init__( - self, *, db_path: Union[Path, str] = None, embedding_model: EmbeddingSupport = None, has_echo: bool = False + self, + *, + db_path: Union[Path, str] = None, + verbose: bool = False, ): super(DuckDBMemory, self).__init__() - self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) + if db_path == ":memory:": self.db_path: Union[Path, str] = ":memory:" else: self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve() - self.engine = self._create_engine(has_echo=has_echo) + + self.engine = self._create_engine(has_echo=verbose) self.SessionFactory = sessionmaker(bind=self.engine) self._create_tables_if_not_exist() @@ -76,6 +78,104 @@ def _create_tables_if_not_exist(self): except Exception as e: logger.error(f"Error during table creation: {e}") + def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: + """ + Fetches all entries from the specified table and returns them as model instances. + """ + result = self.query_entries(PromptMemoryEntry) + return result + + def get_all_embeddings(self) -> list[EmbeddingData]: + """ + Fetches all entries from the specified table and returns them as model instances. + """ + result = self.query_entries(EmbeddingData) + return result + + def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: + """ + Retrieves a list of ConversationData objects that have the specified conversation ID. + + Args: + conversation_id (str): The conversation ID to filter the table. + + Returns: + list[ConversationData]: A list of ConversationData objects matching the specified conversation ID. + """ + try: + return self.query_entries( + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id + ) + except Exception as e: + logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") + return [] + + def get_prompt_entries_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: + """ + Retrieves a list of ConversationData objects that have the specified normalizer ID. + + Args: + normalizer_id (str): The normalizer ID to filter the table. + + Returns: + list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. + """ + try: + return self.query_entries( + PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op("->>")("normalizer_id") == normalizer_id + ) + except Exception as e: + logger.exception( + f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" + ) + return [] + + def insert_prompt_entries(self, *, entries: list[PromptMemoryEntry]) -> None: + """ + Inserts a list of prompt entries into the memory storage. + If necessary, generates embedding data for applicable entries + + Args: + entries (list[Base]): The list of database model instances to be inserted. + """ + embedding_entries = [] + + if self.memory_embedding: + for chat_entry in entries: + embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) + embedding_entries.append(embedding_entry) + + # The ordering of this is weird because after memories are inserted, we lose the reference to them + # and also entries must be inserted before embeddings because of the foreing key constraint + self.insert_entries(entries=entries) + + if embedding_entries: + self.insert_entries(entries=embedding_entries) + + def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool: + """ + Updates entries for a given conversation ID with the specified field values. + + Args: + conversation_id (str): The conversation ID of the entries to be updated. + update_fields (dict): A dictionary of field names and their new values. + + Returns: + bool: True if the update was successful, False otherwise. + """ + # Fetch the relevant entries using query_entries + entries_to_update = self.query_entries( + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id + ) + + # Check if there are entries to update + if not entries_to_update: + logger.info(f"No entries found with conversation_id {conversation_id} to update.") + return False + + # Use the utility function to update the entries + return self.update_entries(entries=entries_to_update, update_fields=update_fields) + def get_all_table_models(self) -> list[Base]: # type: ignore """ Returns a list of all table models used in the database by inspecting the Base registry. @@ -159,70 +259,23 @@ def update_entries(self, *, entries: list[Base], update_fields: dict) -> bool: logger.exception(f"Error updating entries: {e}") return False - def get_all_memory(self, model: Base) -> list[Base]: # type: ignore - """ - Fetches all entries from the specified table and returns them as model instances. + def export_all_tables(self, *, export_type: str = "json"): """ - result = self.query_entries(model) - return result + Exports all table data using the specified exporter. - def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[ConversationData]: - """ - Retrieves a list of ConversationData objects that have the specified conversation ID. + Iterates over all tables, retrieves their data, and exports each to a file named after the table. Args: - conversation_id (str): The conversation ID to filter the table. - - Returns: - list[ConversationData]: A list of ConversationData objects matching the specified conversation ID. - """ - try: - return self.query_entries(ConversationData, conditions=ConversationData.conversation_id == conversation_id) - except Exception as e: - logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") - return [] - - def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[ConversationData]: + export_type (str): The format to export the data in (defaults to "json"). """ - Retrieves a list of ConversationData objects that have the specified normalizer ID. - - Args: - normalizer_id (str): The normalizer ID to filter the table. - - Returns: - list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. - """ - try: - return self.query_entries(ConversationData, conditions=ConversationData.normalizer_id == normalizer_id) - except Exception as e: - logger.exception( - f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" - ) - return [] - - def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool: - """ - Updates entries for a given conversation ID with the specified field values. - - Args: - conversation_id (str): The conversation ID of the entries to be updated. - update_fields (dict): A dictionary of field names and their new values. - - Returns: - bool: True if the update was successful, False otherwise. - """ - # Fetch the relevant entries using query_entries - entries_to_update = self.query_entries( - ConversationData, conditions=ConversationData.conversation_id == conversation_id - ) - - # Check if there are entries to update - if not entries_to_update: - logger.info(f"No entries found with conversation_id {conversation_id} to update.") - return False - - # Use the utility function to update the entries - return self.update_entries(entries=entries_to_update, update_fields=update_fields) + table_models = self.get_all_table_models() + + for model in table_models: + data = self.query_entries(model) + table_name = model.__tablename__ + file_extension = f".{export_type}" + file_path = RESULTS_PATH / f"{table_name}{file_extension}" + self.exporter.export_data(data, file_path=file_path, export_type=export_type) def dispose_engine(self): """ diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index 9dfede46e..3dc03ccda 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -4,7 +4,7 @@ import os from pyrit.embedding.azure_text_embedding import AzureTextEmbedding from pyrit.interfaces import EmbeddingSupport -from pyrit.memory.memory_models import ConversationData, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData class MemoryEmbedding: @@ -20,7 +20,7 @@ def __init__(self, *, embedding_model: EmbeddingSupport): raise ValueError("embedding_model must be set.") self.embedding_model = embedding_model - def generate_embedding_memory_data(self, *, chat_memory: ConversationData) -> EmbeddingData: + def generate_embedding_memory_data(self, *, chat_memory: PromptMemoryEntry) -> EmbeddingData: """ Generates metadata for a chat memory entry. @@ -30,12 +30,17 @@ def generate_embedding_memory_data(self, *, chat_memory: ConversationData) -> Em Returns: ConversationMemoryEntryMetadata: The generated metadata. """ - embedding_data = EmbeddingData( - embedding=self.embedding_model.generate_text_embedding(text=chat_memory.content).data[0].embedding, - embedding_type_name=self.embedding_model.__class__.__name__, - uuid=chat_memory.uuid, - ) - return embedding_data + if chat_memory.converted_prompt_data_type == "text": + embedding_data = EmbeddingData( + embedding=self.embedding_model.generate_text_embedding(text=chat_memory.converted_prompt_text) + .data[0] + .embedding, + embedding_type_name=self.embedding_model.__class__.__name__, + id=chat_memory.id, + ) + return embedding_data + + raise ValueError("Only text data is supported for embedding.") def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) -> MemoryEmbedding | None: @@ -49,4 +54,6 @@ def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) - model = AzureTextEmbedding(api_key=api_key, endpoint=api_base, deployment=deployment) return MemoryEmbedding(embedding_model=model) else: - return None + raise ValueError( + "No embedding model was provided and no Azure OpenAI embedding model was found in the environment." + ) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c43cae33a..f147c046c 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2,13 +2,10 @@ # Licensed under the MIT license. import abc -from hashlib import sha256 -from typing import Optional from pathlib import Path -from uuid import uuid4 - -from pyrit.memory.memory_models import Base, ConversationData +from pyrit.memory.memory_embedding import default_memory_embedding_factory +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.models import ChatMessage @@ -32,14 +29,26 @@ def __init__(self, embedding_model=None): # Initialize the MemoryExporter instance self.exporter = MemoryExporter() + def enable_embedding(self, embedding_model=None): + self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) + + def disable_embedding(self): + self.memory_embedding = None + @abc.abstractmethod - def get_all_memory(self, model: Base) -> list[ConversationData]: # type: ignore + def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: """ Loads all ConversationData from the memory storage handler. """ @abc.abstractmethod - def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[ConversationData]: + def get_all_embeddings(self) -> list[EmbeddingData]: + """ + Loads all EmbeddingData from the memory storage handler. + """ + + @abc.abstractmethod + def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified conversation ID. @@ -51,7 +60,7 @@ def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[Con """ @abc.abstractmethod - def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[ConversationData]: + def get_prompt_entries_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified normalizer ID. @@ -63,50 +72,46 @@ def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[Convers """ @abc.abstractmethod - def insert_entries(self, *, entries: list[Base]) -> None: # type: ignore + def insert_prompt_entries(self, *, entries: list[EmbeddingData]) -> None: """ Inserts a list of entries into the memory storage. + If necessary, generates embedding data for applicable entries + Args: entries (list[Base]): The list of database model instances to be inserted. """ @abc.abstractmethod - def get_all_table_models(self) -> list[Base]: # type: ignore + def dispose_engine(self): """ - Returns a list of all table models from the database. - - Returns: - list[Base]: A list of SQLAlchemy models. + Dispose the engine and clean up resources. """ - @abc.abstractmethod - def query_entries(self, model, *, conditions: Optional = None) -> list[Base]: # type: ignore + def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> list[ChatMessage]: """ - Fetches data from the specified table model with optional conditions. + Returns the memory for a given conversation_id. Args: - model: The SQLAlchemy model class corresponding to the table you want to query. - conditions: SQLAlchemy filter conditions (optional). + conversation_id (str): The conversation ID. Returns: - List of model instances representing the rows fetched from the table. - """ - - @abc.abstractmethod - def dispose_engine(self): - """ - Dispose the engine and clean up resources. + list[ChatMessage]: The list of chat messages. """ + memory_entries = self.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) + return [ChatMessage(role=me.role, content=me.converted_prompt_text) for me in memory_entries] # type: ignore def add_chat_message_to_memory( self, conversation: ChatMessage, conversation_id: str, normalizer_id: str = None, - labels: list[str] = None, + labels: dict[str, str] = {}, ): """ + Deprecated. Will be refactored and removed soon. It currently works incorrectly. + but is included so functionality is maintained. + Adds a single chat conversation entry to the ConversationStore table. If embddings are set, add corresponding embedding entry to the EmbeddingStore table. @@ -116,16 +121,10 @@ def add_chat_message_to_memory( normalizer_id (str): The normalizer ID, labels (list[str]): A list of labels to be added to the memory entry. """ - entries_to_persist = [] - chat_entry = self._create_chat_message_memory_entry( - conversation=conversation, conversation_id=conversation_id, normalizer_id=normalizer_id, labels=labels - ) - entries_to_persist.append(chat_entry) - if self.memory_embedding: - embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) - entries_to_persist.append(embedding_entry) - self.insert_entries(entries=entries_to_persist) + self.add_chat_messages_to_memory( + conversations=[conversation], conversation_id=conversation_id, normalizer_id=normalizer_id, labels=labels + ) def add_chat_messages_to_memory( self, @@ -133,9 +132,12 @@ def add_chat_messages_to_memory( conversations: list[ChatMessage], conversation_id: str, normalizer_id: str = None, - labels: list[str] = None, + labels: dict[str, str] = {}, ): """ + Deprecated. Will be refactored and removed soon. It currently works incorrectly. + but is included so functionality is maintained. + Adds multiple chat conversation entries to the ConversationStore table. If embddings are set, add corresponding embedding entries to the EmbeddingStore table. @@ -145,82 +147,22 @@ def add_chat_messages_to_memory( normalizer_id (str): The normalizer ID labels (list[str]): A list of labels to be added to the memory entry. """ - entries_to_persist = [] + entries_to_add = [] for conversation in conversations: - chat_entry = self._create_chat_message_memory_entry( - conversation=conversation, conversation_id=conversation_id, normalizer_id=normalizer_id, labels=labels + entry = PromptMemoryEntry( + role=conversation.role, + conversation_id=conversation_id, + original_prompt_text=conversation.content, + converted_prompt_text=conversation.content, + labels=labels, ) - entries_to_persist.append(chat_entry) - if self.memory_embedding: - embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) - entries_to_persist.append(embedding_entry) - self.insert_entries(entries=entries_to_persist) + entry.labels["normalizer_id"] = normalizer_id - def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> list[ChatMessage]: - """ - Returns the memory for a given conversation_id. - - Args: - conversation_id (str): The conversation ID. - - Returns: - list[ChatMessage]: The list of chat messages. - """ - memory_entries = self.get_memories_with_conversation_id(conversation_id=conversation_id) - return [ChatMessage(role=me.role, content=me.content) for me in memory_entries] # type: ignore - - def _create_chat_message_memory_entry( - self, - *, - conversation: ChatMessage, - conversation_id: str, - normalizer_id: str = None, - labels: list[str] = None, - ): - """ - Creates a new `ConversationData` instance representing a chat message entry. - - Args: - conversation (ChatMessage): The chat message to be stored. - conversation_id (str): Conversation ID. - normalizer_id (str): Normalizer ID. - labels (list[str]): Labels associated with the conversation. - - Returns: - ConversationData: A new instance ready to be persisted in the memory storage. - """ - uuid = uuid4() - new_chat_memory = ConversationData( - role=conversation.role, - content=conversation.content, - conversation_id=conversation_id, - normalizer_id=normalizer_id, - uuid=uuid, - labels=labels if labels else [], - sha256=sha256(conversation.content.encode()).hexdigest(), - ) - - return new_chat_memory - - def export_all_tables(self, *, export_type: str = "json"): - """ - Exports all table data using the specified exporter. - - Iterates over all tables, retrieves their data, and exports each to a file named after the table. - - Args: - export_type (str): The format to export the data in (defaults to "json"). - """ - table_models = self.get_all_table_models() + entries_to_add.append(entry) - for model in table_models: - data = self.query_entries(model) - table_name = model.__tablename__ - file_extension = f".{export_type}" - file_path = RESULTS_PATH / f"{table_name}{file_extension}" - self.exporter.export_data(data, file_path=file_path, export_type=export_type) + self.insert_prompt_entries(entries=entries_to_add) def export_conversation_by_id(self, *, conversation_id: str, file_path: Path = None, export_type: str = "json"): """ @@ -232,7 +174,7 @@ def export_conversation_by_id(self, *, conversation_id: str, file_path: Path = N If not provided, a default path using RESULTS_PATH will be constructed. export_type (str): The format of the export. Defaults to "json". """ - data = self.get_memories_with_conversation_id(conversation_id=conversation_id) + data = self.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) # If file_path is not provided, construct a default using the exporter's results_path if not file_path: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 2cbd1620d..ea9c4450e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -1,88 +1,154 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from uuid import uuid4 +import hashlib import uuid + from datetime import datetime +from typing import Dict, Literal +from uuid import uuid4 from pydantic import BaseModel, ConfigDict +from sqlalchemy import Column, String, DateTime, Float, JSON, ForeignKey, Index, INTEGER, ARRAY from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, String, DateTime, Float -from sqlalchemy.dialects.postgresql import ARRAY, UUID -from sqlalchemy import ForeignKey, Index +from sqlalchemy.dialects.postgresql import UUID Base = declarative_base() -class ConversationData(Base): # type: ignore - """ - Represents the conversation data. +PromptDataType = Literal["text", "image_url"] - conversation_id is used to group messages together within a prompt_target endpoint. - It's often needed so the prompt_target knows how to construct the messages. - normalizer_id is used to group messages together within a prompt_normalizer. - A prompt_normalizer is usually a single attack, and can contain multiple prompt_targets. - It's often needed to group all the prompts in an attack together. +class PromptMemoryEntry(Base): # type: ignore + """ + Represents the prompt data. + + Because of the nature of database and sql alchemy, type ignores are abundant :) Attributes: - uuid (UUID): A unique identifier for each conversation entry, serving as the primary key. - role (String): The role associated with the message, indicating its origin - within the conversation (e.g., "user", "assistant" or "system"). - content (String): The actual text content of the conversation entry. - conversation_id (String): An identifier used to group related conversation entries. - The conversation_id is linked to a specific LLM model, - aggregating all related conversations under a single identifier. - In scenarios involving multi-turn interactions that utilize two models, - there will be two distinct conversation_ids, one for each model. - timestamp (DateTime): The timestamp when the conversation entry was created or - logged. Defaults to the current UTC time. - normalizer_id (String): An identifier used to group messages together within a prompt_normalizer. - sha256 (String): An optional SHA-256 hash of the content. - labels (ARRAY(String)): An array of labels associated with the conversation entry, - useful for categorization or filtering the final data. - idx_conversation_id (Index): An index on the `conversation_id` column to improve - query performance for operations involving obtaining conversation history based - on conversation_id. + __tablename__ (str): The name of the database table. + __table_args__ (dict): Additional arguments for the database table. + id (UUID): The unique identifier for the memory entry. + role (PromptType): system, assistant, user + conversation_id (str): The identifier for the conversation which is associated with a single target. + sequence (int): The order of the conversation within a conversation_id. + Can be the same number for multi-part requests or multi-part responses. + timestamp (DateTime): The timestamp of the memory entry. + labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. + prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios. + Because memory is how components talk with each other, this can be component specific. + e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. + converters (list[PromptConverter]): The converters for the prompt. + prompt_target (PromptTarget): The target for the prompt. + orchestrator (Orchestrator): The orchestrator for the prompt. + original_prompt_data_type (PromptDataType): The data type of the original prompt (text, image) + original_prompt_text (str): The text of the original prompt. If prompt is an image, it's a link. + original_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. + converted_prompt_data_type (PromptDataType): The data type of the converted prompt (text, image) + converted_prompt_text (str): The text of the converted prompt. If prompt is an image, it's a link. + converted_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. + idx_conversation_id (Index): The index for the conversation ID. + + Methods: + __str__(): Returns a string representation of the memory entry. """ - __tablename__ = "ConversationStore" + __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} - uuid = Column(UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid4) - role = Column(String, nullable=False) - content = Column(String) + id = Column(UUID(as_uuid=True), nullable=False, primary_key=True) + role: "Column[ChatMessageRole]" = Column(String, nullable=False) # type: ignore # noqa conversation_id = Column(String, nullable=False) - timestamp = Column(DateTime, nullable=False, default=datetime.utcnow) - normalizer_id = Column(String) - sha256 = Column(String) - labels = Column(ARRAY(String)) # type: ignore + sequence = Column(INTEGER, nullable=False) + timestamp = Column(DateTime, nullable=False) + labels: Column[Dict[str, str]] = Column(JSON) # type: ignore + prompt_metadata = Column(JSON) + converters: "Column[list[PromptConverter]]" = Column(JSON) # type: ignore # noqa + prompt_target: "Column[PromptTarget]" = Column(JSON) # type: ignore # noqa + orchestrator: "Column[Orchestrator]" = Column(JSON) # type: ignore # noqa + + original_prompt_data_type: PromptDataType = Column(String, nullable=False) # type: ignore + original_prompt_text = Column(String, nullable=False) + original_prompt_data_sha256 = Column(String) + + converted_prompt_data_type: PromptDataType = Column(String, nullable=False) # type: ignore + converted_prompt_text = Column(String) + converted_prompt_data_sha256 = Column(String) + idx_conversation_id = Index("idx_conversation_id", "conversation_id") + def __init__( + self, + *, + role: str, + original_prompt_text: str, + converted_prompt_text: str, + id: uuid.UUID = None, + conversation_id: str = None, + sequence: int = -1, + labels: Dict[str, str] = None, + prompt_metadata: JSON = None, + converters: "PromptConverterList" = None, # type: ignore # noqa + prompt_target: "PromptTarget" = None, # type: ignore # noqa + orchestrator: "Orchestrator" = None, # type: ignore # noqa + original_prompt_data_type: PromptDataType = "text", + converted_prompt_data_type: PromptDataType = "text", + ): + + self.id = id if id else uuid4() # type: ignore + + self.role = role + self.conversation_id = conversation_id if conversation_id else str(uuid4()) + self.sequence = sequence + + self.timestamp = datetime.utcnow() + self.labels = labels + self.prompt_metadata = prompt_metadata # type: ignore + + self.converters = converters.to_json() if converters else None + self.prompt_target = prompt_target.to_json() if prompt_target else None + self.orchestrator = orchestrator.to_json() if orchestrator else None + + self.original_prompt_text = original_prompt_text + self.original_prompt_data_type = original_prompt_data_type + self.original_prompt_data_sha256 = self._create_sha256(original_prompt_text) + + self.converted_prompt_data_type = converted_prompt_data_type + self.converted_prompt_text = converted_prompt_text + self.converted_prompt_data_sha256 = self._create_sha256(converted_prompt_text) + + def is_sequence_set(self) -> bool: + return self.sequence != -1 + + def _create_sha256(self, text: str) -> str: + input_bytes = text.encode("utf-8") + hash_object = hashlib.sha256(input_bytes) + return hash_object.hexdigest() + def __str__(self): - return f"{self.role}: {self.content}" + return f"{self.role}: {self.converted_prompt_text}" class EmbeddingData(Base): # type: ignore """ Represents the embedding data associated with conversation entries in the database. - Each embedding is linked to a specific conversation entry via a 'uuid'. + Each embedding is linked to a specific conversation entry via an id Attributes: - uuid (UUID): The primary key, which is a foreign key referencing the UUID in the ConversationStore table. + uuid (UUID): The primary key, which is a foreign key referencing the UUID in the MemoryEntries table. embedding (ARRAY(Float)): An array of floats representing the embedding vector. embedding_type_name (String): The name or type of the embedding, indicating the model or method used. """ - __tablename__ = "EmbeddingStore" + __tablename__ = "EmbeddingData" # Allows table redefinition if already defined. __table_args__ = {"extend_existing": True} - uuid = Column(UUID(as_uuid=True), ForeignKey(f"{ConversationData.__tablename__}.uuid"), primary_key=True) + id = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True) embedding = Column(ARRAY(Float)) embedding_type_name = Column(String) def __str__(self): - return f"{self.uuid}" + return f"{self.id}" class ConversationMessageWithSimilarity(BaseModel): diff --git a/pyrit/models.py b/pyrit/models.py index e15e91ba9..f0e8c8a37 100644 --- a/pyrit/models.py +++ b/pyrit/models.py @@ -15,10 +15,8 @@ from pydantic import BaseModel, ConfigDict -# Originally derived from this: -# https://github.com/openai/openai-python/blob/7f9e85017a0959e3ba07834880d92c748f8f67ab/src/openai/types/chat/chat_completion_role.py#L4 -ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "tool", "function"] -ChatMessageRole = Literal["system", "user", "assistant", "tool", "function"] +ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant"] +ChatMessageRole = Literal["system", "user", "assistant"] @dataclass diff --git a/pyrit/orchestrator/benchmark_orchestrator.py b/pyrit/orchestrator/benchmark_orchestrator.py index 6b34b66a9..e9634e699 100644 --- a/pyrit/orchestrator/benchmark_orchestrator.py +++ b/pyrit/orchestrator/benchmark_orchestrator.py @@ -31,7 +31,7 @@ def __init__( chat_model_under_evaluation: PromptChatTarget, scorer: QuestionAnswerScorer, memory: MemoryInterface | None = None, - memory_labels: list[str] = ["question-answering-benchmark-orchestrator"], + memory_labels: dict[str, str] = None, evaluation_prompt: str | None = None, batch_size: int = 1, verbose: bool = False, diff --git a/pyrit/orchestrator/end_token_red_teaming_orchestrator.py b/pyrit/orchestrator/end_token_red_teaming_orchestrator.py index cceb99f2c..6bd9ca9fd 100644 --- a/pyrit/orchestrator/end_token_red_teaming_orchestrator.py +++ b/pyrit/orchestrator/end_token_red_teaming_orchestrator.py @@ -23,7 +23,7 @@ def __init__( end_token: Optional[str] = RED_TEAM_CONVERSATION_END_TOKEN, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = ["red-teaming-orchestrator"], + memory_labels: dict[str, str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming target and a prompt target. diff --git a/pyrit/orchestrator/orchestrator_class.py b/pyrit/orchestrator/orchestrator_class.py index 04bc7154a..68e3c1d8e 100644 --- a/pyrit/orchestrator/orchestrator_class.py +++ b/pyrit/orchestrator/orchestrator_class.py @@ -2,10 +2,11 @@ # Licensed under the MIT license. import abc +import json import logging - from typing import Optional +from uuid import uuid4 from pyrit.memory import MemoryInterface, DuckDBMemory from pyrit.prompt_converter import PromptConverter, NoOpConverter @@ -22,14 +23,18 @@ def __init__( *, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = [], + memory_labels: dict[str, str] = {}, verbose: bool = False, ): self._prompt_converters = prompt_converters if prompt_converters else [NoOpConverter()] self._memory = memory or DuckDBMemory() - self._global_memory_labels = memory_labels self._verbose = verbose + if memory_labels: + self._global_memory_labels = memory_labels + + self._global_memory_labels = {"orchestrator": str(self.__class__.__name__)} + if self._verbose: logging.basicConfig(level=logging.INFO) @@ -63,3 +68,10 @@ def dispose_db_engine(self) -> None: Dispose DuckDB database engine to release database connections and resources. """ self._memory.dispose_engine() + + def to_json(self): + s = {} + s["__type__"] = self.__class__.__name__ + s["__module__"] = self.__class__.__module__ + s["id"] = str(uuid4()) + return json.dumps(s) diff --git a/pyrit/orchestrator/prompt_sending_orchestrator.py b/pyrit/orchestrator/prompt_sending_orchestrator.py index 7fa96d020..3af299129 100644 --- a/pyrit/orchestrator/prompt_sending_orchestrator.py +++ b/pyrit/orchestrator/prompt_sending_orchestrator.py @@ -114,4 +114,4 @@ def get_memory(self): Retrieves the memory associated with the prompt normalizer. """ id = self._prompt_normalizer.id - return self._memory.get_memories_with_normalizer_id(normalizer_id=id) + return self._memory.get_prompt_entries_with_normalizer_id(normalizer_id=id) diff --git a/pyrit/orchestrator/red_teaming_orchestrator.py b/pyrit/orchestrator/red_teaming_orchestrator.py index 7350713f7..5ca9729e7 100644 --- a/pyrit/orchestrator/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/red_teaming_orchestrator.py @@ -33,7 +33,7 @@ def __init__( initial_red_teaming_prompt: str = "Begin Conversation", prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = ["red-teaming-orchestrator"], + memory_labels: dict[str, str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming target and a prompt target. @@ -76,7 +76,7 @@ def requires_one_to_one_converters(self) -> bool: return True def get_memory(self): - return self._memory.get_memories_with_normalizer_id(normalizer_id=self._prompt_normalizer.id) + return self._memory.get_prompt_entries_with_normalizer_id(normalizer_id=self._prompt_normalizer.id) @abc.abstractmethod def is_conversation_complete(self, messages: list[ChatMessage], *, red_teaming_chat_role: str) -> bool: diff --git a/pyrit/orchestrator/scoring_red_teaming_orchestrator.py b/pyrit/orchestrator/scoring_red_teaming_orchestrator.py index 19c920888..cef57f6d9 100644 --- a/pyrit/orchestrator/scoring_red_teaming_orchestrator.py +++ b/pyrit/orchestrator/scoring_red_teaming_orchestrator.py @@ -22,7 +22,7 @@ def __init__( scorer: SupportTextClassification, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = ["red-teaming-orchestrator"], + memory_labels: dict[str, str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming bot and a prompt target. diff --git a/pyrit/prompt_converter/__init__.py b/pyrit/prompt_converter/__init__.py index e6ec88189..1ddcef9dc 100644 --- a/pyrit/prompt_converter/__init__.py +++ b/pyrit/prompt_converter/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.prompt_converter.prompt_converter import PromptConverter +from pyrit.prompt_converter.prompt_converter import PromptConverter, PromptConverterList from pyrit.prompt_converter.ascii_art_converter import AsciiArtConverter from pyrit.prompt_converter.base64_converter import Base64Converter @@ -19,6 +19,7 @@ "Base64Converter", "NoOpConverter", "PromptConverter", + "PromptConverterList", "ROT13Converter", "StringJoinConverter", "TranslationConverter", diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 05e7ad028..6808c14a5 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import abc +import json class PromptConverter(abc.ABC): @@ -26,3 +27,17 @@ def convert(self, prompts: list[str]) -> list[str]: def is_one_to_one_converter(self) -> bool: """Indicates if the conversion results in exactly one resulting prompt.""" pass + + def to_dict(self): + public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + public_attributes["__type__"] = self.__class__.__name__ + public_attributes["__module__"] = self.__class__.__module__ + return public_attributes + + +class PromptConverterList: + def __init__(self, converters: list[PromptConverter]) -> None: + self.converters = converters + + def to_json(self): + return json.dumps([converter.to_dict() for converter in self.converters]) diff --git a/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py index caa8dda05..caa41b861 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_ml_chat_target.py @@ -65,7 +65,7 @@ def __init__( self._repetition_penalty = repetition_penalty def set_system_prompt(self, *, prompt: str, conversation_id: str, normalizer_id: str) -> None: - messages = self._memory.get_memories_with_conversation_id(conversation_id=conversation_id) + messages = self._memory.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) if messages: raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") diff --git a/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py b/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py index 3d640fae2..67efbc35b 100644 --- a/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/openai_chat_target.py @@ -33,7 +33,7 @@ def __init__(self) -> None: pass def set_system_prompt(self, *, prompt: str, conversation_id: str, normalizer_id: str) -> None: - messages = self._memory.get_memories_with_conversation_id(conversation_id=conversation_id) + messages = self._memory.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) if messages: raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") diff --git a/pyrit/prompt_target/prompt_target.py b/pyrit/prompt_target/prompt_target.py index fe9b9f941..88536465d 100644 --- a/pyrit/prompt_target/prompt_target.py +++ b/pyrit/prompt_target/prompt_target.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import abc +import json from pyrit.memory import MemoryInterface, DuckDBMemory @@ -41,3 +42,9 @@ async def send_prompt_async( """ Sends a normalized prompt async to the prompt target. """ + + def to_json(self): + public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + public_attributes["__type__"] = self.__class__.__name__ + public_attributes["__module__"] = self.__class__.__module__ + return json.dumps(public_attributes) diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 64d234dd3..e5450acee 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -22,11 +22,12 @@ class TextTarget(PromptTarget): def __init__(self, *, text_stream: IO[str] = sys.stdout, memory: MemoryInterface = None) -> None: super().__init__(memory=memory) - self.text_stream = text_stream + self.stream_name = text_stream.name + self._text_stream = text_stream def send_prompt(self, *, normalized_prompt: str, conversation_id: str, normalizer_id: str) -> str: msg = ChatMessage(role="user", content=normalized_prompt) - self.text_stream.write(f"{str(msg)}\n") + self._text_stream.write(f"{str(msg)}\n") self._memory.add_chat_message_to_memory( conversation=msg, conversation_id=conversation_id, normalizer_id=normalizer_id diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index 6c31e2362..be4691fc3 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -2,12 +2,11 @@ # Licensed under the MIT license. import pytest -import uuid from unittest.mock import MagicMock from pyrit.memory.memory_interface import MemoryInterface from pyrit.analytics.conversation_analytics import ConversationAnalytics -from pyrit.memory.memory_models import ConversationData, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData @pytest.fixture @@ -19,14 +18,18 @@ def mock_memory_interface(): def test_get_similar_chat_messages_by_content(mock_memory_interface): # Mock data returned by the memory interface mock_data = [ - ConversationData(content="Hello, how are you?", role="user"), - ConversationData(content="I'm fine, thank you!", role="assistant"), - ConversationData(content="Hello, how are you?", role="assistant"), # Exact match + PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="user"), + PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="I'm fine, thank you!", role="assistant"), + PromptMemoryEntry( + original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="assistant" + ), # Exact match ] - mock_memory_interface.get_all_memory.return_value = mock_data + mock_memory_interface.get_all_prompt_entries.return_value = mock_data analytics = ConversationAnalytics(memory_interface=mock_memory_interface) - similar_messages = analytics.get_similar_chat_messages_by_content(chat_message_content="Hello, how are you?") + similar_messages = analytics.get_prompt_entries_with_same_converted_content( + chat_message_content="Hello, how are you?" + ) # Expect one exact match assert len(similar_messages) == 2 @@ -39,8 +42,8 @@ def test_get_similar_chat_messages_by_content(mock_memory_interface): def test_get_similar_chat_messages_by_embedding(mock_memory_interface): # Mock ConversationData entries conversation_entries = [ - ConversationData(uuid=uuid.uuid4(), conversation_id="1", role="user", content="Similar message"), - ConversationData(uuid=uuid.uuid4(), conversation_id="2", role="assistant", content="Different message"), + PromptMemoryEntry(original_prompt_text="h", role="user", converted_prompt_text="Similar message"), + PromptMemoryEntry(original_prompt_text="h", role="assistant", converted_prompt_text="Different message"), ] # Mock EmbeddingData entries linked to the ConversationData entries @@ -48,15 +51,14 @@ def test_get_similar_chat_messages_by_embedding(mock_memory_interface): similar_embedding = [0.1, 0.2, 0.31] # Slightly different, but should be similar different_embedding = [0.9, 0.8, 0.7] - mock_data = [ - EmbeddingData(uuid=conversation_entries[0].uuid, embedding=similar_embedding, embedding_type_name="model1"), - EmbeddingData(uuid=conversation_entries[1].uuid, embedding=different_embedding, embedding_type_name="model2"), + mock_embeddings = [ + EmbeddingData(id=conversation_entries[0].id, embedding=similar_embedding, embedding_type_name="model1"), + EmbeddingData(id=conversation_entries[1].id, embedding=different_embedding, embedding_type_name="model2"), ] - # Mock the get_all_memory method to return the mock EmbeddingData entries - mock_memory_interface.get_all_memory.side_effect = lambda model: ( - mock_data if model == EmbeddingData else conversation_entries - ) + # Mock the get_all_prompt_entries method to return the mock EmbeddingData entries + mock_memory_interface.get_all_embeddings.return_value = mock_embeddings + mock_memory_interface.get_all_prompt_entries.return_value = conversation_entries analytics = ConversationAnalytics(memory_interface=mock_memory_interface) similar_messages = analytics.get_similar_chat_messages_by_embedding( diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index aa86c2979..d38a7de5e 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -1,36 +1,30 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json +import os +from typing import Generator import pytest import uuid -import datetime from unittest.mock import MagicMock from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import inspect -from sqlalchemy.dialects.postgresql import UUID, ARRAY +from sqlalchemy import String, DateTime, INTEGER, ARRAY +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql.sqltypes import NullType -from sqlalchemy.types import String, DateTime -from pyrit.memory.memory_models import ConversationData, EmbeddingData -from pyrit.memory import DuckDBMemory +from pyrit.memory.memory_interface import MemoryInterface +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData +from pyrit.prompt_converter.base64_converter import Base64Converter +from pyrit.prompt_converter.prompt_converter import PromptConverterList +from pyrit.prompt_target.text_target import TextTarget +from tests.mocks import get_memory_interface @pytest.fixture -def setup_duckdb_database(): - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +def setup_duckdb_database() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() @pytest.fixture @@ -47,64 +41,83 @@ def mock_session(): def test_conversation_data_schema(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("ConversationStore") + columns = inspector.get_columns("PromptMemoryEntries") column_names = [col["name"] for col in columns] # Expected columns in ConversationData - expected_columns = ["uuid", "role", "content", "conversation_id", "timestamp", "normalizer_id", "sha256", "labels"] + expected_columns = [ + "id", + "role", + "conversation_id", + "sequence", + "timestamp", + "labels", + "prompt_metadata", + "converters", + "prompt_target", + "original_prompt_data_type", + "original_prompt_text", + "original_prompt_data_sha256", + "converted_prompt_data_type", + "converted_prompt_text", + "converted_prompt_data_sha256", + ] + for column in expected_columns: - assert column in column_names, f"{column} not found in ConversationStore schema." + assert column in column_names, f"{column} not found in PromptMemoryEntries schema." def test_embedding_data_schema(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("EmbeddingStore") + columns = inspector.get_columns("EmbeddingData") column_names = [col["name"] for col in columns] # Expected columns in EmbeddingData - expected_columns = ["uuid", "embedding", "embedding_type_name"] + expected_columns = ["id", "embedding", "embedding_type_name"] for column in expected_columns: - assert column in column_names, f"{column} not found in EmbeddingStore schema." + assert column in column_names, f"{column} not found in EmbeddingData schema." def test_conversation_data_column_types(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("ConversationStore") + columns = inspector.get_columns("PromptMemoryEntries") column_types = {col["name"]: type(col["type"]) for col in columns} # Expected column types in ConversationData expected_column_types = { - "uuid": UUID, + "id": UUID, "role": String, - "content": String, "conversation_id": String, + "sequence": INTEGER, "timestamp": DateTime, - "normalizer_id": String, - "sha256": String, - "labels": ARRAY, + "labels": String, + "prompt_metadata": String, + "converters": String, + "prompt_target": String, + "original_prompt_data_type": String, + "original_prompt_text": String, + "original_prompt_data_sha256": String, + "converted_prompt_data_type": String, + "converted_prompt_text": String, + "converted_prompt_data_sha256": String, } for column, expected_type in expected_column_types.items(): if column != "labels": - assert column in column_types, f"{column} not found in ConversationStore schema." + assert column in column_types, f"{column} not found in PromptMemoryEntries schema." assert issubclass( column_types[column], expected_type ), f"Expected {column} to be a subclass of {expected_type}, got {column_types[column]} instead." - # Handle 'labels' column separately - assert "labels" in column_types, "'labels' column not found in ConversationStore schema." - # Check if 'labels' column type is either NullType (due to reflection issue) or ARRAY - assert column_types["labels"] in [NullType, ARRAY], f"Unexpected type for 'labels' column: {column_types['labels']}" - def test_embedding_data_column_types(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("EmbeddingStore") + columns = inspector.get_columns("EmbeddingData") column_types = {col["name"]: col["type"].__class__ for col in columns} # Expected column types in EmbeddingData expected_column_types = { - "uuid": UUID, + "id": UUID, "embedding": ARRAY, "embedding_type_name": String, } @@ -117,7 +130,7 @@ def test_embedding_data_column_types(setup_duckdb_database): column_types[column], expected_type ), f"Expected {column} to be a subclass of {expected_type}, got {column_types[column]} instead." # Handle 'embedding' column separately - assert "embedding" in column_types, "'embedding' column not found in EmbeddingStore schema." + assert "embedding" in column_types, "'embedding' column not found in EmbeddingData schema." # Check if 'embedding' column type is either NullType (due to reflection issue) or ARRAY assert column_types["embedding"] in [ NullType, @@ -127,30 +140,64 @@ def test_embedding_data_column_types(setup_duckdb_database): def test_insert_entry(setup_duckdb_database): session = setup_duckdb_database.get_session() - entry = ConversationData( + entry = PromptMemoryEntry( + id=uuid.uuid4(), conversation_id="123", role="user", - content="Hello", - sha256="abc", + original_prompt_data_type="text", + original_prompt_text="Hello", + converted_prompt_text="Hello", ) # Use the insert_entry method to insert the entry into the database setup_duckdb_database.insert_entry(entry) # Now, get a new session to query the database and verify the entry was inserted with setup_duckdb_database.get_session() as session: - inserted_entry = session.query(ConversationData).filter_by(conversation_id="123").first() + inserted_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() assert inserted_entry is not None assert inserted_entry.role == "user" - assert inserted_entry.content == "Hello" - assert inserted_entry.sha256 == "abc" + assert inserted_entry.original_prompt_text == "Hello" + sha265 = "185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969" + assert inserted_entry.original_prompt_data_sha256 == sha265 + + +def test_insert_prompt_memories_inserts_embedding(setup_duckdb_database): + + embedding_mock = MagicMock() + setup_duckdb_database.enable_embedding(embedding_model=embedding_mock) + + with setup_duckdb_database.get_session(): + id = uuid.uuid4() + entry = PromptMemoryEntry( + id=id, + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + + setup_duckdb_database.insert_prompt_entries(entries=[entry]) + + setup_duckdb_database.dispose_engine() + + # Embedding data should be generated since we passed this model in + embedding_mock.generate_text_embedding.assert_called_once() def test_insert_entry_violates_constraint(setup_duckdb_database): # Generate a fixed UUID fixed_uuid = uuid.uuid4() # Create two entries with the same UUID - entry1 = ConversationData(uuid=fixed_uuid, conversation_id="123", role="user", content="Hello") - entry2 = ConversationData(uuid=fixed_uuid, conversation_id="456", role="user", content="Hello again") + entry1 = PromptMemoryEntry( + id=fixed_uuid, conversation_id="123", role="user", original_prompt_text="Hello", converted_prompt_text="Hello" + ) + + entry2 = PromptMemoryEntry( + id=fixed_uuid, + conversation_id="456", + role="user", + original_prompt_text="Hello again", + converted_prompt_text="Hello again", + ) # Insert the first entry with setup_duckdb_database.get_session() as session: @@ -166,7 +213,12 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): def test_insert_entries(setup_duckdb_database): entries = [ - ConversationData(conversation_id=str(i), role="user", content=f"Message {i}", sha256=f"hash{i}") + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"CMessage {i}", + ) for i in range(5) ] @@ -174,18 +226,20 @@ def test_insert_entries(setup_duckdb_database): with setup_duckdb_database.get_session() as session: # Use the insert_entries method to insert multiple entries into the database setup_duckdb_database.insert_entries(entries=entries) - inserted_entries = session.query(ConversationData).all() + inserted_entries = session.query(PromptMemoryEntry).all() assert len(inserted_entries) == 5 for i, entry in enumerate(inserted_entries): assert entry.conversation_id == str(i) assert entry.role == "user" - assert entry.content == f"Message {i}" - assert entry.sha256 == f"hash{i}" + assert entry.original_prompt_text == f"Message {i}" + assert entry.converted_prompt_text == f"CMessage {i}" def test_insert_embedding_entry(setup_duckdb_database): # Create a ConversationData entry - conversation_entry = ConversationData(conversation_id="123", role="user", content="Hello", sha256="abc") + conversation_entry = PromptMemoryEntry( + conversation_id="123", role="user", original_prompt_text="Hello", converted_prompt_text="abc" + ) # Insert the ConversationData entry using the insert_entry method setup_duckdb_database.insert_entry(conversation_entry) @@ -193,80 +247,136 @@ def test_insert_embedding_entry(setup_duckdb_database): # Re-query the ConversationData entry within a new session to ensure it's attached with setup_duckdb_database.get_session() as session: # Assuming uuid is the primary key and is set upon insertion - reattached_conversation_entry = session.query(ConversationData).filter_by(conversation_id="123").one() - uuid = reattached_conversation_entry.uuid + reattached_conversation_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").one() + uuid = reattached_conversation_entry.id # Now that we have the uuid, we can create and insert the EmbeddingData entry - embedding_entry = EmbeddingData(uuid=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") + embedding_entry = EmbeddingData(id=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") setup_duckdb_database.insert_entry(embedding_entry) # Verify the EmbeddingData entry was inserted correctly with setup_duckdb_database.get_session() as session: - persisted_embedding_entry = session.query(EmbeddingData).filter_by(uuid=uuid).first() + persisted_embedding_entry = session.query(EmbeddingData).filter_by(id=uuid).first() assert persisted_embedding_entry is not None assert persisted_embedding_entry.embedding == [1, 2, 3] assert persisted_embedding_entry.embedding_type_name == "test_type" +def test_disable_embedding(setup_duckdb_database): + setup_duckdb_database.disable_embedding() + + assert ( + setup_duckdb_database.memory_embedding is None + ), "disable_memory flag was passed, so memory embedding should be disabled." + + +def test_default_enable_embedding(setup_duckdb_database): + os.environ["AZURE_OPENAI_EMBEDDING_KEY"] = "mock_key" + os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"] = "embedding" + os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] = "deployment" + + setup_duckdb_database.enable_embedding() + + assert ( + setup_duckdb_database.memory_embedding is not None + ), "Memory embedding should be enabled when set with environment variables." + + +def test_default_embedding_raises(setup_duckdb_database): + os.environ["AZURE_OPENAI_EMBEDDING_KEY"] = "" + os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"] = "" + os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] = "" + + with pytest.raises(ValueError): + setup_duckdb_database.enable_embedding() + + def test_query_entries(setup_duckdb_database): # Insert some test data - entries = [ConversationData(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] + entries = [ + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"Message {i}", + ) + for i in range(3) + ] + setup_duckdb_database.insert_entries(entries=entries) # Query entries without conditions - queried_entries = setup_duckdb_database.query_entries(ConversationData) + queried_entries = setup_duckdb_database.query_entries(PromptMemoryEntry) assert len(queried_entries) == 3 # Query entries with a condition specific_entry = setup_duckdb_database.query_entries( - ConversationData, conditions=ConversationData.conversation_id == "1" + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "1" ) assert len(specific_entry) == 1 - assert specific_entry[0].content == "Message 1" + assert specific_entry[0].original_prompt_text == "Message 1" def test_update_entries(setup_duckdb_database): # Insert a test entry - entry = ConversationData(conversation_id="123", role="user", content="Hello") + entry = PromptMemoryEntry( + conversation_id="123", role="user", original_prompt_text="Hello", converted_prompt_text="Hello" + ) + setup_duckdb_database.insert_entry(entry) # Fetch the entry to update and update its content entries_to_update = setup_duckdb_database.query_entries( - ConversationData, conditions=ConversationData.conversation_id == "123" + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "123" + ) + setup_duckdb_database.update_entries( + entries=entries_to_update, update_fields={"original_prompt_text": "Updated Hello"} ) - setup_duckdb_database.update_entries(entries=entries_to_update, update_fields={"content": "Updated Hello"}) # Verify the entry was updated with setup_duckdb_database.get_session() as session: - updated_entry = session.query(ConversationData).filter_by(conversation_id="123").first() - assert updated_entry.content == "Updated Hello" + updated_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() + assert updated_entry.original_prompt_text == "Updated Hello" def test_get_all_memory(setup_duckdb_database): # Insert some test data - entries = [ConversationData(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] + entries = [ + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"Message {i}", + ) + for i in range(3) + ] + setup_duckdb_database.insert_entries(entries=entries) # Fetch all entries - all_entries = setup_duckdb_database.get_all_memory(ConversationData) + all_entries = setup_duckdb_database.get_all_prompt_entries() assert len(all_entries) == 3 -def test_get_memories_with_conversation_id(setup_duckdb_database): +def test_get_memories_with_json_properties(setup_duckdb_database): # Define a specific conversation_id specific_conversation_id = "test_conversation_id" + converters = PromptConverterList([Base64Converter()]) + target = TextTarget() + # Start a session with setup_duckdb_database.get_session() as session: # Create a ConversationData entry with all attributes filled - entry = ConversationData( + entry = PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", - content="Test content", - timestamp=datetime.datetime.utcnow(), - normalizer_id="test_normalizer_id", - sha256="test_sha256", - labels=["label1", "label2"], + sequence=1, + original_prompt_text="Test content", + converted_prompt_text="Test content", + labels={"normalizer_id": "id1"}, + converters=converters, + prompt_target=target, ) # Insert the ConversationData entry @@ -274,7 +384,7 @@ def test_get_memories_with_conversation_id(setup_duckdb_database): session.commit() # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = setup_duckdb_database.get_memories_with_conversation_id( + retrieved_entries = setup_duckdb_database.get_prompt_entries_with_conversation_id( conversation_id=specific_conversation_id ) @@ -283,40 +393,50 @@ def test_get_memories_with_conversation_id(setup_duckdb_database): retrieved_entry = retrieved_entries[0] assert retrieved_entry.conversation_id == specific_conversation_id assert retrieved_entry.role == "user" - assert retrieved_entry.content == "Test content" + assert retrieved_entry.original_prompt_text == "Test content" # For timestamp, you might want to check if it's close to the current time instead of an exact match assert abs((retrieved_entry.timestamp - entry.timestamp).total_seconds()) < 10 # Assuming the test runs quickly - assert retrieved_entry.normalizer_id == "test_normalizer_id" - assert retrieved_entry.sha256 == "test_sha256" - assert retrieved_entry.labels == ["label1", "label2"] + + converters = json.loads(retrieved_entry.converters) + assert len(converters) == 1 + assert converters[0]["__type__"] == "Base64Converter" + + prompt_target = json.loads(retrieved_entry.prompt_target) + assert prompt_target["__type__"] == "TextTarget" + + labels = retrieved_entry.labels + assert labels["normalizer_id"] == "id1" def test_get_memories_with_normalizer_id(setup_duckdb_database): # Define a specific normalizer_id specific_normalizer_id = "normalizer_test_id" + labels = {"normalizer_id": specific_normalizer_id} + other_labels = {"normalizer_id": "other_normalizer_id"} + # Create a list of ConversationData entries, some with the specific normalizer_id entries = [ - ConversationData( + PromptMemoryEntry( conversation_id="123", role="user", - content="Hello 1", - normalizer_id=specific_normalizer_id, - timestamp=datetime.datetime.utcnow(), + original_prompt_text="Hello 1", + converted_prompt_text="Hello 1", + labels=labels, ), - ConversationData( + PromptMemoryEntry( conversation_id="456", role="user", - content="Hello 2", - normalizer_id="other_normalizer_id", - timestamp=datetime.datetime.utcnow(), + original_prompt_text="Hello 2", + converted_prompt_text="Hello 2", + labels=other_labels, ), - ConversationData( + PromptMemoryEntry( conversation_id="789", role="user", - content="Hello 3", - normalizer_id=specific_normalizer_id, - timestamp=datetime.datetime.utcnow(), + original_prompt_text="Hello 3", + converted_prompt_text="Hello 1", + labels=labels, ), ] @@ -326,13 +446,15 @@ def test_get_memories_with_normalizer_id(setup_duckdb_database): session.commit() # Ensure all entries are committed to the database # Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id - retrieved_entries = setup_duckdb_database.get_memories_with_normalizer_id(normalizer_id=specific_normalizer_id) + retrieved_entries = setup_duckdb_database.get_prompt_entries_with_normalizer_id( + normalizer_id=specific_normalizer_id + ) # Verify that the retrieved entries match the expected normalizer_id assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id for retrieved_entry in retrieved_entries: - assert retrieved_entry.normalizer_id == specific_normalizer_id - assert "Hello" in retrieved_entry.content # Basic check to ensure content is as expected + assert retrieved_entry.labels["normalizer_id"] == specific_normalizer_id + assert "Hello" in retrieved_entry.original_prompt_text # Basic check to ensure content is as expected def test_update_entries_by_conversation_id(setup_duckdb_database): @@ -341,20 +463,23 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): # Create a list of ConversationData entries, some with the specific conversation_id entries = [ - ConversationData( + PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", - content="Original content 1", - timestamp=datetime.datetime.utcnow(), + original_prompt_text="Original content 1", + converted_prompt_text="Original content 1", ), - ConversationData( - conversation_id="other_id", role="user", content="Original content 2", timestamp=datetime.datetime.utcnow() + PromptMemoryEntry( + conversation_id="other_id", + role="user", + original_prompt_text="Original content 2", + converted_prompt_text="Original content 2", ), - ConversationData( + PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", - content="Original content 3", - timestamp=datetime.datetime.utcnow(), + original_prompt_text="Original content 3", + converted_prompt_text="Original content 3", ), ] @@ -364,7 +489,7 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): session.commit() # Ensure all entries are committed to the database # Define the fields to update for entries with the specific conversation_id - update_fields = {"content": "Updated content", "role": "assistant"} + update_fields = {"original_prompt_text": "Updated content", "role": "assistant"} # Use the update_entries_by_conversation_id method to update the entries update_result = setup_duckdb_database.update_entries_by_conversation_id( @@ -374,13 +499,13 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): # Verify that the entries with the specific conversation_id were updated updated_entries = setup_duckdb_database.query_entries( - ConversationData, conditions=ConversationData.conversation_id == specific_conversation_id + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == specific_conversation_id ) for entry in updated_entries: - assert entry.content == "Updated content" + assert entry.original_prompt_text == "Updated content" assert entry.role == "assistant" # Verify that the entry with a different conversation_id was not updated - other_entry = session.query(ConversationData).filter_by(conversation_id="other_id").first() - assert other_entry.content == "Original content 2" # Content should remain unchanged + other_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="other_id").first() + assert other_entry.original_prompt_text == "Original content 2" # Content should remain unchanged assert other_entry.role == "user" # Role should remain unchanged diff --git a/tests/memory/test_memory_embedding.py b/tests/memory/test_memory_embedding.py index 4c613b94b..8e459bf53 100644 --- a/tests/memory/test_memory_embedding.py +++ b/tests/memory/test_memory_embedding.py @@ -8,7 +8,7 @@ from pyrit.memory import MemoryEmbedding from pyrit.models import EmbeddingData, EmbeddingResponse, EmbeddingUsageInformation from pyrit.memory.memory_embedding import default_memory_embedding_factory -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry DEFAULT_EMBEDDING_DATA = EmbeddingData(embedding=[0.0], index=0, object="mock_object") @@ -50,13 +50,15 @@ def memory_encoder_w_mock_embedding_generator(): def test_memory_encoding_chat_message( memory_encoder_w_mock_embedding_generator: MemoryEmbedding, ): - chat_memory = ConversationData( - content="hello world!", + chat_memory = PromptMemoryEntry( + original_prompt_text="hello world!", + converted_prompt_text="hello world!", role="user", + converted_prompt_data_type="text", conversation_id="my_session", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) - assert metadata.uuid == chat_memory.uuid + assert metadata.id == chat_memory.id assert metadata.embedding == DEFAULT_EMBEDDING_DATA.embedding assert metadata.embedding_type_name == "MockEmbeddingGenerator" @@ -82,5 +84,5 @@ def test_default_memory_embedding_factory_without_embedding_model_and_environmen monkeypatch.delenv("AZURE_OPENAI_EMBEDDING_ENDPOINT", raising=False) monkeypatch.delenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", raising=False) - memory_embedding = default_memory_embedding_factory() - assert memory_embedding is None + with pytest.raises(ValueError): + default_memory_embedding_factory() diff --git a/tests/memory/test_memory_encoder.py b/tests/memory/test_memory_encoder.py index dfca720d9..f7d2bbee9 100644 --- a/tests/memory/test_memory_encoder.py +++ b/tests/memory/test_memory_encoder.py @@ -7,7 +7,7 @@ from pyrit.interfaces import EmbeddingSupport from pyrit.memory import MemoryEmbedding -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.models import EmbeddingData, EmbeddingResponse, EmbeddingUsageInformation @@ -53,12 +53,14 @@ def memory_encoder_w_mock_embedding_generator(): def test_memory_encoding_chat_message( memory_encoder_w_mock_embedding_generator: MemoryEmbedding, ): - chat_memory = ConversationData( - content="hello world!", + chat_memory = PromptMemoryEntry( + original_prompt_text="hello world!", + converted_prompt_text="hello world!", role="user", conversation_id="my_session", + converted_prompt_data_type="text", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) - assert metadata.uuid == chat_memory.uuid + assert metadata.id == chat_memory.id assert metadata.embedding == DEFAULT_EMBEDDING_DATA.embedding assert metadata.embedding_type_name == "MockEmbeddingGenerator" diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 8c0ada080..3d56aeccf 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -2,21 +2,18 @@ # Licensed under the MIT license. import json - import pytest from pyrit.memory.memory_exporter import MemoryExporter -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry + from sqlalchemy.inspection import inspect +from tests.mocks import get_sample_conversations @pytest.fixture -def sample_conversations(): - # Create some instances of ConversationStore with sample data - return [ - ConversationData(role="User", content="Hello, how are you?", conversation_id="12345"), - ConversationData(role="Bot", content="I'm fine, thank you!", conversation_id="12345"), - ] +def sample_conversations() -> list[PromptMemoryEntry]: + return get_sample_conversations() def model_to_dict(instance): @@ -34,14 +31,16 @@ def test_export_to_json_creates_file(tmp_path, sample_conversations): with open(file_path, "r") as f: content = json.load(f) # Perform more detailed checks on content if necessary - assert len(content) == 2 # Simple check for the number of items + assert len(content) == 3 # Simple check for the number of items # Convert each ConversationStore instance to a dictionary expected_content = [model_to_dict(conv) for conv in sample_conversations] for expected, actual in zip(expected_content, content): assert expected["role"] == actual["role"] - assert expected["content"] == actual["content"] + assert expected["converted_prompt_text"] == actual["converted_prompt_text"] assert expected["conversation_id"] == actual["conversation_id"] + assert expected["original_prompt_data_type"] == actual["original_prompt_data_type"] + assert expected["original_prompt_text"] == actual["original_prompt_text"] def test_export_data_with_conversations(tmp_path, sample_conversations): @@ -59,10 +58,10 @@ def test_export_data_with_conversations(tmp_path, sample_conversations): # Read the file and verify its contents with open(file_path, "r") as f: content = json.load(f) - assert len(content) == 2 # Check for the expected number of items - assert content[0]["role"] == "User" - assert content[0]["content"] == "Hello, how are you?" + assert len(content) == 3 # Check for the expected number of items + assert content[0]["role"] == "user" + assert content[0]["converted_prompt_text"] == "Hello, how are you?" assert content[0]["conversation_id"] == "12345" - assert content[1]["role"] == "Bot" - assert content[1]["content"] == "I'm fine, thank you!" + assert content[1]["role"] == "assistant" + assert content[1]["converted_prompt_text"] == "I'm fine, thank you!" assert content[1]["conversation_id"] == "12345" diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 3306c5ea3..40520b1a6 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -1,71 +1,65 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Generator +import pytest import random from string import ascii_lowercase -import pytest -from sqlalchemy import inspect - -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory import MemoryInterface from pyrit.models import ChatMessage -from pyrit.memory.memory_models import ConversationData +from tests.mocks import get_memory_interface -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() def generate_random_string(length: int = 10) -> str: return "".join(random.choice(ascii_lowercase) for _ in range(length)) -def test_memory(memory: MemoryInterface): - assert memory +def test_memory(memory_interface: MemoryInterface): + assert memory_interface -def test_conversation_memory_empty_by_default(memory: MemoryInterface): +def test_conversation_memory_empty_by_default(memory_interface: MemoryInterface): expected_count = 0 - c = memory.get_all_memory(ConversationData) + c = memory_interface.get_all_prompt_entries() assert len(c) == expected_count def test_count_of_memories_matches_number_of_conversations_added_1( - memory: MemoryInterface, + memory_interface: MemoryInterface, ): expected_count = 1 message = ChatMessage(role="user", content="Hello") - memory.add_chat_message_to_memory(conversation=message, conversation_id="1", labels=[]) - c = memory.get_all_memory(ConversationData) + memory_interface.add_chat_message_to_memory(conversation=message, conversation_id="1", labels={}) + c = memory_interface.get_all_prompt_entries() assert len(c) == expected_count -def test_add_chate_message_to_memory_added(memory: MemoryInterface): +def test_add_chat_message_to_memory_added(memory_interface: MemoryInterface): expected_count = 3 - memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1") - memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1") - memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1") - assert len(memory.get_all_memory(ConversationData)) == expected_count - - -def test_add_chate_messages_to_memory_added(memory: MemoryInterface): + memory_interface.add_chat_message_to_memory( + conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1" + ) + memory_interface.add_chat_message_to_memory( + conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1" + ) + memory_interface.add_chat_message_to_memory( + conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1" + ) + assert len(memory_interface.get_all_prompt_entries()) == expected_count + + +def test_add_chat_messages_to_memory_added(memory_interface: MemoryInterface): messages = [ ChatMessage(role="user", content="Hello 1"), ChatMessage(role="user", content="Hello 2"), ] - memory.add_chat_messages_to_memory(conversations=messages, conversation_id="1") - assert len(memory.get_all_memory(ConversationData)) == len(messages) + memory_interface.add_chat_messages_to_memory(conversations=messages, conversation_id="1") + assert len(memory_interface.get_all_prompt_entries()) == len(messages) diff --git a/tests/memory/test_memory_models.py b/tests/memory/test_memory_models.py new file mode 100644 index 000000000..f1e0fe76e --- /dev/null +++ b/tests/memory/test_memory_models.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import time + +from datetime import datetime +from unittest.mock import MagicMock +from pyrit.memory import PromptMemoryEntry +from pyrit.orchestrator import PromptSendingOrchestrator +from pyrit.prompt_converter import Base64Converter, PromptConverterList +from tests.mocks import MockPromptTarget + + +def test_id_set(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.id is not None + + +def test_datetime_set(): + now = datetime.utcnow() + time.sleep(0.1) + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.timestamp > now + + +def test_is_sequence_set_false(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.is_sequence_set() is False + + +def test_is_sequence_set_true(): + entry = PromptMemoryEntry(role="user", original_prompt_text="Hello", converted_prompt_text="Hello", sequence=1) + assert entry.is_sequence_set() + + +def test_converters_serialize(): + converters = PromptConverterList([Base64Converter()]) + entry = PromptMemoryEntry( + role="user", original_prompt_text="Hello", converted_prompt_text="Hello", converters=converters + ) + assert ( + entry.converters == '[{"__type__": "Base64Converter", "__module__": "pyrit.prompt_converter.base64_converter"}]' + ) + + +def test_prompt_targets_serialize(): + target = MockPromptTarget() + entry = PromptMemoryEntry( + role="user", original_prompt_text="Hello", converted_prompt_text="Hello", prompt_target=target + ) + + j = json.loads(entry.prompt_target) + + assert j["__type__"] == "MockPromptTarget" + assert j["__module__"] == "tests.mocks" + + +def test_orchestrators_serialize(): + orchestrator = PromptSendingOrchestrator(prompt_target=MagicMock(), memory=MagicMock()) + + entry = PromptMemoryEntry( + role="user", original_prompt_text="Hello", converted_prompt_text="Hello", orchestrator=orchestrator + ) + + j = json.loads(entry.orchestrator) + + assert j["id"] is not None + assert j["__type__"] == "PromptSendingOrchestrator" + assert j["__module__"] == "pyrit.orchestrator.prompt_sending_orchestrator" + + +def test_hashes_generated(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello1", + converted_prompt_text="Hello2", + ) + + assert entry.original_prompt_data_sha256 == "948edbe7ede5aa7423476ae29dcd7d61e7711a071aea0d83698377effa896525" + assert entry.converted_prompt_data_sha256 == "be98c2510e417405647facb89399582fc499c3de4452b3014857f92e6baad9a9" diff --git a/tests/mocks.py b/tests/mocks.py index e274a5cf5..3e3cce529 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -2,7 +2,12 @@ # Licensed under the MIT license. from contextlib import AbstractAsyncContextManager +from typing import Generator +from sqlalchemy import inspect + +from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.prompt_target import PromptTarget @@ -61,3 +66,44 @@ def send_prompt(self, normalized_prompt: str, conversation_id: str, normalizer_i async def send_prompt_async(self, normalized_prompt: str, conversation_id: str, normalizer_id: str) -> None: self.prompt_sent.append(normalized_prompt) + + +def get_memory_interface() -> Generator[MemoryInterface, None, None]: + # Create an in-memory DuckDB engine + duckdb_memory = DuckDBMemory(db_path=":memory:") + + duckdb_memory.disable_embedding() + + # Reset the database to ensure a clean state + duckdb_memory.reset_database() + inspector = inspect(duckdb_memory.engine) + + # Verify that tables are created as expected + assert "PromptMemoryEntries" in inspector.get_table_names(), "PromptMemoryEntries table not created." + assert "EmbeddingData" in inspector.get_table_names(), "EmbeddingData table not created." + + yield duckdb_memory + duckdb_memory.dispose_engine() + + +def get_sample_conversations(): + return [ + PromptMemoryEntry( + role="user", + original_prompt_text="original prompt text", + converted_prompt_text="Hello, how are you?", + conversation_id="12345", + ), + PromptMemoryEntry( + role="assistant", + original_prompt_text="original prompt text", + converted_prompt_text="I'm fine, thank you!", + conversation_id="12345", + ), + PromptMemoryEntry( + role="assistant", + original_prompt_text="original prompt text", + converted_prompt_text="I'm fine, thank you!", + conversation_id="33333", + ), + ] diff --git a/tests/test_prompt_target.py b/tests/test_prompt_target.py index 6b688dd36..a25fdc899 100644 --- a/tests/test_prompt_target.py +++ b/tests/test_prompt_target.py @@ -1,16 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Generator from unittest.mock import patch import pytest -from sqlalchemy import inspect from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory.memory_interface import MemoryInterface from pyrit.prompt_target import AzureOpenAIChatTarget +from tests.mocks import get_memory_interface + + +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() + @pytest.fixture def openai_mock_return() -> ChatCompletion: @@ -36,40 +43,23 @@ def chat_completion_engine() -> AzureOpenAIChatTarget: @pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() - - -@pytest.fixture -def azure_openai_target(memory: DuckDBMemory): +def azure_openai_target(memory_interface: MemoryInterface): return AzureOpenAIChatTarget( deployment_name="test", endpoint="test", api_key="test", - memory=memory, + memory=memory_interface, ) def test_set_system_prompt(azure_openai_target: AzureOpenAIChatTarget): azure_openai_target.set_system_prompt(prompt="system prompt", conversation_id="1", normalizer_id="2") - chats = azure_openai_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_openai_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "system" - assert chats[0].content == "system prompt" + assert chats[0].converted_prompt_text == "system prompt" def test_send_prompt_user_no_system(azure_openai_target: AzureOpenAIChatTarget, openai_mock_return: ChatCompletion): @@ -79,7 +69,7 @@ def test_send_prompt_user_no_system(azure_openai_target: AzureOpenAIChatTarget, normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" ) - chats = azure_openai_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_openai_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 2, f"Expected 2 chats, got {len(chats)}" assert chats[0].role == "user" assert chats[1].role == "assistant" @@ -95,7 +85,7 @@ def test_send_prompt_with_system(azure_openai_target: AzureOpenAIChatTarget, ope normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" ) - chats = azure_openai_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_openai_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 3, f"Expected 3 chats, got {len(chats)}" assert chats[0].role == "system" assert chats[1].role == "user" diff --git a/tests/test_prompt_target_azure_blob_storage.py b/tests/test_prompt_target_azure_blob_storage.py index 808329cb6..44f5e5550 100644 --- a/tests/test_prompt_target_azure_blob_storage.py +++ b/tests/test_prompt_target_azure_blob_storage.py @@ -2,38 +2,28 @@ # Licensed under the MIT license. import os +from typing import Generator import pytest -from sqlalchemy import inspect from unittest.mock import patch -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory import MemoryInterface from pyrit.prompt_target import AzureBlobStorageTarget +from tests.mocks import get_memory_interface -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() @pytest.fixture -def azure_blob_storage_target(memory: DuckDBMemory): +def azure_blob_storage_target(memory_interface: MemoryInterface): return AzureBlobStorageTarget( container_url="https://test.blob.core.windows.net/test", sas_token="valid_sas_token", - memory=memory, + memory=memory_interface, ) @@ -87,10 +77,10 @@ def test_send_prompt(mock_upload, azure_blob_storage_target: AzureBlobStorageTar assert blob_url.__contains__(azure_blob_storage_target._container_url) assert blob_url.__contains__(".txt") - chats = azure_blob_storage_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_blob_storage_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "user" - assert chats[0].content == __name__ + assert chats[0].converted_prompt_text == __name__ @patch("azure.storage.blob.aio.ContainerClient.upload_blob") @@ -103,7 +93,7 @@ async def test_send_prompt_async(mock_upload_async, azure_blob_storage_target: A assert blob_url.__contains__(azure_blob_storage_target._container_url) assert blob_url.__contains__(".txt") - chats = azure_blob_storage_target._memory.get_memories_with_conversation_id(conversation_id="2") + chats = azure_blob_storage_target._memory.get_prompt_entries_with_conversation_id(conversation_id="2") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "user" - assert chats[0].content == __name__ + assert chats[0].converted_prompt_text == __name__ diff --git a/tests/test_prompt_target_text.py b/tests/test_prompt_target_text.py index 56679bc13..506af4171 100644 --- a/tests/test_prompt_target_text.py +++ b/tests/test_prompt_target_text.py @@ -3,48 +3,37 @@ import os from tempfile import NamedTemporaryFile +from typing import Generator import pytest -from sqlalchemy import inspect - -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory import MemoryInterface from pyrit.prompt_target import TextTarget +from tests.mocks import get_memory_interface -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() -def test_send_prompt_user_no_system(memory: DuckDBMemory): - no_op = TextTarget(memory=memory) +def test_send_prompt_user_no_system(memory_interface: MemoryInterface): + no_op = TextTarget(memory=memory_interface) no_op.send_prompt( normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" ) - chats = no_op._memory.get_memories_with_conversation_id(conversation_id="1") + chats = no_op._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "user" -def test_send_prompt_stream(memory: DuckDBMemory): +def test_send_prompt_stream(memory_interface: MemoryInterface): with NamedTemporaryFile(mode="w+", delete=False) as tmp_file: prompt = "hi, I am a victim chatbot, how can I help?" - no_op = TextTarget(memory=memory, text_stream=tmp_file) + no_op = TextTarget(memory=memory_interface, text_stream=tmp_file) no_op.send_prompt(normalized_prompt=prompt, conversation_id="1", normalizer_id="2") tmp_file.seek(0) diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index b653177cb..c90ad0184 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -2,38 +2,26 @@ # Licensed under the MIT license. import pathlib -from typing import Union +from typing import Generator, Union from unittest.mock import Mock, patch +from pyrit.memory.memory_interface import MemoryInterface from pyrit.prompt_converter.prompt_converter import PromptConverter from pyrit.prompt_target.prompt_target import PromptTarget import pytest -from sqlalchemy import inspect -from pyrit.memory.memory_models import ConversationData from pyrit.orchestrator import ScoringRedTeamingOrchestrator, EndTokenRedTeamingOrchestrator from pyrit.orchestrator.end_token_red_teaming_orchestrator import RED_TEAM_CONVERSATION_END_TOKEN from pyrit.prompt_target import AzureOpenAIChatTarget from pyrit.models import AttackStrategy, ChatMessage, Score -from pyrit.memory import DuckDBMemory from pyrit.common.path import DATASETS_PATH +from tests.mocks import get_memory_interface -@pytest.fixture -def memory() -> DuckDBMemory: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - yield duckdb_memory - duckdb_memory.dispose_engine() +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() @pytest.fixture @@ -42,12 +30,12 @@ def chat_completion_engine() -> AzureOpenAIChatTarget: @pytest.fixture -def prompt_target(memory) -> AzureOpenAIChatTarget: +def prompt_target(memory_interface) -> AzureOpenAIChatTarget: return AzureOpenAIChatTarget( deployment_name="test", endpoint="test", api_key="test", - memory=memory, + memory=memory_interface, ) @@ -79,18 +67,21 @@ def check_conversations( # first conversation (with red teaming chat bot) assert conversations[0].conversation_id == conversations[1].conversation_id == conversations[2].conversation_id assert conversations[0].role == "system" - assert conversations[0].content == red_teaming_meta_prompt + assert conversations[0].converted_prompt_text == red_teaming_meta_prompt assert conversations[1].role == "user" - assert conversations[1].content == initial_red_teaming_prompt + assert conversations[1].converted_prompt_text == initial_red_teaming_prompt assert conversations[2].role == "assistant" - assert conversations[2].content == expected_red_teaming_responses[0] + assert conversations[2].converted_prompt_text == expected_red_teaming_responses[0] # second conversation (with prompt target) assert conversations[3 - index_offset].conversation_id == conversations[4 - index_offset].conversation_id - assert conversations[3 - index_offset].normalizer_id == conversations[4 - index_offset].normalizer_id + assert ( + conversations[3 - index_offset].labels["normalizer_id"] + == conversations[4 - index_offset].labels["normalizer_id"] + ) assert conversations[3 - index_offset].role == "user" - assert conversations[3 - index_offset].content == expected_red_teaming_responses[0] + assert conversations[3 - index_offset].converted_prompt_text == expected_red_teaming_responses[0] assert conversations[4 - index_offset].role == "assistant" - assert conversations[4 - index_offset].content == expected_target_responses[0] + assert conversations[4 - index_offset].converted_prompt_text == expected_target_responses[0] if stop_after_n_conversations == 2: return @@ -103,20 +94,23 @@ def check_conversations( # third conversation (with red teaming chatbot) assert conversations[5 - index_offset].conversation_id == conversations[6 - index_offset].conversation_id assert conversations[5 - index_offset].role == "user" - assert conversations[5 - index_offset].content == expected_target_responses[0] + assert conversations[5 - index_offset].converted_prompt_text == expected_target_responses[0] assert conversations[6 - index_offset].role == "assistant" - assert conversations[6 - index_offset].content == expected_red_teaming_responses[1] + assert conversations[6 - index_offset].converted_prompt_text == expected_red_teaming_responses[1] if stop_after_n_conversations == 3: return # fourth conversation (with prompt target) assert conversations[7 - index_offset].conversation_id == conversations[8 - index_offset].conversation_id - assert conversations[7 - index_offset].normalizer_id == conversations[8 - index_offset].normalizer_id + assert ( + conversations[7 - index_offset].labels["normalizer_id"] + == conversations[8 - index_offset].labels["normalizer_id"] + ) assert conversations[7 - index_offset].role == "user" - assert conversations[7 - index_offset].content == expected_red_teaming_responses[1] + assert conversations[7 - index_offset].converted_prompt_text == expected_red_teaming_responses[1] assert conversations[8 - index_offset].role == "assistant" - assert conversations[8 - index_offset].content == expected_target_responses[1] + assert conversations[8 - index_offset].converted_prompt_text == expected_target_responses[1] @pytest.mark.parametrize("attack_strategy_as_str", [True, False]) @@ -125,7 +119,7 @@ def test_send_prompt_twice( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -135,7 +129,7 @@ def test_send_prompt_twice( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory, + "memory": memory_interface, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -152,7 +146,7 @@ def test_send_prompt_twice( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 5, f"Expected 5 conversations, got {len(conversations)}" check_conversations( @@ -173,7 +167,7 @@ def test_send_prompt_twice( mock_target.return_value = expected_target_responses[1] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[1] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting another two conversation threads assert len(conversations) == 9, f"Expected 9 conversations, got {len(conversations)}" check_conversations( @@ -192,7 +186,7 @@ def test_send_fixed_prompt_then_generated_prompt( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -202,7 +196,7 @@ def test_send_fixed_prompt_then_generated_prompt( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory, + "memory": memory_interface, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -218,7 +212,7 @@ def test_send_fixed_prompt_then_generated_prompt( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt(prompt=fixed_input_prompt) assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 2, f"Expected 2 conversations, got {len(conversations)}" check_conversations( @@ -240,7 +234,7 @@ def test_send_fixed_prompt_then_generated_prompt( mock_target.return_value = expected_target_responses[1] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[1] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting another two conversation threads assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations( @@ -259,7 +253,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -269,7 +263,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory, + "memory": memory_interface, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -285,7 +279,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt(prompt=fixed_input_prompt) assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 2, f"Expected 2 conversations, got {len(conversations)}" check_conversations( @@ -314,7 +308,7 @@ def test_reach_goal_after_two_turns_end_token( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, ): attack_strategy: Union[str | AttackStrategy] = ( @@ -323,7 +317,7 @@ def test_reach_goal_after_two_turns_end_token( red_teaming_orchestrator = EndTokenRedTeamingOrchestrator( red_teaming_chat=chat_completion_engine, - memory=memory, + memory=memory_interface, attack_strategy=attack_strategy, initial_red_teaming_prompt="how can I help you?", prompt_target=prompt_target, @@ -341,7 +335,7 @@ def test_reach_goal_after_two_turns_end_token( mock_target.return_value = expected_target_response target_response = red_teaming_orchestrator.apply_attack_strategy_until_completion() assert target_response == expected_target_response - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting three conversation threads (two with red teaming chat and one with prompt target) assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations(