From e69792f502c4861e03143d67abf1aec1ae4ef92c Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 25 Mar 2024 18:22:49 -0700 Subject: [PATCH 01/15] refactoring schema --- pyrit/analytics/conversation_analytics.py | 6 +- pyrit/common/print.py | 4 +- pyrit/memory/__init__.py | 4 +- pyrit/memory/duckdb_memory.py | 12 +-- pyrit/memory/memory_embedding.py | 6 +- pyrit/memory/memory_interface.py | 10 +- pyrit/memory/memory_models.py | 93 +++++++++++-------- .../analytics/test_conversation_analytics.py | 16 ++-- tests/memory/test_duckdb_memory.py | 52 +++++------ tests/memory/test_memory_embedding.py | 6 +- tests/memory/test_memory_encoder.py | 6 +- tests/memory/test_memory_exporter.py | 6 +- tests/memory/test_memory_interface.py | 10 +- tests/test_red_teaming_orchestrator.py | 14 +-- 14 files changed, 130 insertions(+), 115 deletions(-) diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index 9904e2b95..a813320c5 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -6,7 +6,7 @@ from sklearn.metrics.pairwise import cosine_similarity from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ConversationMessageWithSimilarity, EmbeddingMessageWithSimilarity -from pyrit.memory.memory_models import ConversationData, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData class ConversationAnalytics: @@ -37,7 +37,7 @@ def get_similar_chat_messages_by_content( list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on content. """ - all_memories = self.memory_interface.get_all_memory(ConversationData) + all_memories = self.memory_interface.get_all_memory(PromptMemoryEntry) similar_messages = [] for memory in all_memories: @@ -82,7 +82,7 @@ def get_similar_chat_messages_by_embedding( if similarity_score >= threshold: similar_messages.append( EmbeddingMessageWithSimilarity( - score=similarity_score, uuid=memory.uuid, metric="cosine_similarity" # type: ignore + score=similarity_score, uuid=memory.id, metric="cosine_similarity" # type: ignore ) ) diff --git a/pyrit/common/print.py b/pyrit/common/print.py index f7f71142d..490c7e459 100644 --- a/pyrit/common/print.py +++ b/pyrit/common/print.py @@ -4,13 +4,13 @@ import textwrap import termcolor -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.models import ChatMessage from termcolor._types import Color def print_chat_messages_with_color( - messages: list[ChatMessage | ConversationData], + messages: list[ChatMessage | PromptMemoryEntry], max_content_character_width: int = 80, left_padding_width: int = 20, custom_colors: dict[str, Color] = None, diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 6b927adb2..c6af289a9 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.memory.duckdb_memory import DuckDBMemory from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter -__all__ = ["ConversationData", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"] +__all__ = ["PromptMemoryEntry", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"] diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 854a57c83..158f9ddd1 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -12,7 +12,7 @@ from sqlalchemy.engine.base import Engine from contextlib import closing -from pyrit.memory.memory_models import ConversationData, Base +from pyrit.memory.memory_models import PromptMemoryEntry, Base from pyrit.memory.memory_embedding import default_memory_embedding_factory from pyrit.memory.memory_interface import MemoryInterface from pyrit.interfaces import EmbeddingSupport @@ -166,7 +166,7 @@ def get_all_memory(self, model: Base) -> list[Base]: # type: ignore result = self.query_entries(model) return result - def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[ConversationData]: + def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified conversation ID. @@ -177,12 +177,12 @@ def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[Con list[ConversationData]: A list of ConversationData objects matching the specified conversation ID. """ try: - return self.query_entries(ConversationData, conditions=ConversationData.conversation_id == conversation_id) + return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id) except Exception as e: logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") return [] - def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[ConversationData]: + def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified normalizer ID. @@ -193,7 +193,7 @@ def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[Convers list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. """ try: - return self.query_entries(ConversationData, conditions=ConversationData.normalizer_id == normalizer_id) + return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.normalizer_id == normalizer_id) except Exception as e: logger.exception( f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" @@ -213,7 +213,7 @@ def update_entries_by_conversation_id(self, *, conversation_id: str, update_fiel """ # Fetch the relevant entries using query_entries entries_to_update = self.query_entries( - ConversationData, conditions=ConversationData.conversation_id == conversation_id + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id ) # Check if there are entries to update diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index 9dfede46e..a4c2cfd3d 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -4,7 +4,7 @@ import os from pyrit.embedding.azure_text_embedding import AzureTextEmbedding from pyrit.interfaces import EmbeddingSupport -from pyrit.memory.memory_models import ConversationData, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData class MemoryEmbedding: @@ -20,7 +20,7 @@ def __init__(self, *, embedding_model: EmbeddingSupport): raise ValueError("embedding_model must be set.") self.embedding_model = embedding_model - def generate_embedding_memory_data(self, *, chat_memory: ConversationData) -> EmbeddingData: + def generate_embedding_memory_data(self, *, chat_memory: PromptMemoryEntry) -> EmbeddingData: """ Generates metadata for a chat memory entry. @@ -33,7 +33,7 @@ def generate_embedding_memory_data(self, *, chat_memory: ConversationData) -> Em embedding_data = EmbeddingData( embedding=self.embedding_model.generate_text_embedding(text=chat_memory.content).data[0].embedding, embedding_type_name=self.embedding_model.__class__.__name__, - uuid=chat_memory.uuid, + uuid=chat_memory.id, ) return embedding_data diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c43cae33a..6179f5193 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -8,7 +8,7 @@ from uuid import uuid4 -from pyrit.memory.memory_models import Base, ConversationData +from pyrit.memory.memory_models import Base, PromptMemoryEntry from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.models import ChatMessage @@ -33,13 +33,13 @@ def __init__(self, embedding_model=None): self.exporter = MemoryExporter() @abc.abstractmethod - def get_all_memory(self, model: Base) -> list[ConversationData]: # type: ignore + def get_all_memory(self, model: Base) -> list[PromptMemoryEntry]: # type: ignore """ Loads all ConversationData from the memory storage handler. """ @abc.abstractmethod - def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[ConversationData]: + def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified conversation ID. @@ -51,7 +51,7 @@ def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[Con """ @abc.abstractmethod - def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[ConversationData]: + def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified normalizer ID. @@ -192,7 +192,7 @@ def _create_chat_message_memory_entry( ConversationData: A new instance ready to be persisted in the memory storage. """ uuid = uuid4() - new_chat_memory = ConversationData( + new_chat_memory = PromptMemoryEntry( role=conversation.role, content=conversation.content, conversation_id=conversation_id, diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 2cbd1620d..a5265f4bc 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -1,62 +1,77 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from uuid import uuid4 +import enum import uuid + from datetime import datetime +from typing import Dict +from uuid import uuid4 from pydantic import BaseModel, ConfigDict +from sqlalchemy import Column, String, DateTime, Float, Enum, JSON, ForeignKey, Index from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, String, DateTime, Float from sqlalchemy.dialects.postgresql import ARRAY, UUID -from sqlalchemy import ForeignKey, Index Base = declarative_base() +class PromptType(enum.Enum): + SYSTEM = 'system' + REQUEST_SEGMENT = 'request_segment' + RESPONSE = 'response' -class ConversationData(Base): # type: ignore - """ - Represents the conversation data. +class PromptDataType(enum.Enum): + TEXT = 'text' + IMAGE = 'image' - conversation_id is used to group messages together within a prompt_target endpoint. - It's often needed so the prompt_target knows how to construct the messages. - normalizer_id is used to group messages together within a prompt_normalizer. - A prompt_normalizer is usually a single attack, and can contain multiple prompt_targets. - It's often needed to group all the prompts in an attack together. +class PromptMemoryEntry(Base): # type: ignore + """ + Represents the prompt data. Attributes: - uuid (UUID): A unique identifier for each conversation entry, serving as the primary key. - role (String): The role associated with the message, indicating its origin - within the conversation (e.g., "user", "assistant" or "system"). - content (String): The actual text content of the conversation entry. - conversation_id (String): An identifier used to group related conversation entries. - The conversation_id is linked to a specific LLM model, - aggregating all related conversations under a single identifier. - In scenarios involving multi-turn interactions that utilize two models, - there will be two distinct conversation_ids, one for each model. - timestamp (DateTime): The timestamp when the conversation entry was created or - logged. Defaults to the current UTC time. - normalizer_id (String): An identifier used to group messages together within a prompt_normalizer. - sha256 (String): An optional SHA-256 hash of the content. - labels (ARRAY(String)): An array of labels associated with the conversation entry, - useful for categorization or filtering the final data. - idx_conversation_id (Index): An index on the `conversation_id` column to improve - query performance for operations involving obtaining conversation history based - on conversation_id. + __tablename__ (str): The name of the database table. + __table_args__ (dict): Additional arguments for the database table. + id (UUID): The unique identifier for the memory entry. + prompt_entry_type (PromptType): The type of the prompt entry (system, request_segment, response). + conversation_id (str): The identifier for the conversation which is associated with a single target. + timestamp (DateTime): The timestamp of the memory entry. + labels (Dict[str, str]): The labels associated with the memory entry. + prompt_metadata (JSON): The metadata associated with the prompt. + converters (list[PromptConverter]): The converters for the prompt. + prompt_target (PromptTarget): The target for the prompt. + original_prompt_data_type (PromptDataType): The data type of the original prompt (text, image) + original_prompt_text (str): The text of the original prompt. If prompt is an image, it's a link. + original_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. + converted_prompt_data_type (PromptDataType): The data type of the converted prompt (text, image) + converted_prompt_text (str): The text of the converted prompt. If prompt is an image, it's a link. + original_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. + idx_conversation_id (Index): The index for the conversation ID. + + Methods: + __str__(): Returns a string representation of the memory entry. """ - __tablename__ = "ConversationStore" + __tablename__ = "MemoryEntries" __table_args__ = {"extend_existing": True} - uuid = Column(UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid4) - role = Column(String, nullable=False) - content = Column(String) + id = Column(UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid4) + prompt_entry_type = Column(Enum(PromptType)) conversation_id = Column(String, nullable=False) timestamp = Column(DateTime, nullable=False, default=datetime.utcnow) - normalizer_id = Column(String) - sha256 = Column(String) - labels = Column(ARRAY(String)) # type: ignore + labels: Column[Dict[str, str]] = Column(JSON) + prompt_metadata = Column(JSON) + converters: 'Column[list[PromptConverter]]' = Column(JSON) + prompt_target: 'Column[PromptTarget]' = Column(JSON) + + original_prompt_data_type = Column(Enum(PromptDataType)) + original_prompt_text = Column(String) + original_prompt_data_sha256 = Column(String) + + converted_prompt_data_type = Column(Enum(PromptDataType)) + converted_prompt_text = Column(String) + original_prompt_data_sha256 = Column(String) + idx_conversation_id = Index("idx_conversation_id", "conversation_id") def __str__(self): @@ -69,15 +84,15 @@ class EmbeddingData(Base): # type: ignore Each embedding is linked to a specific conversation entry via a 'uuid'. Attributes: - uuid (UUID): The primary key, which is a foreign key referencing the UUID in the ConversationStore table. + uuid (UUID): The primary key, which is a foreign key referencing the UUID in the MemoryEntries table. embedding (ARRAY(Float)): An array of floats representing the embedding vector. embedding_type_name (String): The name or type of the embedding, indicating the model or method used. """ - __tablename__ = "EmbeddingStore" + __tablename__ = "EmbeddingData" # Allows table redefinition if already defined. __table_args__ = {"extend_existing": True} - uuid = Column(UUID(as_uuid=True), ForeignKey(f"{ConversationData.__tablename__}.uuid"), primary_key=True) + uuid = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.uuid"), primary_key=True) embedding = Column(ARRAY(Float)) embedding_type_name = Column(String) diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index 6c31e2362..ed10f5179 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -7,7 +7,7 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.analytics.conversation_analytics import ConversationAnalytics -from pyrit.memory.memory_models import ConversationData, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData @pytest.fixture @@ -19,9 +19,9 @@ def mock_memory_interface(): def test_get_similar_chat_messages_by_content(mock_memory_interface): # Mock data returned by the memory interface mock_data = [ - ConversationData(content="Hello, how are you?", role="user"), - ConversationData(content="I'm fine, thank you!", role="assistant"), - ConversationData(content="Hello, how are you?", role="assistant"), # Exact match + PromptMemoryEntry(content="Hello, how are you?", role="user"), + PromptMemoryEntry(content="I'm fine, thank you!", role="assistant"), + PromptMemoryEntry(content="Hello, how are you?", role="assistant"), # Exact match ] mock_memory_interface.get_all_memory.return_value = mock_data @@ -39,8 +39,8 @@ def test_get_similar_chat_messages_by_content(mock_memory_interface): def test_get_similar_chat_messages_by_embedding(mock_memory_interface): # Mock ConversationData entries conversation_entries = [ - ConversationData(uuid=uuid.uuid4(), conversation_id="1", role="user", content="Similar message"), - ConversationData(uuid=uuid.uuid4(), conversation_id="2", role="assistant", content="Different message"), + PromptMemoryEntry(uuid=uuid.uuid4(), conversation_id="1", role="user", content="Similar message"), + PromptMemoryEntry(uuid=uuid.uuid4(), conversation_id="2", role="assistant", content="Different message"), ] # Mock EmbeddingData entries linked to the ConversationData entries @@ -49,8 +49,8 @@ def test_get_similar_chat_messages_by_embedding(mock_memory_interface): different_embedding = [0.9, 0.8, 0.7] mock_data = [ - EmbeddingData(uuid=conversation_entries[0].uuid, embedding=similar_embedding, embedding_type_name="model1"), - EmbeddingData(uuid=conversation_entries[1].uuid, embedding=different_embedding, embedding_type_name="model2"), + EmbeddingData(uuid=conversation_entries[0].id, embedding=similar_embedding, embedding_type_name="model1"), + EmbeddingData(uuid=conversation_entries[1].id, embedding=different_embedding, embedding_type_name="model2"), ] # Mock the get_all_memory method to return the mock EmbeddingData entries diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index aa86c2979..7cb56f56c 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -12,7 +12,7 @@ from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import String, DateTime -from pyrit.memory.memory_models import ConversationData, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory import DuckDBMemory @@ -127,7 +127,7 @@ def test_embedding_data_column_types(setup_duckdb_database): def test_insert_entry(setup_duckdb_database): session = setup_duckdb_database.get_session() - entry = ConversationData( + entry = PromptMemoryEntry( conversation_id="123", role="user", content="Hello", @@ -138,7 +138,7 @@ def test_insert_entry(setup_duckdb_database): # Now, get a new session to query the database and verify the entry was inserted with setup_duckdb_database.get_session() as session: - inserted_entry = session.query(ConversationData).filter_by(conversation_id="123").first() + inserted_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() assert inserted_entry is not None assert inserted_entry.role == "user" assert inserted_entry.content == "Hello" @@ -149,8 +149,8 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): # Generate a fixed UUID fixed_uuid = uuid.uuid4() # Create two entries with the same UUID - entry1 = ConversationData(uuid=fixed_uuid, conversation_id="123", role="user", content="Hello") - entry2 = ConversationData(uuid=fixed_uuid, conversation_id="456", role="user", content="Hello again") + entry1 = PromptMemoryEntry(uuid=fixed_uuid, conversation_id="123", role="user", content="Hello") + entry2 = PromptMemoryEntry(uuid=fixed_uuid, conversation_id="456", role="user", content="Hello again") # Insert the first entry with setup_duckdb_database.get_session() as session: @@ -166,7 +166,7 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): def test_insert_entries(setup_duckdb_database): entries = [ - ConversationData(conversation_id=str(i), role="user", content=f"Message {i}", sha256=f"hash{i}") + PromptMemoryEntry(conversation_id=str(i), role="user", content=f"Message {i}", sha256=f"hash{i}") for i in range(5) ] @@ -174,7 +174,7 @@ def test_insert_entries(setup_duckdb_database): with setup_duckdb_database.get_session() as session: # Use the insert_entries method to insert multiple entries into the database setup_duckdb_database.insert_entries(entries=entries) - inserted_entries = session.query(ConversationData).all() + inserted_entries = session.query(PromptMemoryEntry).all() assert len(inserted_entries) == 5 for i, entry in enumerate(inserted_entries): assert entry.conversation_id == str(i) @@ -185,7 +185,7 @@ def test_insert_entries(setup_duckdb_database): def test_insert_embedding_entry(setup_duckdb_database): # Create a ConversationData entry - conversation_entry = ConversationData(conversation_id="123", role="user", content="Hello", sha256="abc") + conversation_entry = PromptMemoryEntry(conversation_id="123", role="user", content="Hello", sha256="abc") # Insert the ConversationData entry using the insert_entry method setup_duckdb_database.insert_entry(conversation_entry) @@ -193,7 +193,7 @@ def test_insert_embedding_entry(setup_duckdb_database): # Re-query the ConversationData entry within a new session to ensure it's attached with setup_duckdb_database.get_session() as session: # Assuming uuid is the primary key and is set upon insertion - reattached_conversation_entry = session.query(ConversationData).filter_by(conversation_id="123").one() + reattached_conversation_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").one() uuid = reattached_conversation_entry.uuid # Now that we have the uuid, we can create and insert the EmbeddingData entry @@ -210,16 +210,16 @@ def test_insert_embedding_entry(setup_duckdb_database): def test_query_entries(setup_duckdb_database): # Insert some test data - entries = [ConversationData(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] + entries = [PromptMemoryEntry(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] setup_duckdb_database.insert_entries(entries=entries) # Query entries without conditions - queried_entries = setup_duckdb_database.query_entries(ConversationData) + queried_entries = setup_duckdb_database.query_entries(PromptMemoryEntry) assert len(queried_entries) == 3 # Query entries with a condition specific_entry = setup_duckdb_database.query_entries( - ConversationData, conditions=ConversationData.conversation_id == "1" + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "1" ) assert len(specific_entry) == 1 assert specific_entry[0].content == "Message 1" @@ -227,28 +227,28 @@ def test_query_entries(setup_duckdb_database): def test_update_entries(setup_duckdb_database): # Insert a test entry - entry = ConversationData(conversation_id="123", role="user", content="Hello") + entry = PromptMemoryEntry(conversation_id="123", role="user", content="Hello") setup_duckdb_database.insert_entry(entry) # Fetch the entry to update and update its content entries_to_update = setup_duckdb_database.query_entries( - ConversationData, conditions=ConversationData.conversation_id == "123" + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "123" ) setup_duckdb_database.update_entries(entries=entries_to_update, update_fields={"content": "Updated Hello"}) # Verify the entry was updated with setup_duckdb_database.get_session() as session: - updated_entry = session.query(ConversationData).filter_by(conversation_id="123").first() + updated_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() assert updated_entry.content == "Updated Hello" def test_get_all_memory(setup_duckdb_database): # Insert some test data - entries = [ConversationData(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] + entries = [PromptMemoryEntry(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] setup_duckdb_database.insert_entries(entries=entries) # Fetch all entries - all_entries = setup_duckdb_database.get_all_memory(ConversationData) + all_entries = setup_duckdb_database.get_all_memory(PromptMemoryEntry) assert len(all_entries) == 3 @@ -259,7 +259,7 @@ def test_get_memories_with_conversation_id(setup_duckdb_database): # Start a session with setup_duckdb_database.get_session() as session: # Create a ConversationData entry with all attributes filled - entry = ConversationData( + entry = PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", content="Test content", @@ -297,21 +297,21 @@ def test_get_memories_with_normalizer_id(setup_duckdb_database): # Create a list of ConversationData entries, some with the specific normalizer_id entries = [ - ConversationData( + PromptMemoryEntry( conversation_id="123", role="user", content="Hello 1", normalizer_id=specific_normalizer_id, timestamp=datetime.datetime.utcnow(), ), - ConversationData( + PromptMemoryEntry( conversation_id="456", role="user", content="Hello 2", normalizer_id="other_normalizer_id", timestamp=datetime.datetime.utcnow(), ), - ConversationData( + PromptMemoryEntry( conversation_id="789", role="user", content="Hello 3", @@ -341,16 +341,16 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): # Create a list of ConversationData entries, some with the specific conversation_id entries = [ - ConversationData( + PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", content="Original content 1", timestamp=datetime.datetime.utcnow(), ), - ConversationData( + PromptMemoryEntry( conversation_id="other_id", role="user", content="Original content 2", timestamp=datetime.datetime.utcnow() ), - ConversationData( + PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", content="Original content 3", @@ -374,13 +374,13 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): # Verify that the entries with the specific conversation_id were updated updated_entries = setup_duckdb_database.query_entries( - ConversationData, conditions=ConversationData.conversation_id == specific_conversation_id + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == specific_conversation_id ) for entry in updated_entries: assert entry.content == "Updated content" assert entry.role == "assistant" # Verify that the entry with a different conversation_id was not updated - other_entry = session.query(ConversationData).filter_by(conversation_id="other_id").first() + other_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="other_id").first() assert other_entry.content == "Original content 2" # Content should remain unchanged assert other_entry.role == "user" # Role should remain unchanged diff --git a/tests/memory/test_memory_embedding.py b/tests/memory/test_memory_embedding.py index 4c613b94b..73d8553d0 100644 --- a/tests/memory/test_memory_embedding.py +++ b/tests/memory/test_memory_embedding.py @@ -8,7 +8,7 @@ from pyrit.memory import MemoryEmbedding from pyrit.models import EmbeddingData, EmbeddingResponse, EmbeddingUsageInformation from pyrit.memory.memory_embedding import default_memory_embedding_factory -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry DEFAULT_EMBEDDING_DATA = EmbeddingData(embedding=[0.0], index=0, object="mock_object") @@ -50,13 +50,13 @@ def memory_encoder_w_mock_embedding_generator(): def test_memory_encoding_chat_message( memory_encoder_w_mock_embedding_generator: MemoryEmbedding, ): - chat_memory = ConversationData( + chat_memory = PromptMemoryEntry( content="hello world!", role="user", conversation_id="my_session", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) - assert metadata.uuid == chat_memory.uuid + assert metadata.uuid == chat_memory.id assert metadata.embedding == DEFAULT_EMBEDDING_DATA.embedding assert metadata.embedding_type_name == "MockEmbeddingGenerator" diff --git a/tests/memory/test_memory_encoder.py b/tests/memory/test_memory_encoder.py index dfca720d9..7cc16574f 100644 --- a/tests/memory/test_memory_encoder.py +++ b/tests/memory/test_memory_encoder.py @@ -7,7 +7,7 @@ from pyrit.interfaces import EmbeddingSupport from pyrit.memory import MemoryEmbedding -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.models import EmbeddingData, EmbeddingResponse, EmbeddingUsageInformation @@ -53,12 +53,12 @@ def memory_encoder_w_mock_embedding_generator(): def test_memory_encoding_chat_message( memory_encoder_w_mock_embedding_generator: MemoryEmbedding, ): - chat_memory = ConversationData( + chat_memory = PromptMemoryEntry( content="hello world!", role="user", conversation_id="my_session", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) - assert metadata.uuid == chat_memory.uuid + assert metadata.uuid == chat_memory.id assert metadata.embedding == DEFAULT_EMBEDDING_DATA.embedding assert metadata.embedding_type_name == "MockEmbeddingGenerator" diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 8c0ada080..3b87d9ff8 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -6,7 +6,7 @@ import pytest from pyrit.memory.memory_exporter import MemoryExporter -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from sqlalchemy.inspection import inspect @@ -14,8 +14,8 @@ def sample_conversations(): # Create some instances of ConversationStore with sample data return [ - ConversationData(role="User", content="Hello, how are you?", conversation_id="12345"), - ConversationData(role="Bot", content="I'm fine, thank you!", conversation_id="12345"), + PromptMemoryEntry(role="User", content="Hello, how are you?", conversation_id="12345"), + PromptMemoryEntry(role="Bot", content="I'm fine, thank you!", conversation_id="12345"), ] diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 3306c5ea3..8381bff3d 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -9,7 +9,7 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.models import ChatMessage -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry @pytest.fixture @@ -39,7 +39,7 @@ def test_memory(memory: MemoryInterface): def test_conversation_memory_empty_by_default(memory: MemoryInterface): expected_count = 0 - c = memory.get_all_memory(ConversationData) + c = memory.get_all_memory(PromptMemoryEntry) assert len(c) == expected_count @@ -49,7 +49,7 @@ def test_count_of_memories_matches_number_of_conversations_added_1( expected_count = 1 message = ChatMessage(role="user", content="Hello") memory.add_chat_message_to_memory(conversation=message, conversation_id="1", labels=[]) - c = memory.get_all_memory(ConversationData) + c = memory.get_all_memory(PromptMemoryEntry) assert len(c) == expected_count @@ -58,7 +58,7 @@ def test_add_chate_message_to_memory_added(memory: MemoryInterface): memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1") memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1") memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1") - assert len(memory.get_all_memory(ConversationData)) == expected_count + assert len(memory.get_all_memory(PromptMemoryEntry)) == expected_count def test_add_chate_messages_to_memory_added(memory: MemoryInterface): @@ -68,4 +68,4 @@ def test_add_chate_messages_to_memory_added(memory: MemoryInterface): ] memory.add_chat_messages_to_memory(conversations=messages, conversation_id="1") - assert len(memory.get_all_memory(ConversationData)) == len(messages) + assert len(memory.get_all_memory(PromptMemoryEntry)) == len(messages) diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index c477fa09a..b14f1f830 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -10,7 +10,7 @@ import pytest from sqlalchemy import inspect -from pyrit.memory.memory_models import ConversationData +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.orchestrator import ScoringRedTeamingOrchestrator, EndTokenRedTeamingOrchestrator from pyrit.orchestrator.end_token_red_teaming_orchestrator import RED_TEAM_CONVERSATION_END_TOKEN from pyrit.prompt_target import AzureOpenAIChatTarget @@ -152,7 +152,7 @@ def test_send_prompt_twice( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 5, f"Expected 5 conversations, got {len(conversations)}" check_conversations( @@ -173,7 +173,7 @@ def test_send_prompt_twice( mock_target.return_value = expected_target_responses[1] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[1] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) # Expecting another two conversation threads assert len(conversations) == 9, f"Expected 9 conversations, got {len(conversations)}" check_conversations( @@ -218,7 +218,7 @@ def test_send_fixed_prompt_then_generated_prompt( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt(prompt=fixed_input_prompt) assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 2, f"Expected 2 conversations, got {len(conversations)}" check_conversations( @@ -240,7 +240,7 @@ def test_send_fixed_prompt_then_generated_prompt( mock_target.return_value = expected_target_responses[1] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[1] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) # Expecting another two conversation threads assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations( @@ -285,7 +285,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt(prompt=fixed_input_prompt) assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 2, f"Expected 2 conversations, got {len(conversations)}" check_conversations( @@ -341,7 +341,7 @@ def test_reach_goal_after_two_turns_end_token( mock_target.return_value = expected_target_response target_response = red_teaming_orchestrator.apply_attack_strategy_until_completion() assert target_response == expected_target_response - conversations = red_teaming_orchestrator._memory.get_all_memory(ConversationData) + conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) # Expecting three conversation threads (two with red teaming chat and one with prompt target) assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations( From d245ce6b120fa397c931e95d29ec47bd80abb794 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 26 Mar 2024 11:11:35 -0700 Subject: [PATCH 02/15] fixing piping --- pyrit/memory/memory_models.py | 14 ++++--- tests/memory/test_duckdb_memory.py | 60 ++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index a5265f4bc..1b1a0105c 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -9,9 +9,9 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict -from sqlalchemy import Column, String, DateTime, Float, Enum, JSON, ForeignKey, Index +from sqlalchemy import Column, String, DateTime, Float, Enum, JSON, ForeignKey, Index, INTEGER, ARRAY from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.dialects.postgresql import ARRAY, UUID +from sqlalchemy.dialects.postgresql import UUID Base = declarative_base() @@ -36,6 +36,7 @@ class PromptMemoryEntry(Base): # type: ignore id (UUID): The unique identifier for the memory entry. prompt_entry_type (PromptType): The type of the prompt entry (system, request_segment, response). conversation_id (str): The identifier for the conversation which is associated with a single target. + sequence (int): The order of the conversation within a conversation_id timestamp (DateTime): The timestamp of the memory entry. labels (Dict[str, str]): The labels associated with the memory entry. prompt_metadata (JSON): The metadata associated with the prompt. @@ -46,18 +47,19 @@ class PromptMemoryEntry(Base): # type: ignore original_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. converted_prompt_data_type (PromptDataType): The data type of the converted prompt (text, image) converted_prompt_text (str): The text of the converted prompt. If prompt is an image, it's a link. - original_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. + converted_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. idx_conversation_id (Index): The index for the conversation ID. Methods: __str__(): Returns a string representation of the memory entry. """ - __tablename__ = "MemoryEntries" + __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} id = Column(UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid4) prompt_entry_type = Column(Enum(PromptType)) conversation_id = Column(String, nullable=False) + sequence = Column(INTEGER, nullable=False, default=0) timestamp = Column(DateTime, nullable=False, default=datetime.utcnow) labels: Column[Dict[str, str]] = Column(JSON) prompt_metadata = Column(JSON) @@ -70,7 +72,7 @@ class PromptMemoryEntry(Base): # type: ignore converted_prompt_data_type = Column(Enum(PromptDataType)) converted_prompt_text = Column(String) - original_prompt_data_sha256 = Column(String) + converted_prompt_data_sha256 = Column(String) idx_conversation_id = Index("idx_conversation_id", "conversation_id") @@ -92,7 +94,7 @@ class EmbeddingData(Base): # type: ignore __tablename__ = "EmbeddingData" # Allows table redefinition if already defined. __table_args__ = {"extend_existing": True} - uuid = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.uuid"), primary_key=True) + uuid = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True) embedding = Column(ARRAY(Float)) embedding_type_name = Column(String) diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 7cb56f56c..8eaeadd1c 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -8,12 +8,14 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import inspect -from sqlalchemy.dialects.postgresql import UUID, ARRAY +from sqlalchemy import String, DateTime, Float, Enum, JSON, ForeignKey, Index, INTEGER, ARRAY +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import String, DateTime from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory import DuckDBMemory +from pyrit.prompt_converter import PromptConverter @pytest.fixture @@ -26,8 +28,8 @@ def setup_duckdb_database(): inspector = inspect(duckdb_memory.engine) # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." + assert "PromptMemoryEntries" in inspector.get_table_names(), "PromptMemoryEntries table not created." + assert "EmbeddingData" in inspector.get_table_names(), "EmbeddingData table not created." yield duckdb_memory duckdb_memory.dispose_engine() @@ -47,54 +49,72 @@ def mock_session(): def test_conversation_data_schema(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("ConversationStore") + columns = inspector.get_columns("PromptMemoryEntries") column_names = [col["name"] for col in columns] # Expected columns in ConversationData - expected_columns = ["uuid", "role", "content", "conversation_id", "timestamp", "normalizer_id", "sha256", "labels"] + expected_columns = ["id", + "prompt_entry_type", + "conversation_id", + "sequence", + "timestamp", + "labels", + "prompt_metadata", + "converters", + "prompt_target", + "original_prompt_data_type", + "original_prompt_text", + "original_prompt_data_sha256", + "converted_prompt_data_type", + "converted_prompt_text", + "converted_prompt_data_sha256"] + for column in expected_columns: - assert column in column_names, f"{column} not found in ConversationStore schema." + assert column in column_names, f"{column} not found in PromptMemoryEntries schema." def test_embedding_data_schema(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("EmbeddingStore") + columns = inspector.get_columns("EmbeddingData") column_names = [col["name"] for col in columns] # Expected columns in EmbeddingData expected_columns = ["uuid", "embedding", "embedding_type_name"] for column in expected_columns: - assert column in column_names, f"{column} not found in EmbeddingStore schema." + assert column in column_names, f"{column} not found in EmbeddingData schema." def test_conversation_data_column_types(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("ConversationStore") + columns = inspector.get_columns("PromptMemoryEntries") column_types = {col["name"]: type(col["type"]) for col in columns} # Expected column types in ConversationData expected_column_types = { - "uuid": UUID, - "role": String, - "content": String, + "id": UUID, + "prompt_entry_type": Enum, "conversation_id": String, + "sequence": INTEGER, "timestamp": DateTime, - "normalizer_id": String, - "sha256": String, - "labels": ARRAY, + "labels": String, + "prompt_metadata": String, + "converters": String, + "prompt_target": String, + "original_prompt_data_type": String, + "original_prompt_text": String, + "original_prompt_data_sha256": String, + "converted_prompt_data_type": String, + "converted_prompt_text": String, + "converted_prompt_data_sha256": String, } for column, expected_type in expected_column_types.items(): if column != "labels": - assert column in column_types, f"{column} not found in ConversationStore schema." + assert column in column_types, f"{column} not found in PromptMemoryEntries schema." assert issubclass( column_types[column], expected_type ), f"Expected {column} to be a subclass of {expected_type}, got {column_types[column]} instead." - # Handle 'labels' column separately - assert "labels" in column_types, "'labels' column not found in ConversationStore schema." - # Check if 'labels' column type is either NullType (due to reflection issue) or ARRAY - assert column_types["labels"] in [NullType, ARRAY], f"Unexpected type for 'labels' column: {column_types['labels']}" def test_embedding_data_column_types(setup_duckdb_database): From bc1d1ccd720817698a4163d21ed6d6e51d284eb0 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 26 Mar 2024 19:28:23 -0700 Subject: [PATCH 03/15] going to try something --- pyrit/memory/duckdb_memory.py | 6 +- pyrit/memory/memory_models.py | 10 +- pyrit/models.py | 6 +- pyrit/prompt_converter/prompt_converter.py | 15 ++- pyrit/prompt_target/prompt_target.py | 7 ++ pyrit/prompt_target/text_target.py | 5 +- tests/memory/test_duckdb_memory.py | 104 +++++++++++++-------- tests/memory/test_memory_embedding.py | 2 +- tests/memory/test_memory_encoder.py | 2 +- 9 files changed, 100 insertions(+), 57 deletions(-) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 158f9ddd1..8dc1ccfc6 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -117,6 +117,10 @@ def insert_entries(self, *, entries: list[Base]) -> None: # type: ignore session.rollback() logger.exception(f"Error inserting multiple entries into the table: {e}") + #def query_entries_by_label(self, model, *, key: str, value: str) -> list[Base]: # type: ignore + # + # return self.query_entries(PromptMemoryEntry, conditions=model.labels[key] == value) + def query_entries(self, model, *, conditions: Optional = None) -> list[Base]: # type: ignore """ Fetches data from the specified table model with optional conditions. @@ -193,7 +197,7 @@ def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptM list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. """ try: - return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.normalizer_id == normalizer_id) + return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.labels["normalizer_id"] == normalizer_id) except Exception as e: logger.exception( f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 1b1a0105c..67fd3c00c 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -16,10 +16,6 @@ Base = declarative_base() -class PromptType(enum.Enum): - SYSTEM = 'system' - REQUEST_SEGMENT = 'request_segment' - RESPONSE = 'response' class PromptDataType(enum.Enum): TEXT = 'text' @@ -57,7 +53,7 @@ class PromptMemoryEntry(Base): # type: ignore __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} id = Column(UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid4) - prompt_entry_type = Column(Enum(PromptType)) + role: 'Column[ChatMessageRole]' = Column(String, nullable=False) conversation_id = Column(String, nullable=False) sequence = Column(INTEGER, nullable=False, default=0) timestamp = Column(DateTime, nullable=False, default=datetime.utcnow) @@ -94,12 +90,12 @@ class EmbeddingData(Base): # type: ignore __tablename__ = "EmbeddingData" # Allows table redefinition if already defined. __table_args__ = {"extend_existing": True} - uuid = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True) + id = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True) embedding = Column(ARRAY(Float)) embedding_type_name = Column(String) def __str__(self): - return f"{self.uuid}" + return f"{self.id}" class ConversationMessageWithSimilarity(BaseModel): diff --git a/pyrit/models.py b/pyrit/models.py index e15e91ba9..f0e8c8a37 100644 --- a/pyrit/models.py +++ b/pyrit/models.py @@ -15,10 +15,8 @@ from pydantic import BaseModel, ConfigDict -# Originally derived from this: -# https://github.com/openai/openai-python/blob/7f9e85017a0959e3ba07834880d92c748f8f67ab/src/openai/types/chat/chat_completion_role.py#L4 -ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "tool", "function"] -ChatMessageRole = Literal["system", "user", "assistant", "tool", "function"] +ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant"] +ChatMessageRole = Literal["system", "user", "assistant"] @dataclass diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 05e7ad028..2528ad980 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import abc - +import json class PromptConverter(abc.ABC): """ @@ -26,3 +26,16 @@ def convert(self, prompts: list[str]) -> list[str]: def is_one_to_one_converter(self) -> bool: """Indicates if the conversion results in exactly one resulting prompt.""" pass + + def to_dict(self): + public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith('_')} + public_attributes['__type__'] = self.__class__.__name__ + public_attributes['__module__'] = self.__class__.__module__ + return public_attributes + +class PromptConverterList(): + def __init__(self, converters: list[PromptConverter]) -> None: + self.converters = converters + + def to_json(self): + return json.dumps([converter.to_dict() for converter in self.converters]) \ No newline at end of file diff --git a/pyrit/prompt_target/prompt_target.py b/pyrit/prompt_target/prompt_target.py index 90c1072bf..cc9f2b484 100644 --- a/pyrit/prompt_target/prompt_target.py +++ b/pyrit/prompt_target/prompt_target.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import abc +import json from pyrit.memory import MemoryInterface from pyrit.memory import DuckDBMemory @@ -29,3 +30,9 @@ async def send_prompt_async(self, *, normalized_prompt: str, conversation_id: st """ Sends a normalized prompt async to the prompt target. """ + + def to_dict(self): + public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith('_')} + public_attributes['__type__'] = self.__class__.__name__ + public_attributes['__module__'] = self.__class__.__module__ + return json.dumps(public_attributes) diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 64d234dd3..e5450acee 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -22,11 +22,12 @@ class TextTarget(PromptTarget): def __init__(self, *, text_stream: IO[str] = sys.stdout, memory: MemoryInterface = None) -> None: super().__init__(memory=memory) - self.text_stream = text_stream + self.stream_name = text_stream.name + self._text_stream = text_stream def send_prompt(self, *, normalized_prompt: str, conversation_id: str, normalizer_id: str) -> str: msg = ChatMessage(role="user", content=normalized_prompt) - self.text_stream.write(f"{str(msg)}\n") + self._text_stream.write(f"{str(msg)}\n") self._memory.add_chat_message_to_memory( conversation=msg, conversation_id=conversation_id, normalizer_id=normalizer_id diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 8eaeadd1c..17230595a 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json import pytest import uuid import datetime @@ -13,9 +14,12 @@ from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.types import String, DateTime -from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData, PromptDataType from pyrit.memory import DuckDBMemory from pyrit.prompt_converter import PromptConverter +from pyrit.prompt_converter.base64_converter import Base64Converter +from pyrit.prompt_converter.prompt_converter import PromptConverterList +from pyrit.prompt_target.text_target import TextTarget @pytest.fixture @@ -54,7 +58,7 @@ def test_conversation_data_schema(setup_duckdb_database): # Expected columns in ConversationData expected_columns = ["id", - "prompt_entry_type", + "role", "conversation_id", "sequence", "timestamp", @@ -92,7 +96,7 @@ def test_conversation_data_column_types(setup_duckdb_database): # Expected column types in ConversationData expected_column_types = { "id": UUID, - "prompt_entry_type": Enum, + "role": String, "conversation_id": String, "sequence": INTEGER, "timestamp": DateTime, @@ -119,12 +123,12 @@ def test_conversation_data_column_types(setup_duckdb_database): def test_embedding_data_column_types(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) - columns = inspector.get_columns("EmbeddingStore") + columns = inspector.get_columns("EmbeddingData") column_types = {col["name"]: col["type"].__class__ for col in columns} # Expected column types in EmbeddingData expected_column_types = { - "uuid": UUID, + "id": UUID, "embedding": ARRAY, "embedding_type_name": String, } @@ -137,7 +141,7 @@ def test_embedding_data_column_types(setup_duckdb_database): column_types[column], expected_type ), f"Expected {column} to be a subclass of {expected_type}, got {column_types[column]} instead." # Handle 'embedding' column separately - assert "embedding" in column_types, "'embedding' column not found in EmbeddingStore schema." + assert "embedding" in column_types, "'embedding' column not found in EmbeddingData schema." # Check if 'embedding' column type is either NullType (due to reflection issue) or ARRAY assert column_types["embedding"] in [ NullType, @@ -148,10 +152,13 @@ def test_embedding_data_column_types(setup_duckdb_database): def test_insert_entry(setup_duckdb_database): session = setup_duckdb_database.get_session() entry = PromptMemoryEntry( + id=uuid.uuid4(), conversation_id="123", role="user", - content="Hello", - sha256="abc", + + original_prompt_data_type=PromptDataType.TEXT, + original_prompt_text="Hello", + original_prompt_data_sha256="abc", ) # Use the insert_entry method to insert the entry into the database setup_duckdb_database.insert_entry(entry) @@ -161,16 +168,16 @@ def test_insert_entry(setup_duckdb_database): inserted_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() assert inserted_entry is not None assert inserted_entry.role == "user" - assert inserted_entry.content == "Hello" - assert inserted_entry.sha256 == "abc" + assert inserted_entry.original_prompt_text == "Hello" + assert inserted_entry.original_prompt_data_sha256 == "abc" def test_insert_entry_violates_constraint(setup_duckdb_database): # Generate a fixed UUID fixed_uuid = uuid.uuid4() # Create two entries with the same UUID - entry1 = PromptMemoryEntry(uuid=fixed_uuid, conversation_id="123", role="user", content="Hello") - entry2 = PromptMemoryEntry(uuid=fixed_uuid, conversation_id="456", role="user", content="Hello again") + entry1 = PromptMemoryEntry(id=fixed_uuid, conversation_id="123", role="user", original_prompt_text="Hello") + entry2 = PromptMemoryEntry(id=fixed_uuid, conversation_id="456", role="user", original_prompt_text="Hello again") # Insert the first entry with setup_duckdb_database.get_session() as session: @@ -186,7 +193,7 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): def test_insert_entries(setup_duckdb_database): entries = [ - PromptMemoryEntry(conversation_id=str(i), role="user", content=f"Message {i}", sha256=f"hash{i}") + PromptMemoryEntry(conversation_id=str(i), role="user", original_prompt_text=f"Message {i}", original_prompt_data_sha256=f"hash{i}") for i in range(5) ] @@ -199,13 +206,13 @@ def test_insert_entries(setup_duckdb_database): for i, entry in enumerate(inserted_entries): assert entry.conversation_id == str(i) assert entry.role == "user" - assert entry.content == f"Message {i}" - assert entry.sha256 == f"hash{i}" + assert entry.original_prompt_text == f"Message {i}" + assert entry.original_prompt_data_sha256 == f"hash{i}" def test_insert_embedding_entry(setup_duckdb_database): # Create a ConversationData entry - conversation_entry = PromptMemoryEntry(conversation_id="123", role="user", content="Hello", sha256="abc") + conversation_entry = PromptMemoryEntry(conversation_id="123", role="user", original_prompt_text="Hello", original_prompt_data_sha256="abc") # Insert the ConversationData entry using the insert_entry method setup_duckdb_database.insert_entry(conversation_entry) @@ -214,15 +221,15 @@ def test_insert_embedding_entry(setup_duckdb_database): with setup_duckdb_database.get_session() as session: # Assuming uuid is the primary key and is set upon insertion reattached_conversation_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").one() - uuid = reattached_conversation_entry.uuid + uuid = reattached_conversation_entry.id # Now that we have the uuid, we can create and insert the EmbeddingData entry - embedding_entry = EmbeddingData(uuid=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") + embedding_entry = EmbeddingData(id=uuid, embedding=[1, 2, 3], embedding_type_name="test_type") setup_duckdb_database.insert_entry(embedding_entry) # Verify the EmbeddingData entry was inserted correctly with setup_duckdb_database.get_session() as session: - persisted_embedding_entry = session.query(EmbeddingData).filter_by(uuid=uuid).first() + persisted_embedding_entry = session.query(EmbeddingData).filter_by(id=uuid).first() assert persisted_embedding_entry is not None assert persisted_embedding_entry.embedding == [1, 2, 3] assert persisted_embedding_entry.embedding_type_name == "test_type" @@ -230,7 +237,7 @@ def test_insert_embedding_entry(setup_duckdb_database): def test_query_entries(setup_duckdb_database): # Insert some test data - entries = [PromptMemoryEntry(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] + entries = [PromptMemoryEntry(conversation_id=str(i), role="user", original_prompt_text=f"Message {i}") for i in range(3)] setup_duckdb_database.insert_entries(entries=entries) # Query entries without conditions @@ -242,29 +249,29 @@ def test_query_entries(setup_duckdb_database): PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "1" ) assert len(specific_entry) == 1 - assert specific_entry[0].content == "Message 1" + assert specific_entry[0].original_prompt_text == "Message 1" def test_update_entries(setup_duckdb_database): # Insert a test entry - entry = PromptMemoryEntry(conversation_id="123", role="user", content="Hello") + entry = PromptMemoryEntry(conversation_id="123", role="user", original_prompt_text="Hello") setup_duckdb_database.insert_entry(entry) # Fetch the entry to update and update its content entries_to_update = setup_duckdb_database.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "123" ) - setup_duckdb_database.update_entries(entries=entries_to_update, update_fields={"content": "Updated Hello"}) + setup_duckdb_database.update_entries(entries=entries_to_update, update_fields={"original_prompt_text": "Updated Hello"}) # Verify the entry was updated with setup_duckdb_database.get_session() as session: updated_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="123").first() - assert updated_entry.content == "Updated Hello" + assert updated_entry.original_prompt_text == "Updated Hello" def test_get_all_memory(setup_duckdb_database): # Insert some test data - entries = [PromptMemoryEntry(conversation_id=str(i), role="user", content=f"Message {i}") for i in range(3)] + entries = [PromptMemoryEntry(conversation_id=str(i), role="user", original_prompt_text=f"Message {i}") for i in range(3)] setup_duckdb_database.insert_entries(entries=entries) # Fetch all entries @@ -272,21 +279,27 @@ def test_get_all_memory(setup_duckdb_database): assert len(all_entries) == 3 -def test_get_memories_with_conversation_id(setup_duckdb_database): +def test_get_memories_with_json_properties(setup_duckdb_database): # Define a specific conversation_id specific_conversation_id = "test_conversation_id" + converters = PromptConverterList([Base64Converter()]).to_json() + target = TextTarget().to_dict() + # Start a session with setup_duckdb_database.get_session() as session: # Create a ConversationData entry with all attributes filled entry = PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", - content="Test content", + sequence=1, + + original_prompt_text="Test content", timestamp=datetime.datetime.utcnow(), - normalizer_id="test_normalizer_id", - sha256="test_sha256", - labels=["label1", "label2"], + original_prompt_data_sha256="test_sha256", + labels={"normalizer_id": "id1"}, + converters=converters, + prompt_target=target, ) # Insert the ConversationData entry @@ -303,39 +316,50 @@ def test_get_memories_with_conversation_id(setup_duckdb_database): retrieved_entry = retrieved_entries[0] assert retrieved_entry.conversation_id == specific_conversation_id assert retrieved_entry.role == "user" - assert retrieved_entry.content == "Test content" + assert retrieved_entry.original_prompt_text == "Test content" # For timestamp, you might want to check if it's close to the current time instead of an exact match assert abs((retrieved_entry.timestamp - entry.timestamp).total_seconds()) < 10 # Assuming the test runs quickly - assert retrieved_entry.normalizer_id == "test_normalizer_id" - assert retrieved_entry.sha256 == "test_sha256" - assert retrieved_entry.labels == ["label1", "label2"] + assert retrieved_entry.original_prompt_data_sha256 == "test_sha256" + + converters = json.loads(retrieved_entry.converters) + assert len(converters) == 1 + assert converters[0]["__type__"] == "Base64Converter" + + prompt_target = json.loads(retrieved_entry.prompt_target) + assert prompt_target["__type__"] == "TextTarget" + + labels = retrieved_entry.labels + assert labels["normalizer_id"] == "id1" def test_get_memories_with_normalizer_id(setup_duckdb_database): # Define a specific normalizer_id specific_normalizer_id = "normalizer_test_id" + labels = {"normalizer_id": specific_normalizer_id} + other_labels = {"normalizer_id": "other_normalizer_id"} + # Create a list of ConversationData entries, some with the specific normalizer_id entries = [ PromptMemoryEntry( conversation_id="123", role="user", - content="Hello 1", - normalizer_id=specific_normalizer_id, + original_prompt_text="Hello 1", + labels=labels, timestamp=datetime.datetime.utcnow(), ), PromptMemoryEntry( conversation_id="456", role="user", - content="Hello 2", - normalizer_id="other_normalizer_id", + original_prompt_text="Hello 2", + labels=other_labels, timestamp=datetime.datetime.utcnow(), ), PromptMemoryEntry( conversation_id="789", role="user", - content="Hello 3", - normalizer_id=specific_normalizer_id, + original_prompt_text="Hello 3", + labels=labels, timestamp=datetime.datetime.utcnow(), ), ] diff --git a/tests/memory/test_memory_embedding.py b/tests/memory/test_memory_embedding.py index 73d8553d0..c34b47319 100644 --- a/tests/memory/test_memory_embedding.py +++ b/tests/memory/test_memory_embedding.py @@ -56,7 +56,7 @@ def test_memory_encoding_chat_message( conversation_id="my_session", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) - assert metadata.uuid == chat_memory.id + assert metadata.id == chat_memory.id assert metadata.embedding == DEFAULT_EMBEDDING_DATA.embedding assert metadata.embedding_type_name == "MockEmbeddingGenerator" diff --git a/tests/memory/test_memory_encoder.py b/tests/memory/test_memory_encoder.py index 7cc16574f..92784b9b0 100644 --- a/tests/memory/test_memory_encoder.py +++ b/tests/memory/test_memory_encoder.py @@ -59,6 +59,6 @@ def test_memory_encoding_chat_message( conversation_id="my_session", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) - assert metadata.uuid == chat_memory.id + assert metadata.id == chat_memory.id assert metadata.embedding == DEFAULT_EMBEDDING_DATA.embedding assert metadata.embedding_type_name == "MockEmbeddingGenerator" From 63f8b01870d29ea5ddba9ee31d10e0dd7ad9afec Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 27 Mar 2024 09:45:04 -0700 Subject: [PATCH 04/15] fetching json attributes --- pyrit/memory/duckdb_memory.py | 2 +- tests/memory/test_duckdb_memory.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 8dc1ccfc6..9b31a981e 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -197,7 +197,7 @@ def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptM list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. """ try: - return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.labels["normalizer_id"] == normalizer_id) + return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op('->>')('normalizer_id') == normalizer_id) except Exception as e: logger.exception( f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 17230595a..b127d1bf4 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -375,8 +375,8 @@ def test_get_memories_with_normalizer_id(setup_duckdb_database): # Verify that the retrieved entries match the expected normalizer_id assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id for retrieved_entry in retrieved_entries: - assert retrieved_entry.normalizer_id == specific_normalizer_id - assert "Hello" in retrieved_entry.content # Basic check to ensure content is as expected + assert retrieved_entry.labels["normalizer_id"] == specific_normalizer_id + assert "Hello" in retrieved_entry.original_prompt_text # Basic check to ensure content is as expected def test_update_entries_by_conversation_id(setup_duckdb_database): From b9d69345015df1e69e590030e2fb424970e51ed0 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 27 Mar 2024 09:48:25 -0700 Subject: [PATCH 05/15] Cleaning up docs --- pyrit/memory/memory_models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 67fd3c00c..5eb2b9ef3 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -30,11 +30,12 @@ class PromptMemoryEntry(Base): # type: ignore __tablename__ (str): The name of the database table. __table_args__ (dict): Additional arguments for the database table. id (UUID): The unique identifier for the memory entry. - prompt_entry_type (PromptType): The type of the prompt entry (system, request_segment, response). + role (PromptType): system, request_segment, response conversation_id (str): The identifier for the conversation which is associated with a single target. - sequence (int): The order of the conversation within a conversation_id + sequence (int): The order of the conversation within a conversation_id. + Can be the same number for multi-part requests or responses. timestamp (DateTime): The timestamp of the memory entry. - labels (Dict[str, str]): The labels associated with the memory entry. + labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. prompt_metadata (JSON): The metadata associated with the prompt. converters (list[PromptConverter]): The converters for the prompt. prompt_target (PromptTarget): The target for the prompt. From 27fae4e7912acb9e03dc018110a48d3437965c2d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 27 Mar 2024 11:15:04 -0700 Subject: [PATCH 06/15] saving work --- tests/memory/test_duckdb_memory.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index b127d1bf4..d5dcd30c0 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -83,7 +83,7 @@ def test_embedding_data_schema(setup_duckdb_database): column_names = [col["name"] for col in columns] # Expected columns in EmbeddingData - expected_columns = ["uuid", "embedding", "embedding_type_name"] + expected_columns = ["id", "embedding", "embedding_type_name"] for column in expected_columns: assert column in column_names, f"{column} not found in EmbeddingData schema." @@ -388,16 +388,19 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", - content="Original content 1", + original_prompt_text="Original content 1", timestamp=datetime.datetime.utcnow(), ), PromptMemoryEntry( - conversation_id="other_id", role="user", content="Original content 2", timestamp=datetime.datetime.utcnow() + conversation_id="other_id", + role="user", + original_prompt_text="Original content 2", + timestamp=datetime.datetime.utcnow() ), PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", - content="Original content 3", + original_prompt_text="Original content 3", timestamp=datetime.datetime.utcnow(), ), ] @@ -408,7 +411,7 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): session.commit() # Ensure all entries are committed to the database # Define the fields to update for entries with the specific conversation_id - update_fields = {"content": "Updated content", "role": "assistant"} + update_fields = {"original_prompt_text": "Updated content", "role": "assistant"} # Use the update_entries_by_conversation_id method to update the entries update_result = setup_duckdb_database.update_entries_by_conversation_id( @@ -421,10 +424,10 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == specific_conversation_id ) for entry in updated_entries: - assert entry.content == "Updated content" + assert entry.original_prompt_text == "Updated content" assert entry.role == "assistant" # Verify that the entry with a different conversation_id was not updated other_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="other_id").first() - assert other_entry.content == "Original content 2" # Content should remain unchanged + assert other_entry.original_prompt_text == "Original content 2" # Content should remain unchanged assert other_entry.role == "user" # Role should remain unchanged From 28fbac53d0ea708a88c74400262ced788ce0facd Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 28 Mar 2024 09:28:16 -0700 Subject: [PATCH 07/15] continuting to plum, still a ways to go --- pyrit/analytics/conversation_analytics.py | 12 +- pyrit/memory/duckdb_memory.py | 167 ++++++++++-------- pyrit/memory/memory_embedding.py | 15 +- pyrit/memory/memory_interface.py | 125 +++++-------- pyrit/memory/memory_models.py | 73 ++++++-- .../prompt_sending_orchestrator.py | 2 +- .../orchestrator/red_teaming_orchestrator.py | 2 +- pyrit/prompt_target/azure_ml_chat_target.py | 2 +- pyrit/prompt_target/openai_chat_target.py | 2 +- .../analytics/test_conversation_analytics.py | 22 +-- tests/memory/test_duckdb_memory.py | 84 ++++++--- tests/memory/test_memory_embedding.py | 4 +- tests/memory/test_memory_encoder.py | 4 +- tests/memory/test_memory_interface.py | 30 +--- tests/mocks.py | 20 +++ tests/test_prompt_target.py | 27 +-- .../test_prompt_target_azure_blob_storage.py | 24 +-- tests/test_prompt_target_text.py | 19 +- tests/test_red_teaming_orchestrator.py | 2 +- 19 files changed, 333 insertions(+), 303 deletions(-) diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index a813320c5..f3bb5b1ed 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -24,11 +24,11 @@ def __init__(self, *, memory_interface: MemoryInterface): """ self.memory_interface = memory_interface - def get_similar_chat_messages_by_content( + def get_prompt_entries_with_same_converted_content( self, *, chat_message_content: str ) -> list[ConversationMessageWithSimilarity]: """ - Retrieves chat messages that are similar to the given content based on exact matches. + Retrieves chat messages that have the same converted content Args: chat_message_content (str): The content of the chat message to find similar messages for. @@ -37,16 +37,16 @@ def get_similar_chat_messages_by_content( list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on content. """ - all_memories = self.memory_interface.get_all_memory(PromptMemoryEntry) + all_memories = self.memory_interface.get_all_prompt_entries(PromptMemoryEntry) similar_messages = [] for memory in all_memories: - if memory.content == chat_message_content: + if memory.converted_prompt_text == chat_message_content: similar_messages.append( ConversationMessageWithSimilarity( score=1.0, role=memory.role, - content=memory.content, + content=memory.converted_prompt_text, metric="exact_match", # Exact match ) ) @@ -67,7 +67,7 @@ def get_similar_chat_messages_by_embedding( List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on embedding similarity. """ - all_memories = self.memory_interface.get_all_memory(EmbeddingData) + all_memories = self.memory_interface.get_all_prompt_entries(EmbeddingData) similar_messages = [] target_embedding = np.array(chat_message_embedding).reshape(1, -1) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 9b31a981e..96cebd334 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -33,15 +33,23 @@ class DuckDBMemory(MemoryInterface, metaclass=Singleton): DEFAULT_DB_FILE_NAME = "pyrit_duckdb_storage.db" def __init__( - self, *, db_path: Union[Path, str] = None, embedding_model: EmbeddingSupport = None, has_echo: bool = False + self, + *, + db_path: Union[Path, str] = None, + embedding_model: EmbeddingSupport = None, + disable_embedding: bool = False, + verbose: bool = False ): super(DuckDBMemory, self).__init__() - self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) + if not disable_embedding: + self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) + if db_path == ":memory:": self.db_path: Union[Path, str] = ":memory:" else: self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve() - self.engine = self._create_engine(has_echo=has_echo) + + self.engine = self._create_engine(has_echo=verbose) self.SessionFactory = sessionmaker(bind=self.engine) self._create_tables_if_not_exist() @@ -76,6 +84,90 @@ def _create_tables_if_not_exist(self): except Exception as e: logger.error(f"Error during table creation: {e}") + def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: + """ + Fetches all entries from the specified table and returns them as model instances. + """ + result = self.query_entries(PromptMemoryEntry) + return result + + def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: + """ + Retrieves a list of ConversationData objects that have the specified conversation ID. + + Args: + conversation_id (str): The conversation ID to filter the table. + + Returns: + list[ConversationData]: A list of ConversationData objects matching the specified conversation ID. + """ + try: + return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id) + except Exception as e: + logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") + return [] + + def get_prompt_entries_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: + """ + Retrieves a list of ConversationData objects that have the specified normalizer ID. + + Args: + normalizer_id (str): The normalizer ID to filter the table. + + Returns: + list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. + """ + try: + return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op('->>')('normalizer_id') == normalizer_id) + except Exception as e: + logger.exception( + f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" + ) + return [] + + def insert_prompt_entries(self, *, entries: list[PromptMemoryEntry]) -> None: + """ + Inserts a list of prompt entries into the memory storage. + If necessary, generates embedding data for applicable entries + + Args: + entries (list[Base]): The list of database model instances to be inserted. + """ + embedding_entries = [] + + if self.memory_embedding: + for chat_entry in entries: + + embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) + embedding_entries.append(embedding_entry) + + self.insert_entries(entries=entries) + self.insert_entries(entries=embedding_entries) + + def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool: + """ + Updates entries for a given conversation ID with the specified field values. + + Args: + conversation_id (str): The conversation ID of the entries to be updated. + update_fields (dict): A dictionary of field names and their new values. + + Returns: + bool: True if the update was successful, False otherwise. + """ + # Fetch the relevant entries using query_entries + entries_to_update = self.query_entries( + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id + ) + + # Check if there are entries to update + if not entries_to_update: + logger.info(f"No entries found with conversation_id {conversation_id} to update.") + return False + + # Use the utility function to update the entries + return self.update_entries(entries=entries_to_update, update_fields=update_fields) + def get_all_table_models(self) -> list[Base]: # type: ignore """ Returns a list of all table models used in the database by inspecting the Base registry. @@ -117,10 +209,6 @@ def insert_entries(self, *, entries: list[Base]) -> None: # type: ignore session.rollback() logger.exception(f"Error inserting multiple entries into the table: {e}") - #def query_entries_by_label(self, model, *, key: str, value: str) -> list[Base]: # type: ignore - # - # return self.query_entries(PromptMemoryEntry, conditions=model.labels[key] == value) - def query_entries(self, model, *, conditions: Optional = None) -> list[Base]: # type: ignore """ Fetches data from the specified table model with optional conditions. @@ -163,71 +251,6 @@ def update_entries(self, *, entries: list[Base], update_fields: dict) -> bool: logger.exception(f"Error updating entries: {e}") return False - def get_all_memory(self, model: Base) -> list[Base]: # type: ignore - """ - Fetches all entries from the specified table and returns them as model instances. - """ - result = self.query_entries(model) - return result - - def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: - """ - Retrieves a list of ConversationData objects that have the specified conversation ID. - - Args: - conversation_id (str): The conversation ID to filter the table. - - Returns: - list[ConversationData]: A list of ConversationData objects matching the specified conversation ID. - """ - try: - return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id) - except Exception as e: - logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") - return [] - - def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: - """ - Retrieves a list of ConversationData objects that have the specified normalizer ID. - - Args: - normalizer_id (str): The normalizer ID to filter the table. - - Returns: - list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. - """ - try: - return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op('->>')('normalizer_id') == normalizer_id) - except Exception as e: - logger.exception( - f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" - ) - return [] - - def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool: - """ - Updates entries for a given conversation ID with the specified field values. - - Args: - conversation_id (str): The conversation ID of the entries to be updated. - update_fields (dict): A dictionary of field names and their new values. - - Returns: - bool: True if the update was successful, False otherwise. - """ - # Fetch the relevant entries using query_entries - entries_to_update = self.query_entries( - PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id - ) - - # Check if there are entries to update - if not entries_to_update: - logger.info(f"No entries found with conversation_id {conversation_id} to update.") - return False - - # Use the utility function to update the entries - return self.update_entries(entries=entries_to_update, update_fields=update_fields) - def dispose_engine(self): """ Dispose the engine and clean up resources. diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index a4c2cfd3d..082a1884e 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -30,12 +30,15 @@ def generate_embedding_memory_data(self, *, chat_memory: PromptMemoryEntry) -> E Returns: ConversationMemoryEntryMetadata: The generated metadata. """ - embedding_data = EmbeddingData( - embedding=self.embedding_model.generate_text_embedding(text=chat_memory.content).data[0].embedding, - embedding_type_name=self.embedding_model.__class__.__name__, - uuid=chat_memory.id, - ) - return embedding_data + if chat_memory.converted_prompt_data_type == "text": + embedding_data = EmbeddingData( + embedding=self.embedding_model.generate_text_embedding(text=chat_memory.converted_prompt_text).data[0].embedding, + embedding_type_name=self.embedding_model.__class__.__name__, + id=chat_memory.id, + ) + return embedding_data + + raise ValueError("Only text data is supported for embedding.") def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) -> MemoryEmbedding | None: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 6179f5193..94fd6907a 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -8,7 +8,7 @@ from uuid import uuid4 -from pyrit.memory.memory_models import Base, PromptMemoryEntry +from pyrit.memory.memory_models import Base, PromptMemoryEntry, EmbeddingData from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.models import ChatMessage @@ -33,13 +33,13 @@ def __init__(self, embedding_model=None): self.exporter = MemoryExporter() @abc.abstractmethod - def get_all_memory(self, model: Base) -> list[PromptMemoryEntry]: # type: ignore + def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: """ Loads all ConversationData from the memory storage handler. """ @abc.abstractmethod - def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: + def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified conversation ID. @@ -51,7 +51,7 @@ def get_memories_with_conversation_id(self, *, conversation_id: str) -> list[Pro """ @abc.abstractmethod - def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: + def get_prompt_entries_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified normalizer ID. @@ -63,50 +63,48 @@ def get_memories_with_normalizer_id(self, *, normalizer_id: str) -> list[PromptM """ @abc.abstractmethod - def insert_entries(self, *, entries: list[Base]) -> None: # type: ignore + def insert_prompt_entries(self, *, entries: list[EmbeddingData]) -> None: """ Inserts a list of entries into the memory storage. + If necessary, generates embedding data for applicable entries + Args: entries (list[Base]): The list of database model instances to be inserted. """ + @abc.abstractmethod - def get_all_table_models(self) -> list[Base]: # type: ignore + def dispose_engine(self): """ - Returns a list of all table models from the database. - - Returns: - list[Base]: A list of SQLAlchemy models. + Dispose the engine and clean up resources. """ - @abc.abstractmethod - def query_entries(self, model, *, conditions: Optional = None) -> list[Base]: # type: ignore + def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> list[ChatMessage]: """ - Fetches data from the specified table model with optional conditions. + Returns the memory for a given conversation_id. Args: - model: The SQLAlchemy model class corresponding to the table you want to query. - conditions: SQLAlchemy filter conditions (optional). + conversation_id (str): The conversation ID. Returns: - List of model instances representing the rows fetched from the table. + list[ChatMessage]: The list of chat messages. """ + memory_entries = self.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) + return [ChatMessage(role=me.role, content=me.converted_prompt_text) for me in memory_entries] # type: ignore - @abc.abstractmethod - def dispose_engine(self): - """ - Dispose the engine and clean up resources. - """ def add_chat_message_to_memory( self, conversation: ChatMessage, conversation_id: str, normalizer_id: str = None, - labels: list[str] = None, + labels: list[str] = {}, ): """ + Deprecated. Will be refactored and removed soon. It currently works incorrectly. + but is included so functionality is maintained. + Adds a single chat conversation entry to the ConversationStore table. If embddings are set, add corresponding embedding entry to the EmbeddingStore table. @@ -116,16 +114,15 @@ def add_chat_message_to_memory( normalizer_id (str): The normalizer ID, labels (list[str]): A list of labels to be added to the memory entry. """ - entries_to_persist = [] - chat_entry = self._create_chat_message_memory_entry( - conversation=conversation, conversation_id=conversation_id, normalizer_id=normalizer_id, labels=labels + + + self.add_chat_messages_to_memory( + conversations=[conversation], + conversation_id=conversation_id, + normalizer_id=normalizer_id, + labels=labels ) - entries_to_persist.append(chat_entry) - if self.memory_embedding: - embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) - entries_to_persist.append(embedding_entry) - self.insert_entries(entries=entries_to_persist) def add_chat_messages_to_memory( self, @@ -133,9 +130,12 @@ def add_chat_messages_to_memory( conversations: list[ChatMessage], conversation_id: str, normalizer_id: str = None, - labels: list[str] = None, + labels: dict[str:str] = {}, ): """ + Deprecated. Will be refactored and removed soon. It currently works incorrectly. + but is included so functionality is maintained. + Adds multiple chat conversation entries to the ConversationStore table. If embddings are set, add corresponding embedding entries to the EmbeddingStore table. @@ -145,64 +145,23 @@ def add_chat_messages_to_memory( normalizer_id (str): The normalizer ID labels (list[str]): A list of labels to be added to the memory entry. """ - entries_to_persist = [] + entries_to_add = [] for conversation in conversations: - chat_entry = self._create_chat_message_memory_entry( - conversation=conversation, conversation_id=conversation_id, normalizer_id=normalizer_id, labels=labels + entry = PromptMemoryEntry( + role=conversation.role, + conversation_id=conversation_id, + original_prompt_text=conversation.content, + converted_prompt_text=conversation.content, + labels=labels, ) - entries_to_persist.append(chat_entry) - if self.memory_embedding: - embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) - entries_to_persist.append(embedding_entry) - self.insert_entries(entries=entries_to_persist) + entry.labels["normalizer_id"] = normalizer_id - def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> list[ChatMessage]: - """ - Returns the memory for a given conversation_id. - - Args: - conversation_id (str): The conversation ID. - - Returns: - list[ChatMessage]: The list of chat messages. - """ - memory_entries = self.get_memories_with_conversation_id(conversation_id=conversation_id) - return [ChatMessage(role=me.role, content=me.content) for me in memory_entries] # type: ignore + entries_to_add.append(entry) - def _create_chat_message_memory_entry( - self, - *, - conversation: ChatMessage, - conversation_id: str, - normalizer_id: str = None, - labels: list[str] = None, - ): - """ - Creates a new `ConversationData` instance representing a chat message entry. - - Args: - conversation (ChatMessage): The chat message to be stored. - conversation_id (str): Conversation ID. - normalizer_id (str): Normalizer ID. - labels (list[str]): Labels associated with the conversation. - - Returns: - ConversationData: A new instance ready to be persisted in the memory storage. - """ - uuid = uuid4() - new_chat_memory = PromptMemoryEntry( - role=conversation.role, - content=conversation.content, - conversation_id=conversation_id, - normalizer_id=normalizer_id, - uuid=uuid, - labels=labels if labels else [], - sha256=sha256(conversation.content.encode()).hexdigest(), - ) + self.insert_prompt_entries(entries=entries_to_add) - return new_chat_memory def export_all_tables(self, *, export_type: str = "json"): """ @@ -232,7 +191,7 @@ def export_conversation_by_id(self, *, conversation_id: str, file_path: Path = N If not provided, a default path using RESULTS_PATH will be constructed. export_type (str): The format of the export. Defaults to "json". """ - data = self.get_memories_with_conversation_id(conversation_id=conversation_id) + data = self.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) # If file_path is not provided, construct a default using the exporter's results_path if not file_path: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 5eb2b9ef3..4866681ca 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -2,10 +2,11 @@ # Licensed under the MIT license. import enum +import hashlib import uuid from datetime import datetime -from typing import Dict +from typing import Dict, Literal from uuid import uuid4 from pydantic import BaseModel, ConfigDict @@ -16,10 +17,13 @@ Base = declarative_base() - +""" class PromptDataType(enum.Enum): TEXT = 'text' - IMAGE = 'image' + IMAGE_URL = 'image_url' +""" + +PromptDataType = Literal["text", "image_url"] class PromptMemoryEntry(Base): # type: ignore @@ -53,34 +57,81 @@ class PromptMemoryEntry(Base): # type: ignore __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} - id = Column(UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid4) + id = Column(UUID(as_uuid=True), nullable=False, primary_key=True) role: 'Column[ChatMessageRole]' = Column(String, nullable=False) conversation_id = Column(String, nullable=False) - sequence = Column(INTEGER, nullable=False, default=0) - timestamp = Column(DateTime, nullable=False, default=datetime.utcnow) + sequence = Column(INTEGER, nullable=False) + timestamp = Column(DateTime, nullable=False) labels: Column[Dict[str, str]] = Column(JSON) prompt_metadata = Column(JSON) converters: 'Column[list[PromptConverter]]' = Column(JSON) prompt_target: 'Column[PromptTarget]' = Column(JSON) - original_prompt_data_type = Column(Enum(PromptDataType)) - original_prompt_text = Column(String) + original_prompt_data_type: PromptDataType = Column(String, nullable=False) + original_prompt_text = Column(String, nullable=False) original_prompt_data_sha256 = Column(String) - converted_prompt_data_type = Column(Enum(PromptDataType)) + converted_prompt_data_type: PromptDataType = Column(String, nullable=False) converted_prompt_text = Column(String) converted_prompt_data_sha256 = Column(String) idx_conversation_id = Index("idx_conversation_id", "conversation_id") + + def __init__(self, + *, + role: str, + original_prompt_text: str, + converted_prompt_text: str, + id: uuid.UUID = None, + conversation_id: str = None, + sequence: int = -1, + labels: Dict[str, str] = None, + prompt_metadata: JSON = None, + converters: 'PromptConverterList' = None, + prompt_target: 'PromptTarget' = None, + original_prompt_data_type: PromptDataType = "text", + converted_prompt_data_type: PromptDataType = "text" + ): + + + self.id = id if id else uuid4() + + self.role = role + self.conversation_id = conversation_id if conversation_id else str(uuid4()) + self.sequence = sequence + + self.timestamp = datetime.utcnow() + self.labels = labels + self.prompt_metadata = prompt_metadata + + self.converters = converters.to_json() if converters else None + self.prompt_target = prompt_target.to_dict() if prompt_target else None + + self.original_prompt_text = original_prompt_text + self.original_prompt_data_type = original_prompt_data_type + self.original_prompt_data_sha256 = self._create_sha256(original_prompt_text) + + + self.converted_prompt_data_type = converted_prompt_data_type + self.converted_prompt_text = converted_prompt_text + self.converted_prompt_data_sha256 = self._create_sha256(converted_prompt_text) + + + + def _create_sha256(self, text: str) -> str: + input_bytes = text.encode('utf-8') + hash_object = hashlib.sha256(input_bytes) + return hash_object.hexdigest() + def __str__(self): - return f"{self.role}: {self.content}" + return f"{self.role}: {self.converted_prompt_text}" class EmbeddingData(Base): # type: ignore """ Represents the embedding data associated with conversation entries in the database. - Each embedding is linked to a specific conversation entry via a 'uuid'. + Each embedding is linked to a specific conversation entry via an id Attributes: uuid (UUID): The primary key, which is a foreign key referencing the UUID in the MemoryEntries table. diff --git a/pyrit/orchestrator/prompt_sending_orchestrator.py b/pyrit/orchestrator/prompt_sending_orchestrator.py index 7fa96d020..3af299129 100644 --- a/pyrit/orchestrator/prompt_sending_orchestrator.py +++ b/pyrit/orchestrator/prompt_sending_orchestrator.py @@ -114,4 +114,4 @@ def get_memory(self): Retrieves the memory associated with the prompt normalizer. """ id = self._prompt_normalizer.id - return self._memory.get_memories_with_normalizer_id(normalizer_id=id) + return self._memory.get_prompt_entries_with_normalizer_id(normalizer_id=id) diff --git a/pyrit/orchestrator/red_teaming_orchestrator.py b/pyrit/orchestrator/red_teaming_orchestrator.py index 5f356e72b..ed24eceb6 100644 --- a/pyrit/orchestrator/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/red_teaming_orchestrator.py @@ -81,7 +81,7 @@ def requires_one_to_one_converters(self) -> bool: return True def get_memory(self): - return self._memory.get_memories_with_normalizer_id(normalizer_id=self._prompt_normalizer.id) + return self._memory.get_prompt_entries_with_normalizer_id(normalizer_id=self._prompt_normalizer.id) @abc.abstractmethod def is_conversation_complete(self, messages: list[ChatMessage], *, red_teaming_chat_role: str) -> bool: diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 68b353e22..cc3677967 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -66,7 +66,7 @@ def __init__( self._repetition_penalty = repetition_penalty def set_system_prompt(self, *, prompt: str, conversation_id: str, normalizer_id: str) -> None: - messages = self._memory.get_memories_with_conversation_id(conversation_id=conversation_id) + messages = self._memory.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) if messages: raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") diff --git a/pyrit/prompt_target/openai_chat_target.py b/pyrit/prompt_target/openai_chat_target.py index 0a8414c42..22513f5a5 100644 --- a/pyrit/prompt_target/openai_chat_target.py +++ b/pyrit/prompt_target/openai_chat_target.py @@ -31,7 +31,7 @@ def __init__(self) -> None: pass def set_system_prompt(self, *, prompt: str, conversation_id: str, normalizer_id: str) -> None: - messages = self._memory.get_memories_with_conversation_id(conversation_id=conversation_id) + messages = self._memory.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) if messages: raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index ed10f5179..554d6cfc7 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -19,14 +19,14 @@ def mock_memory_interface(): def test_get_similar_chat_messages_by_content(mock_memory_interface): # Mock data returned by the memory interface mock_data = [ - PromptMemoryEntry(content="Hello, how are you?", role="user"), - PromptMemoryEntry(content="I'm fine, thank you!", role="assistant"), - PromptMemoryEntry(content="Hello, how are you?", role="assistant"), # Exact match + PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="user"), + PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="I'm fine, thank you!", role="assistant"), + PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="assistant"), # Exact match ] - mock_memory_interface.get_all_memory.return_value = mock_data + mock_memory_interface.get_all_prompt_entries.return_value = mock_data analytics = ConversationAnalytics(memory_interface=mock_memory_interface) - similar_messages = analytics.get_similar_chat_messages_by_content(chat_message_content="Hello, how are you?") + similar_messages = analytics.get_prompt_entries_with_same_converted_content(chat_message_content="Hello, how are you?") # Expect one exact match assert len(similar_messages) == 2 @@ -39,8 +39,8 @@ def test_get_similar_chat_messages_by_content(mock_memory_interface): def test_get_similar_chat_messages_by_embedding(mock_memory_interface): # Mock ConversationData entries conversation_entries = [ - PromptMemoryEntry(uuid=uuid.uuid4(), conversation_id="1", role="user", content="Similar message"), - PromptMemoryEntry(uuid=uuid.uuid4(), conversation_id="2", role="assistant", content="Different message"), + PromptMemoryEntry(original_prompt_text="h", role="user", converted_prompt_text="Similar message"), + PromptMemoryEntry(original_prompt_text="h", role="assistant", converted_prompt_text="Different message"), ] # Mock EmbeddingData entries linked to the ConversationData entries @@ -49,12 +49,12 @@ def test_get_similar_chat_messages_by_embedding(mock_memory_interface): different_embedding = [0.9, 0.8, 0.7] mock_data = [ - EmbeddingData(uuid=conversation_entries[0].id, embedding=similar_embedding, embedding_type_name="model1"), - EmbeddingData(uuid=conversation_entries[1].id, embedding=different_embedding, embedding_type_name="model2"), + EmbeddingData(id=conversation_entries[0].id, embedding=similar_embedding, embedding_type_name="model1"), + EmbeddingData(id=conversation_entries[1].id, embedding=different_embedding, embedding_type_name="model2"), ] - # Mock the get_all_memory method to return the mock EmbeddingData entries - mock_memory_interface.get_all_memory.side_effect = lambda model: ( + # Mock the get_all_prompt_entries method to return the mock EmbeddingData entries + mock_memory_interface.get_all_prompt_entries.side_effect = lambda model: ( mock_data if model == EmbeddingData else conversation_entries ) diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index d5dcd30c0..2ec6578d3 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -155,10 +155,9 @@ def test_insert_entry(setup_duckdb_database): id=uuid.uuid4(), conversation_id="123", role="user", - - original_prompt_data_type=PromptDataType.TEXT, + original_prompt_data_type="text", original_prompt_text="Hello", - original_prompt_data_sha256="abc", + converted_prompt_text="Hello", ) # Use the insert_entry method to insert the entry into the database setup_duckdb_database.insert_entry(entry) @@ -169,15 +168,25 @@ def test_insert_entry(setup_duckdb_database): assert inserted_entry is not None assert inserted_entry.role == "user" assert inserted_entry.original_prompt_text == "Hello" - assert inserted_entry.original_prompt_data_sha256 == "abc" + sha265 = "185f8db32271fe25f561a6fc938b2e264306ec304eda518007d1764826381969" + assert inserted_entry.original_prompt_data_sha256 == sha265 def test_insert_entry_violates_constraint(setup_duckdb_database): # Generate a fixed UUID fixed_uuid = uuid.uuid4() # Create two entries with the same UUID - entry1 = PromptMemoryEntry(id=fixed_uuid, conversation_id="123", role="user", original_prompt_text="Hello") - entry2 = PromptMemoryEntry(id=fixed_uuid, conversation_id="456", role="user", original_prompt_text="Hello again") + entry1 = PromptMemoryEntry(id=fixed_uuid, + conversation_id="123", + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello") + + entry2 = PromptMemoryEntry(id=fixed_uuid, + conversation_id="456", + role="user", + original_prompt_text="Hello again", + converted_prompt_text="Hello again") # Insert the first entry with setup_duckdb_database.get_session() as session: @@ -193,7 +202,10 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): def test_insert_entries(setup_duckdb_database): entries = [ - PromptMemoryEntry(conversation_id=str(i), role="user", original_prompt_text=f"Message {i}", original_prompt_data_sha256=f"hash{i}") + PromptMemoryEntry(conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"CMessage {i}") for i in range(5) ] @@ -207,12 +219,15 @@ def test_insert_entries(setup_duckdb_database): assert entry.conversation_id == str(i) assert entry.role == "user" assert entry.original_prompt_text == f"Message {i}" - assert entry.original_prompt_data_sha256 == f"hash{i}" + assert entry.converted_prompt_text == f"CMessage {i}" def test_insert_embedding_entry(setup_duckdb_database): # Create a ConversationData entry - conversation_entry = PromptMemoryEntry(conversation_id="123", role="user", original_prompt_text="Hello", original_prompt_data_sha256="abc") + conversation_entry = PromptMemoryEntry(conversation_id="123", + role="user", + original_prompt_text="Hello", + converted_prompt_text="abc") # Insert the ConversationData entry using the insert_entry method setup_duckdb_database.insert_entry(conversation_entry) @@ -237,7 +252,14 @@ def test_insert_embedding_entry(setup_duckdb_database): def test_query_entries(setup_duckdb_database): # Insert some test data - entries = [PromptMemoryEntry(conversation_id=str(i), role="user", original_prompt_text=f"Message {i}") for i in range(3)] + entries = [ + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"Message {i}") for i in range(3) + ] + setup_duckdb_database.insert_entries(entries=entries) # Query entries without conditions @@ -254,7 +276,11 @@ def test_query_entries(setup_duckdb_database): def test_update_entries(setup_duckdb_database): # Insert a test entry - entry = PromptMemoryEntry(conversation_id="123", role="user", original_prompt_text="Hello") + entry = PromptMemoryEntry(conversation_id="123", + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello") + setup_duckdb_database.insert_entry(entry) # Fetch the entry to update and update its content @@ -271,11 +297,18 @@ def test_update_entries(setup_duckdb_database): def test_get_all_memory(setup_duckdb_database): # Insert some test data - entries = [PromptMemoryEntry(conversation_id=str(i), role="user", original_prompt_text=f"Message {i}") for i in range(3)] + entries = [ + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"Message {i}") for i in range(3) + ] + setup_duckdb_database.insert_entries(entries=entries) # Fetch all entries - all_entries = setup_duckdb_database.get_all_memory(PromptMemoryEntry) + all_entries = setup_duckdb_database.get_all_prompt_entries() assert len(all_entries) == 3 @@ -283,8 +316,8 @@ def test_get_memories_with_json_properties(setup_duckdb_database): # Define a specific conversation_id specific_conversation_id = "test_conversation_id" - converters = PromptConverterList([Base64Converter()]).to_json() - target = TextTarget().to_dict() + converters = PromptConverterList([Base64Converter()]) + target = TextTarget() # Start a session with setup_duckdb_database.get_session() as session: @@ -293,10 +326,8 @@ def test_get_memories_with_json_properties(setup_duckdb_database): conversation_id=specific_conversation_id, role="user", sequence=1, - original_prompt_text="Test content", - timestamp=datetime.datetime.utcnow(), - original_prompt_data_sha256="test_sha256", + converted_prompt_text="Test content", labels={"normalizer_id": "id1"}, converters=converters, prompt_target=target, @@ -307,7 +338,7 @@ def test_get_memories_with_json_properties(setup_duckdb_database): session.commit() # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = setup_duckdb_database.get_memories_with_conversation_id( + retrieved_entries = setup_duckdb_database.get_prompt_entries_with_conversation_id( conversation_id=specific_conversation_id ) @@ -319,7 +350,6 @@ def test_get_memories_with_json_properties(setup_duckdb_database): assert retrieved_entry.original_prompt_text == "Test content" # For timestamp, you might want to check if it's close to the current time instead of an exact match assert abs((retrieved_entry.timestamp - entry.timestamp).total_seconds()) < 10 # Assuming the test runs quickly - assert retrieved_entry.original_prompt_data_sha256 == "test_sha256" converters = json.loads(retrieved_entry.converters) assert len(converters) == 1 @@ -345,22 +375,22 @@ def test_get_memories_with_normalizer_id(setup_duckdb_database): conversation_id="123", role="user", original_prompt_text="Hello 1", + converted_prompt_text="Hello 1", labels=labels, - timestamp=datetime.datetime.utcnow(), ), PromptMemoryEntry( conversation_id="456", role="user", original_prompt_text="Hello 2", + converted_prompt_text="Hello 2", labels=other_labels, - timestamp=datetime.datetime.utcnow(), ), PromptMemoryEntry( conversation_id="789", role="user", original_prompt_text="Hello 3", + converted_prompt_text="Hello 1", labels=labels, - timestamp=datetime.datetime.utcnow(), ), ] @@ -370,7 +400,7 @@ def test_get_memories_with_normalizer_id(setup_duckdb_database): session.commit() # Ensure all entries are committed to the database # Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id - retrieved_entries = setup_duckdb_database.get_memories_with_normalizer_id(normalizer_id=specific_normalizer_id) + retrieved_entries = setup_duckdb_database.get_prompt_entries_with_normalizer_id(normalizer_id=specific_normalizer_id) # Verify that the retrieved entries match the expected normalizer_id assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id @@ -389,19 +419,19 @@ def test_update_entries_by_conversation_id(setup_duckdb_database): conversation_id=specific_conversation_id, role="user", original_prompt_text="Original content 1", - timestamp=datetime.datetime.utcnow(), + converted_prompt_text="Original content 1", ), PromptMemoryEntry( conversation_id="other_id", role="user", original_prompt_text="Original content 2", - timestamp=datetime.datetime.utcnow() + converted_prompt_text="Original content 2", ), PromptMemoryEntry( conversation_id=specific_conversation_id, role="user", original_prompt_text="Original content 3", - timestamp=datetime.datetime.utcnow(), + converted_prompt_text="Original content 3", ), ] diff --git a/tests/memory/test_memory_embedding.py b/tests/memory/test_memory_embedding.py index c34b47319..01b6d4709 100644 --- a/tests/memory/test_memory_embedding.py +++ b/tests/memory/test_memory_embedding.py @@ -51,8 +51,10 @@ def test_memory_encoding_chat_message( memory_encoder_w_mock_embedding_generator: MemoryEmbedding, ): chat_memory = PromptMemoryEntry( - content="hello world!", + original_prompt_text="hello world!", + converted_prompt_text="hello world!", role="user", + converted_prompt_data_type="text", conversation_id="my_session", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) diff --git a/tests/memory/test_memory_encoder.py b/tests/memory/test_memory_encoder.py index 92784b9b0..f7d2bbee9 100644 --- a/tests/memory/test_memory_encoder.py +++ b/tests/memory/test_memory_encoder.py @@ -54,9 +54,11 @@ def test_memory_encoding_chat_message( memory_encoder_w_mock_embedding_generator: MemoryEmbedding, ): chat_memory = PromptMemoryEntry( - content="hello world!", + original_prompt_text="hello world!", + converted_prompt_text="hello world!", role="user", conversation_id="my_session", + converted_prompt_data_type="text", ) metadata = memory_encoder_w_mock_embedding_generator.generate_embedding_memory_data(chat_memory=chat_memory) assert metadata.id == chat_memory.id diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 8381bff3d..872f0e81e 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -11,22 +11,8 @@ from pyrit.models import ChatMessage from pyrit.memory.memory_models import PromptMemoryEntry +from tests.mocks import memory -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() def generate_random_string(length: int = 10) -> str: @@ -39,7 +25,7 @@ def test_memory(memory: MemoryInterface): def test_conversation_memory_empty_by_default(memory: MemoryInterface): expected_count = 0 - c = memory.get_all_memory(PromptMemoryEntry) + c = memory.get_all_prompt_entries() assert len(c) == expected_count @@ -48,24 +34,24 @@ def test_count_of_memories_matches_number_of_conversations_added_1( ): expected_count = 1 message = ChatMessage(role="user", content="Hello") - memory.add_chat_message_to_memory(conversation=message, conversation_id="1", labels=[]) - c = memory.get_all_memory(PromptMemoryEntry) + memory.add_chat_message_to_memory(conversation=message, conversation_id="1", labels={}) + c = memory.get_all_prompt_entries() assert len(c) == expected_count -def test_add_chate_message_to_memory_added(memory: MemoryInterface): +def test_add_chat_message_to_memory_added(memory: MemoryInterface): expected_count = 3 memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1") memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1") memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1") - assert len(memory.get_all_memory(PromptMemoryEntry)) == expected_count + assert len(memory.get_all_prompt_entries()) == expected_count -def test_add_chate_messages_to_memory_added(memory: MemoryInterface): +def test_add_chat_messages_to_memory_added(memory: MemoryInterface): messages = [ ChatMessage(role="user", content="Hello 1"), ChatMessage(role="user", content="Hello 2"), ] memory.add_chat_messages_to_memory(conversations=messages, conversation_id="1") - assert len(memory.get_all_memory(PromptMemoryEntry)) == len(messages) + assert len(memory.get_all_prompt_entries()) == len(messages) diff --git a/tests/mocks.py b/tests/mocks.py index e274a5cf5..4d0e88aef 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -3,6 +3,10 @@ from contextlib import AbstractAsyncContextManager +import pytest +from sqlalchemy import inspect + +from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import PromptTarget @@ -61,3 +65,19 @@ def send_prompt(self, normalized_prompt: str, conversation_id: str, normalizer_i async def send_prompt_async(self, normalized_prompt: str, conversation_id: str, normalizer_id: str) -> None: self.prompt_sent.append(normalized_prompt) + +@pytest.fixture +def memory() -> MemoryInterface: # type: ignore + # Create an in-memory DuckDB engine + duckdb_memory = DuckDBMemory(db_path=":memory:") + + # Reset the database to ensure a clean state + duckdb_memory.reset_database() + inspector = inspect(duckdb_memory.engine) + + # Verify that tables are created as expected + assert "PromptMemoryEntries" in inspector.get_table_names(), "PromptMemoryEntries table not created." + assert "EmbeddingData" in inspector.get_table_names(), "EmbeddingData table not created." + + yield duckdb_memory + duckdb_memory.dispose_engine() diff --git a/tests/test_prompt_target.py b/tests/test_prompt_target.py index 6b688dd36..1a623a756 100644 --- a/tests/test_prompt_target.py +++ b/tests/test_prompt_target.py @@ -11,6 +11,7 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import AzureOpenAIChatTarget +from tests.mocks import memory @pytest.fixture def openai_mock_return() -> ChatCompletion: @@ -34,24 +35,6 @@ def openai_mock_return() -> ChatCompletion: def chat_completion_engine() -> AzureOpenAIChatTarget: return AzureOpenAIChatTarget(deployment_name="test", endpoint="test", api_key="test") - -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() - - @pytest.fixture def azure_openai_target(memory: DuckDBMemory): @@ -66,10 +49,10 @@ def azure_openai_target(memory: DuckDBMemory): def test_set_system_prompt(azure_openai_target: AzureOpenAIChatTarget): azure_openai_target.set_system_prompt(prompt="system prompt", conversation_id="1", normalizer_id="2") - chats = azure_openai_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_openai_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "system" - assert chats[0].content == "system prompt" + assert chats[0].converted_prompt_text == "system prompt" def test_send_prompt_user_no_system(azure_openai_target: AzureOpenAIChatTarget, openai_mock_return: ChatCompletion): @@ -79,7 +62,7 @@ def test_send_prompt_user_no_system(azure_openai_target: AzureOpenAIChatTarget, normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" ) - chats = azure_openai_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_openai_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 2, f"Expected 2 chats, got {len(chats)}" assert chats[0].role == "user" assert chats[1].role == "assistant" @@ -95,7 +78,7 @@ def test_send_prompt_with_system(azure_openai_target: AzureOpenAIChatTarget, ope normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" ) - chats = azure_openai_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_openai_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 3, f"Expected 3 chats, got {len(chats)}" assert chats[0].role == "system" assert chats[1].role == "user" diff --git a/tests/test_prompt_target_azure_blob_storage.py b/tests/test_prompt_target_azure_blob_storage.py index 808329cb6..d34323992 100644 --- a/tests/test_prompt_target_azure_blob_storage.py +++ b/tests/test_prompt_target_azure_blob_storage.py @@ -10,22 +10,8 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import AzureBlobStorageTarget +from tests.mocks import memory -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() @pytest.fixture @@ -87,10 +73,10 @@ def test_send_prompt(mock_upload, azure_blob_storage_target: AzureBlobStorageTar assert blob_url.__contains__(azure_blob_storage_target._container_url) assert blob_url.__contains__(".txt") - chats = azure_blob_storage_target._memory.get_memories_with_conversation_id(conversation_id="1") + chats = azure_blob_storage_target._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "user" - assert chats[0].content == __name__ + assert chats[0].converted_prompt_text == __name__ @patch("azure.storage.blob.aio.ContainerClient.upload_blob") @@ -103,7 +89,7 @@ async def test_send_prompt_async(mock_upload_async, azure_blob_storage_target: A assert blob_url.__contains__(azure_blob_storage_target._container_url) assert blob_url.__contains__(".txt") - chats = azure_blob_storage_target._memory.get_memories_with_conversation_id(conversation_id="2") + chats = azure_blob_storage_target._memory.get_prompt_entries_with_conversation_id(conversation_id="2") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "user" - assert chats[0].content == __name__ + assert chats[0].converted_prompt_text == __name__ diff --git a/tests/test_prompt_target_text.py b/tests/test_prompt_target_text.py index 56679bc13..f84d4b126 100644 --- a/tests/test_prompt_target_text.py +++ b/tests/test_prompt_target_text.py @@ -10,22 +10,7 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import TextTarget - -@pytest.fixture -def memory() -> MemoryInterface: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +from tests.mocks import memory def test_send_prompt_user_no_system(memory: DuckDBMemory): @@ -35,7 +20,7 @@ def test_send_prompt_user_no_system(memory: DuckDBMemory): normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" ) - chats = no_op._memory.get_memories_with_conversation_id(conversation_id="1") + chats = no_op._memory.get_prompt_entries_with_conversation_id(conversation_id="1") assert len(chats) == 1, f"Expected 1 chat, got {len(chats)}" assert chats[0].role == "user" diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index b14f1f830..76bf28b65 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -341,7 +341,7 @@ def test_reach_goal_after_two_turns_end_token( mock_target.return_value = expected_target_response target_response = red_teaming_orchestrator.apply_attack_strategy_until_completion() assert target_response == expected_target_response - conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries(PromptMemoryEntry) # Expecting three conversation threads (two with red teaming chat and one with prompt target) assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations( From c3b2dfd9b12b2a9cf9647367f8b6d7e20f661cb8 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 28 Mar 2024 11:15:09 -0700 Subject: [PATCH 08/15] tests passing --- pyrit/memory/memory_models.py | 4 +- pyrit/orchestrator/benchmark_orchestrator.py | 2 +- .../end_token_red_teaming_orchestrator.py | 2 +- pyrit/orchestrator/orchestrator_class.py | 8 ++- .../orchestrator/red_teaming_orchestrator.py | 2 +- tests/memory/test_memory_exporter.py | 24 ++++----- tests/mocks.py | 23 +++++++++ tests/test_red_teaming_orchestrator.py | 51 +++++++------------ 8 files changed, 62 insertions(+), 54 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 4866681ca..13db0304c 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -34,10 +34,10 @@ class PromptMemoryEntry(Base): # type: ignore __tablename__ (str): The name of the database table. __table_args__ (dict): Additional arguments for the database table. id (UUID): The unique identifier for the memory entry. - role (PromptType): system, request_segment, response + role (PromptType): system, assistant, user conversation_id (str): The identifier for the conversation which is associated with a single target. sequence (int): The order of the conversation within a conversation_id. - Can be the same number for multi-part requests or responses. + Can be the same number for multi-part requests or multi-part responses. timestamp (DateTime): The timestamp of the memory entry. labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. prompt_metadata (JSON): The metadata associated with the prompt. diff --git a/pyrit/orchestrator/benchmark_orchestrator.py b/pyrit/orchestrator/benchmark_orchestrator.py index d49458e7b..d3e3c03fe 100644 --- a/pyrit/orchestrator/benchmark_orchestrator.py +++ b/pyrit/orchestrator/benchmark_orchestrator.py @@ -31,7 +31,7 @@ def __init__( chat_model_under_evaluation: PromptChatTarget, scorer: QuestionAnswerScorer, memory: MemoryInterface | None = None, - memory_labels: list[str] = ["question-answering-benchmark-orchestrator"], + memory_labels: dict[str:str] = None, evaluation_prompt: str | None = None, batch_size: int = 1, verbose: bool = False, diff --git a/pyrit/orchestrator/end_token_red_teaming_orchestrator.py b/pyrit/orchestrator/end_token_red_teaming_orchestrator.py index da30b4c82..25af59115 100644 --- a/pyrit/orchestrator/end_token_red_teaming_orchestrator.py +++ b/pyrit/orchestrator/end_token_red_teaming_orchestrator.py @@ -24,7 +24,7 @@ def __init__( end_token: Optional[str] = RED_TEAM_CONVERSATION_END_TOKEN, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = ["red-teaming-orchestrator"], + memory_labels: dict[str:str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming target and a prompt target. diff --git a/pyrit/orchestrator/orchestrator_class.py b/pyrit/orchestrator/orchestrator_class.py index 04bc7154a..0553a1efb 100644 --- a/pyrit/orchestrator/orchestrator_class.py +++ b/pyrit/orchestrator/orchestrator_class.py @@ -22,14 +22,18 @@ def __init__( *, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = [], + memory_labels: dict[str:str] = {}, verbose: bool = False, ): self._prompt_converters = prompt_converters if prompt_converters else [NoOpConverter()] self._memory = memory or DuckDBMemory() - self._global_memory_labels = memory_labels self._verbose = verbose + if memory_labels: + self._global_memory_labels = memory_labels + + self._global_memory_labels = {"orchestrator": str(self.__class__.__name__)} + if self._verbose: logging.basicConfig(level=logging.INFO) diff --git a/pyrit/orchestrator/red_teaming_orchestrator.py b/pyrit/orchestrator/red_teaming_orchestrator.py index ed24eceb6..2628e8c57 100644 --- a/pyrit/orchestrator/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/red_teaming_orchestrator.py @@ -38,7 +38,7 @@ def __init__( initial_red_teaming_prompt: str = "Begin Conversation", prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = ["red-teaming-orchestrator"], + memory_labels: dict[str:str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming target and a prompt target. diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 3b87d9ff8..2d0e6a5df 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -9,14 +9,8 @@ from pyrit.memory.memory_models import PromptMemoryEntry from sqlalchemy.inspection import inspect +from tests.mocks import sample_conversations -@pytest.fixture -def sample_conversations(): - # Create some instances of ConversationStore with sample data - return [ - PromptMemoryEntry(role="User", content="Hello, how are you?", conversation_id="12345"), - PromptMemoryEntry(role="Bot", content="I'm fine, thank you!", conversation_id="12345"), - ] def model_to_dict(instance): @@ -34,14 +28,16 @@ def test_export_to_json_creates_file(tmp_path, sample_conversations): with open(file_path, "r") as f: content = json.load(f) # Perform more detailed checks on content if necessary - assert len(content) == 2 # Simple check for the number of items + assert len(content) == 3 # Simple check for the number of items # Convert each ConversationStore instance to a dictionary expected_content = [model_to_dict(conv) for conv in sample_conversations] for expected, actual in zip(expected_content, content): assert expected["role"] == actual["role"] - assert expected["content"] == actual["content"] + assert expected["converted_prompt_text"] == actual["converted_prompt_text"] assert expected["conversation_id"] == actual["conversation_id"] + assert expected["original_prompt_data_type"] == actual["original_prompt_data_type"] + assert expected["original_prompt_text"] == actual["original_prompt_text"] def test_export_data_with_conversations(tmp_path, sample_conversations): @@ -59,10 +55,10 @@ def test_export_data_with_conversations(tmp_path, sample_conversations): # Read the file and verify its contents with open(file_path, "r") as f: content = json.load(f) - assert len(content) == 2 # Check for the expected number of items - assert content[0]["role"] == "User" - assert content[0]["content"] == "Hello, how are you?" + assert len(content) == 3 # Check for the expected number of items + assert content[0]["role"] == "user" + assert content[0]["converted_prompt_text"] == "Hello, how are you?" assert content[0]["conversation_id"] == "12345" - assert content[1]["role"] == "Bot" - assert content[1]["content"] == "I'm fine, thank you!" + assert content[1]["role"] == "assistant" + assert content[1]["converted_prompt_text"] == "I'm fine, thank you!" assert content[1]["conversation_id"] == "12345" diff --git a/tests/mocks.py b/tests/mocks.py index 4d0e88aef..2807cb7a0 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -7,6 +7,7 @@ from sqlalchemy import inspect from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.prompt_target import PromptTarget @@ -81,3 +82,25 @@ def memory() -> MemoryInterface: # type: ignore yield duckdb_memory duckdb_memory.dispose_engine() + +@pytest.fixture +def sample_conversations(): + return [ + PromptMemoryEntry( + role="user", + original_prompt_text="original prompt text", + converted_prompt_text="Hello, how are you?", + conversation_id="12345"), + PromptMemoryEntry( + role="assistant", + original_prompt_text="original prompt text", + converted_prompt_text="I'm fine, thank you!", + conversation_id="12345" + ), + PromptMemoryEntry( + role="assistant", + original_prompt_text="original prompt text", + converted_prompt_text="I'm fine, thank you!", + conversation_id="33333" + ), + ] \ No newline at end of file diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index 76bf28b65..9f2806039 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -18,22 +18,7 @@ from pyrit.memory import DuckDBMemory from pyrit.common.path import DATASETS_PATH - -@pytest.fixture -def memory() -> DuckDBMemory: # type: ignore - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "ConversationStore" in inspector.get_table_names(), "ConversationStore table not created." - assert "EmbeddingStore" in inspector.get_table_names(), "EmbeddingStore table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +from tests.mocks import memory @pytest.fixture @@ -79,18 +64,18 @@ def check_conversations( # first conversation (with red teaming chat bot) assert conversations[0].conversation_id == conversations[1].conversation_id == conversations[2].conversation_id assert conversations[0].role == "system" - assert conversations[0].content == red_teaming_meta_prompt + assert conversations[0].converted_prompt_text == red_teaming_meta_prompt assert conversations[1].role == "user" - assert conversations[1].content == initial_red_teaming_prompt + assert conversations[1].converted_prompt_text == initial_red_teaming_prompt assert conversations[2].role == "assistant" - assert conversations[2].content == expected_red_teaming_responses[0] + assert conversations[2].converted_prompt_text == expected_red_teaming_responses[0] # second conversation (with prompt target) assert conversations[3 - index_offset].conversation_id == conversations[4 - index_offset].conversation_id - assert conversations[3 - index_offset].normalizer_id == conversations[4 - index_offset].normalizer_id + assert conversations[3 - index_offset].labels["normalizer_id"] == conversations[4 - index_offset].labels["normalizer_id"] assert conversations[3 - index_offset].role == "user" - assert conversations[3 - index_offset].content == expected_red_teaming_responses[0] + assert conversations[3 - index_offset].converted_prompt_text == expected_red_teaming_responses[0] assert conversations[4 - index_offset].role == "assistant" - assert conversations[4 - index_offset].content == expected_target_responses[0] + assert conversations[4 - index_offset].converted_prompt_text == expected_target_responses[0] if stop_after_n_conversations == 2: return @@ -103,20 +88,20 @@ def check_conversations( # third conversation (with red teaming chatbot) assert conversations[5 - index_offset].conversation_id == conversations[6 - index_offset].conversation_id assert conversations[5 - index_offset].role == "user" - assert conversations[5 - index_offset].content == expected_target_responses[0] + assert conversations[5 - index_offset].converted_prompt_text == expected_target_responses[0] assert conversations[6 - index_offset].role == "assistant" - assert conversations[6 - index_offset].content == expected_red_teaming_responses[1] + assert conversations[6 - index_offset].converted_prompt_text == expected_red_teaming_responses[1] if stop_after_n_conversations == 3: return # fourth conversation (with prompt target) assert conversations[7 - index_offset].conversation_id == conversations[8 - index_offset].conversation_id - assert conversations[7 - index_offset].normalizer_id == conversations[8 - index_offset].normalizer_id + assert conversations[7 - index_offset].labels["normalizer_id"] == conversations[8 - index_offset].labels["normalizer_id"] assert conversations[7 - index_offset].role == "user" - assert conversations[7 - index_offset].content == expected_red_teaming_responses[1] + assert conversations[7 - index_offset].converted_prompt_text == expected_red_teaming_responses[1] assert conversations[8 - index_offset].role == "assistant" - assert conversations[8 - index_offset].content == expected_target_responses[1] + assert conversations[8 - index_offset].converted_prompt_text == expected_target_responses[1] @pytest.mark.parametrize("attack_strategy_as_str", [True, False]) @@ -152,7 +137,7 @@ def test_send_prompt_twice( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 5, f"Expected 5 conversations, got {len(conversations)}" check_conversations( @@ -173,7 +158,7 @@ def test_send_prompt_twice( mock_target.return_value = expected_target_responses[1] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[1] - conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting another two conversation threads assert len(conversations) == 9, f"Expected 9 conversations, got {len(conversations)}" check_conversations( @@ -218,7 +203,7 @@ def test_send_fixed_prompt_then_generated_prompt( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt(prompt=fixed_input_prompt) assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 2, f"Expected 2 conversations, got {len(conversations)}" check_conversations( @@ -240,7 +225,7 @@ def test_send_fixed_prompt_then_generated_prompt( mock_target.return_value = expected_target_responses[1] target_response = red_teaming_orchestrator.send_prompt() assert target_response == expected_target_responses[1] - conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting another two conversation threads assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations( @@ -285,7 +270,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( mock_target.return_value = expected_target_responses[0] target_response = red_teaming_orchestrator.send_prompt(prompt=fixed_input_prompt) assert target_response == expected_target_responses[0] - conversations = red_teaming_orchestrator._memory.get_all_memory(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting two conversation threads (one with red teaming chat and one with prompt target) assert len(conversations) == 2, f"Expected 2 conversations, got {len(conversations)}" check_conversations( @@ -341,7 +326,7 @@ def test_reach_goal_after_two_turns_end_token( mock_target.return_value = expected_target_response target_response = red_teaming_orchestrator.apply_attack_strategy_until_completion() assert target_response == expected_target_response - conversations = red_teaming_orchestrator._memory.get_all_prompt_entries(PromptMemoryEntry) + conversations = red_teaming_orchestrator._memory.get_all_prompt_entries() # Expecting three conversation threads (two with red teaming chat and one with prompt target) assert len(conversations) == 7, f"Expected 7 conversations, got {len(conversations)}" check_conversations( From 151a7f7e1d77725a31d27b1ecea62e8d3c539c9d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 28 Mar 2024 11:49:17 -0700 Subject: [PATCH 09/15] build changes --- pyrit/memory/duckdb_memory.py | 10 +- pyrit/memory/memory_embedding.py | 4 +- pyrit/memory/memory_interface.py | 10 +- pyrit/memory/memory_models.py | 48 ++++---- pyrit/prompt_converter/prompt_converter.py | 12 +- pyrit/prompt_target/prompt_target.py | 6 +- .../analytics/test_conversation_analytics.py | 8 +- tests/memory/test_duckdb_memory.py | 115 ++++++++++-------- tests/memory/test_memory_exporter.py | 1 - tests/memory/test_memory_interface.py | 1 - tests/mocks.py | 15 ++- tests/test_prompt_target.py | 2 + .../test_prompt_target_azure_blob_storage.py | 1 - tests/test_red_teaming_orchestrator.py | 10 +- 14 files changed, 131 insertions(+), 112 deletions(-) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 96cebd334..443cf42ed 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -38,7 +38,7 @@ def __init__( db_path: Union[Path, str] = None, embedding_model: EmbeddingSupport = None, disable_embedding: bool = False, - verbose: bool = False + verbose: bool = False, ): super(DuckDBMemory, self).__init__() if not disable_embedding: @@ -102,7 +102,9 @@ def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> li list[ConversationData]: A list of ConversationData objects matching the specified conversation ID. """ try: - return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id) + return self.query_entries( + PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id + ) except Exception as e: logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}") return [] @@ -118,7 +120,9 @@ def get_prompt_entries_with_normalizer_id(self, *, normalizer_id: str) -> list[P list[ConversationData]: A list of ConversationData objects matching the specified normalizer ID. """ try: - return self.query_entries(PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op('->>')('normalizer_id') == normalizer_id) + return self.query_entries( + PromptMemoryEntry, conditions=PromptMemoryEntry.labels.op("->>")("normalizer_id") == normalizer_id + ) except Exception as e: logger.exception( f"Unexpected error: Failed to retrieve ConversationData with normalizer_id {normalizer_id}. {e}" diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index 082a1884e..0ad7c68da 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -32,7 +32,9 @@ def generate_embedding_memory_data(self, *, chat_memory: PromptMemoryEntry) -> E """ if chat_memory.converted_prompt_data_type == "text": embedding_data = EmbeddingData( - embedding=self.embedding_model.generate_text_embedding(text=chat_memory.converted_prompt_text).data[0].embedding, + embedding=self.embedding_model.generate_text_embedding(text=chat_memory.converted_prompt_text) + .data[0] + .embedding, embedding_type_name=self.embedding_model.__class__.__name__, id=chat_memory.id, ) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 94fd6907a..87a9a33cb 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -73,7 +73,6 @@ def insert_prompt_entries(self, *, entries: list[EmbeddingData]) -> None: entries (list[Base]): The list of database model instances to be inserted. """ - @abc.abstractmethod def dispose_engine(self): """ @@ -93,7 +92,6 @@ def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> lis memory_entries = self.get_prompt_entries_with_conversation_id(conversation_id=conversation_id) return [ChatMessage(role=me.role, content=me.converted_prompt_text) for me in memory_entries] # type: ignore - def add_chat_message_to_memory( self, conversation: ChatMessage, @@ -115,15 +113,10 @@ def add_chat_message_to_memory( labels (list[str]): A list of labels to be added to the memory entry. """ - self.add_chat_messages_to_memory( - conversations=[conversation], - conversation_id=conversation_id, - normalizer_id=normalizer_id, - labels=labels + conversations=[conversation], conversation_id=conversation_id, normalizer_id=normalizer_id, labels=labels ) - def add_chat_messages_to_memory( self, *, @@ -162,7 +155,6 @@ def add_chat_messages_to_memory( self.insert_prompt_entries(entries=entries_to_add) - def export_all_tables(self, *, export_type: str = "json"): """ Exports all table data using the specified exporter. diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 13db0304c..63fbc325b 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -40,7 +40,9 @@ class PromptMemoryEntry(Base): # type: ignore Can be the same number for multi-part requests or multi-part responses. timestamp (DateTime): The timestamp of the memory entry. labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. - prompt_metadata (JSON): The metadata associated with the prompt. + prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios. + Because memory is how components talk with each other, this can be component specific. + e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. converters (list[PromptConverter]): The converters for the prompt. prompt_target (PromptTarget): The target for the prompt. original_prompt_data_type (PromptDataType): The data type of the original prompt (text, image) @@ -58,14 +60,14 @@ class PromptMemoryEntry(Base): # type: ignore __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} id = Column(UUID(as_uuid=True), nullable=False, primary_key=True) - role: 'Column[ChatMessageRole]' = Column(String, nullable=False) + role: "Column[ChatMessageRole]" = Column(String, nullable=False) conversation_id = Column(String, nullable=False) sequence = Column(INTEGER, nullable=False) timestamp = Column(DateTime, nullable=False) labels: Column[Dict[str, str]] = Column(JSON) prompt_metadata = Column(JSON) - converters: 'Column[list[PromptConverter]]' = Column(JSON) - prompt_target: 'Column[PromptTarget]' = Column(JSON) + converters: "Column[list[PromptConverter]]" = Column(JSON) + prompt_target: "Column[PromptTarget]" = Column(JSON) original_prompt_data_type: PromptDataType = Column(String, nullable=False) original_prompt_text = Column(String, nullable=False) @@ -77,23 +79,22 @@ class PromptMemoryEntry(Base): # type: ignore idx_conversation_id = Index("idx_conversation_id", "conversation_id") - - def __init__(self, - *, - role: str, - original_prompt_text: str, - converted_prompt_text: str, - id: uuid.UUID = None, - conversation_id: str = None, - sequence: int = -1, - labels: Dict[str, str] = None, - prompt_metadata: JSON = None, - converters: 'PromptConverterList' = None, - prompt_target: 'PromptTarget' = None, - original_prompt_data_type: PromptDataType = "text", - converted_prompt_data_type: PromptDataType = "text" - ): - + def __init__( + self, + *, + role: str, + original_prompt_text: str, + converted_prompt_text: str, + id: uuid.UUID = None, + conversation_id: str = None, + sequence: int = -1, + labels: Dict[str, str] = None, + prompt_metadata: JSON = None, + converters: "PromptConverterList" = None, + prompt_target: "PromptTarget" = None, + original_prompt_data_type: PromptDataType = "text", + converted_prompt_data_type: PromptDataType = "text", + ): self.id = id if id else uuid4() @@ -112,15 +113,12 @@ def __init__(self, self.original_prompt_data_type = original_prompt_data_type self.original_prompt_data_sha256 = self._create_sha256(original_prompt_text) - self.converted_prompt_data_type = converted_prompt_data_type self.converted_prompt_text = converted_prompt_text self.converted_prompt_data_sha256 = self._create_sha256(converted_prompt_text) - - def _create_sha256(self, text: str) -> str: - input_bytes = text.encode('utf-8') + input_bytes = text.encode("utf-8") hash_object = hashlib.sha256(input_bytes) return hash_object.hexdigest() diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 2528ad980..6808c14a5 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -4,6 +4,7 @@ import abc import json + class PromptConverter(abc.ABC): """ A prompt converter is responsible for converting prompts into multiple representations. @@ -28,14 +29,15 @@ def is_one_to_one_converter(self) -> bool: pass def to_dict(self): - public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith('_')} - public_attributes['__type__'] = self.__class__.__name__ - public_attributes['__module__'] = self.__class__.__module__ + public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + public_attributes["__type__"] = self.__class__.__name__ + public_attributes["__module__"] = self.__class__.__module__ return public_attributes -class PromptConverterList(): + +class PromptConverterList: def __init__(self, converters: list[PromptConverter]) -> None: self.converters = converters def to_json(self): - return json.dumps([converter.to_dict() for converter in self.converters]) \ No newline at end of file + return json.dumps([converter.to_dict() for converter in self.converters]) diff --git a/pyrit/prompt_target/prompt_target.py b/pyrit/prompt_target/prompt_target.py index cc9f2b484..19d659b6d 100644 --- a/pyrit/prompt_target/prompt_target.py +++ b/pyrit/prompt_target/prompt_target.py @@ -32,7 +32,7 @@ async def send_prompt_async(self, *, normalized_prompt: str, conversation_id: st """ def to_dict(self): - public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith('_')} - public_attributes['__type__'] = self.__class__.__name__ - public_attributes['__module__'] = self.__class__.__module__ + public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + public_attributes["__type__"] = self.__class__.__name__ + public_attributes["__module__"] = self.__class__.__module__ return json.dumps(public_attributes) diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index 554d6cfc7..c4c6c05db 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -21,12 +21,16 @@ def test_get_similar_chat_messages_by_content(mock_memory_interface): mock_data = [ PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="user"), PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="I'm fine, thank you!", role="assistant"), - PromptMemoryEntry(original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="assistant"), # Exact match + PromptMemoryEntry( + original_prompt_text="h", converted_prompt_text="Hello, how are you?", role="assistant" + ), # Exact match ] mock_memory_interface.get_all_prompt_entries.return_value = mock_data analytics = ConversationAnalytics(memory_interface=mock_memory_interface) - similar_messages = analytics.get_prompt_entries_with_same_converted_content(chat_message_content="Hello, how are you?") + similar_messages = analytics.get_prompt_entries_with_same_converted_content( + chat_message_content="Hello, how are you?" + ) # Expect one exact match assert len(similar_messages) == 2 diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 2ec6578d3..f7d9d9966 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -57,21 +57,23 @@ def test_conversation_data_schema(setup_duckdb_database): column_names = [col["name"] for col in columns] # Expected columns in ConversationData - expected_columns = ["id", - "role", - "conversation_id", - "sequence", - "timestamp", - "labels", - "prompt_metadata", - "converters", - "prompt_target", - "original_prompt_data_type", - "original_prompt_text", - "original_prompt_data_sha256", - "converted_prompt_data_type", - "converted_prompt_text", - "converted_prompt_data_sha256"] + expected_columns = [ + "id", + "role", + "conversation_id", + "sequence", + "timestamp", + "labels", + "prompt_metadata", + "converters", + "prompt_target", + "original_prompt_data_type", + "original_prompt_text", + "original_prompt_data_sha256", + "converted_prompt_data_type", + "converted_prompt_text", + "converted_prompt_data_sha256", + ] for column in expected_columns: assert column in column_names, f"{column} not found in PromptMemoryEntries schema." @@ -120,7 +122,6 @@ def test_conversation_data_column_types(setup_duckdb_database): ), f"Expected {column} to be a subclass of {expected_type}, got {column_types[column]} instead." - def test_embedding_data_column_types(setup_duckdb_database): inspector = inspect(setup_duckdb_database.engine) columns = inspector.get_columns("EmbeddingData") @@ -176,17 +177,17 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): # Generate a fixed UUID fixed_uuid = uuid.uuid4() # Create two entries with the same UUID - entry1 = PromptMemoryEntry(id=fixed_uuid, - conversation_id="123", - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello") - - entry2 = PromptMemoryEntry(id=fixed_uuid, - conversation_id="456", - role="user", - original_prompt_text="Hello again", - converted_prompt_text="Hello again") + entry1 = PromptMemoryEntry( + id=fixed_uuid, conversation_id="123", role="user", original_prompt_text="Hello", converted_prompt_text="Hello" + ) + + entry2 = PromptMemoryEntry( + id=fixed_uuid, + conversation_id="456", + role="user", + original_prompt_text="Hello again", + converted_prompt_text="Hello again", + ) # Insert the first entry with setup_duckdb_database.get_session() as session: @@ -202,10 +203,12 @@ def test_insert_entry_violates_constraint(setup_duckdb_database): def test_insert_entries(setup_duckdb_database): entries = [ - PromptMemoryEntry(conversation_id=str(i), - role="user", - original_prompt_text=f"Message {i}", - converted_prompt_text=f"CMessage {i}") + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"CMessage {i}", + ) for i in range(5) ] @@ -224,10 +227,9 @@ def test_insert_entries(setup_duckdb_database): def test_insert_embedding_entry(setup_duckdb_database): # Create a ConversationData entry - conversation_entry = PromptMemoryEntry(conversation_id="123", - role="user", - original_prompt_text="Hello", - converted_prompt_text="abc") + conversation_entry = PromptMemoryEntry( + conversation_id="123", role="user", original_prompt_text="Hello", converted_prompt_text="abc" + ) # Insert the ConversationData entry using the insert_entry method setup_duckdb_database.insert_entry(conversation_entry) @@ -253,12 +255,14 @@ def test_insert_embedding_entry(setup_duckdb_database): def test_query_entries(setup_duckdb_database): # Insert some test data entries = [ - PromptMemoryEntry( - conversation_id=str(i), - role="user", - original_prompt_text=f"Message {i}", - converted_prompt_text=f"Message {i}") for i in range(3) - ] + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"Message {i}", + ) + for i in range(3) + ] setup_duckdb_database.insert_entries(entries=entries) @@ -276,10 +280,9 @@ def test_query_entries(setup_duckdb_database): def test_update_entries(setup_duckdb_database): # Insert a test entry - entry = PromptMemoryEntry(conversation_id="123", - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello") + entry = PromptMemoryEntry( + conversation_id="123", role="user", original_prompt_text="Hello", converted_prompt_text="Hello" + ) setup_duckdb_database.insert_entry(entry) @@ -287,7 +290,9 @@ def test_update_entries(setup_duckdb_database): entries_to_update = setup_duckdb_database.query_entries( PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == "123" ) - setup_duckdb_database.update_entries(entries=entries_to_update, update_fields={"original_prompt_text": "Updated Hello"}) + setup_duckdb_database.update_entries( + entries=entries_to_update, update_fields={"original_prompt_text": "Updated Hello"} + ) # Verify the entry was updated with setup_duckdb_database.get_session() as session: @@ -298,12 +303,14 @@ def test_update_entries(setup_duckdb_database): def test_get_all_memory(setup_duckdb_database): # Insert some test data entries = [ - PromptMemoryEntry( - conversation_id=str(i), - role="user", - original_prompt_text=f"Message {i}", - converted_prompt_text=f"Message {i}") for i in range(3) - ] + PromptMemoryEntry( + conversation_id=str(i), + role="user", + original_prompt_text=f"Message {i}", + converted_prompt_text=f"Message {i}", + ) + for i in range(3) + ] setup_duckdb_database.insert_entries(entries=entries) @@ -400,7 +407,9 @@ def test_get_memories_with_normalizer_id(setup_duckdb_database): session.commit() # Ensure all entries are committed to the database # Use the get_memories_with_normalizer_id method to retrieve entries with the specific normalizer_id - retrieved_entries = setup_duckdb_database.get_prompt_entries_with_normalizer_id(normalizer_id=specific_normalizer_id) + retrieved_entries = setup_duckdb_database.get_prompt_entries_with_normalizer_id( + normalizer_id=specific_normalizer_id + ) # Verify that the retrieved entries match the expected normalizer_id assert len(retrieved_entries) == 2 # Two entries should have the specific normalizer_id diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 2d0e6a5df..032614de8 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -12,7 +12,6 @@ from tests.mocks import sample_conversations - def model_to_dict(instance): """Converts a SQLAlchemy model instance into a dictionary.""" return {c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs} diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 872f0e81e..99d2cc4b7 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -14,7 +14,6 @@ from tests.mocks import memory - def generate_random_string(length: int = 10) -> str: return "".join(random.choice(ascii_lowercase) for _ in range(length)) diff --git a/tests/mocks.py b/tests/mocks.py index 2807cb7a0..47ef9a685 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -67,6 +67,7 @@ def send_prompt(self, normalized_prompt: str, conversation_id: str, normalizer_i async def send_prompt_async(self, normalized_prompt: str, conversation_id: str, normalizer_id: str) -> None: self.prompt_sent.append(normalized_prompt) + @pytest.fixture def memory() -> MemoryInterface: # type: ignore # Create an in-memory DuckDB engine @@ -83,6 +84,7 @@ def memory() -> MemoryInterface: # type: ignore yield duckdb_memory duckdb_memory.dispose_engine() + @pytest.fixture def sample_conversations(): return [ @@ -90,17 +92,18 @@ def sample_conversations(): role="user", original_prompt_text="original prompt text", converted_prompt_text="Hello, how are you?", - conversation_id="12345"), + conversation_id="12345", + ), PromptMemoryEntry( role="assistant", original_prompt_text="original prompt text", converted_prompt_text="I'm fine, thank you!", - conversation_id="12345" - ), + conversation_id="12345", + ), PromptMemoryEntry( role="assistant", original_prompt_text="original prompt text", converted_prompt_text="I'm fine, thank you!", - conversation_id="33333" - ), - ] \ No newline at end of file + conversation_id="33333", + ), + ] diff --git a/tests/test_prompt_target.py b/tests/test_prompt_target.py index 1a623a756..2cc032f29 100644 --- a/tests/test_prompt_target.py +++ b/tests/test_prompt_target.py @@ -13,6 +13,7 @@ from tests.mocks import memory + @pytest.fixture def openai_mock_return() -> ChatCompletion: return ChatCompletion( @@ -35,6 +36,7 @@ def openai_mock_return() -> ChatCompletion: def chat_completion_engine() -> AzureOpenAIChatTarget: return AzureOpenAIChatTarget(deployment_name="test", endpoint="test", api_key="test") + @pytest.fixture def azure_openai_target(memory: DuckDBMemory): diff --git a/tests/test_prompt_target_azure_blob_storage.py b/tests/test_prompt_target_azure_blob_storage.py index d34323992..65c14ae60 100644 --- a/tests/test_prompt_target_azure_blob_storage.py +++ b/tests/test_prompt_target_azure_blob_storage.py @@ -13,7 +13,6 @@ from tests.mocks import memory - @pytest.fixture def azure_blob_storage_target(memory: DuckDBMemory): return AzureBlobStorageTarget( diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index 9f2806039..64849eee9 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -71,7 +71,10 @@ def check_conversations( assert conversations[2].converted_prompt_text == expected_red_teaming_responses[0] # second conversation (with prompt target) assert conversations[3 - index_offset].conversation_id == conversations[4 - index_offset].conversation_id - assert conversations[3 - index_offset].labels["normalizer_id"] == conversations[4 - index_offset].labels["normalizer_id"] + assert ( + conversations[3 - index_offset].labels["normalizer_id"] + == conversations[4 - index_offset].labels["normalizer_id"] + ) assert conversations[3 - index_offset].role == "user" assert conversations[3 - index_offset].converted_prompt_text == expected_red_teaming_responses[0] assert conversations[4 - index_offset].role == "assistant" @@ -97,7 +100,10 @@ def check_conversations( # fourth conversation (with prompt target) assert conversations[7 - index_offset].conversation_id == conversations[8 - index_offset].conversation_id - assert conversations[7 - index_offset].labels["normalizer_id"] == conversations[8 - index_offset].labels["normalizer_id"] + assert ( + conversations[7 - index_offset].labels["normalizer_id"] + == conversations[8 - index_offset].labels["normalizer_id"] + ) assert conversations[7 - index_offset].role == "user" assert conversations[7 - index_offset].converted_prompt_text == expected_red_teaming_responses[1] assert conversations[8 - index_offset].role == "assistant" From 4cf48ba7d06e7f7295b87ed99c1ccfb18738bd78 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 28 Mar 2024 12:33:44 -0700 Subject: [PATCH 10/15] fixing up build things --- pyrit/memory/memory_interface.py | 6 +-- pyrit/memory/memory_models.py | 24 +++++------- .../analytics/test_conversation_analytics.py | 1 - tests/memory/test_duckdb_memory.py | 7 +--- tests/memory/test_memory_exporter.py | 5 +-- tests/memory/test_memory_interface.py | 38 +++++++++---------- tests/mocks.py | 4 +- tests/test_prompt_target.py | 6 +-- .../test_prompt_target_azure_blob_storage.py | 6 +-- tests/test_prompt_target_text.py | 10 ++--- tests/test_red_teaming_orchestrator.py | 22 +++++------ 11 files changed, 55 insertions(+), 74 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 87a9a33cb..5db4e2b7b 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2,13 +2,9 @@ # Licensed under the MIT license. import abc -from hashlib import sha256 -from typing import Optional from pathlib import Path -from uuid import uuid4 - -from pyrit.memory.memory_models import Base, PromptMemoryEntry, EmbeddingData +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.models import ChatMessage diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 63fbc325b..336a50460 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import enum import hashlib import uuid @@ -10,18 +9,13 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict -from sqlalchemy import Column, String, DateTime, Float, Enum, JSON, ForeignKey, Index, INTEGER, ARRAY +from sqlalchemy import Column, String, DateTime, Float, JSON, ForeignKey, Index, INTEGER, ARRAY from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.dialects.postgresql import UUID Base = declarative_base() -""" -class PromptDataType(enum.Enum): - TEXT = 'text' - IMAGE_URL = 'image_url' -""" PromptDataType = Literal["text", "image_url"] @@ -30,6 +24,8 @@ class PromptMemoryEntry(Base): # type: ignore """ Represents the prompt data. + Because of the nature of database and sql alchemy, type ignores are abundant :) + Attributes: __tablename__ (str): The name of the database table. __table_args__ (dict): Additional arguments for the database table. @@ -60,20 +56,20 @@ class PromptMemoryEntry(Base): # type: ignore __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} id = Column(UUID(as_uuid=True), nullable=False, primary_key=True) - role: "Column[ChatMessageRole]" = Column(String, nullable=False) + role: "Column[ChatMessageRole]" = Column(String, nullable=False) # type: ignore # noqa conversation_id = Column(String, nullable=False) sequence = Column(INTEGER, nullable=False) timestamp = Column(DateTime, nullable=False) labels: Column[Dict[str, str]] = Column(JSON) prompt_metadata = Column(JSON) - converters: "Column[list[PromptConverter]]" = Column(JSON) - prompt_target: "Column[PromptTarget]" = Column(JSON) + converters: "Column[list[PromptConverter]]" = Column(JSON) # type: ignore # noqa + prompt_target: "Column[PromptTarget]" = Column(JSON) # type: ignore # noqa - original_prompt_data_type: PromptDataType = Column(String, nullable=False) + original_prompt_data_type: PromptDataType = Column(String, nullable=False) # type: ignore original_prompt_text = Column(String, nullable=False) original_prompt_data_sha256 = Column(String) - converted_prompt_data_type: PromptDataType = Column(String, nullable=False) + converted_prompt_data_type: PromptDataType = Column(String, nullable=False) # type: ignore converted_prompt_text = Column(String) converted_prompt_data_sha256 = Column(String) @@ -90,8 +86,8 @@ def __init__( sequence: int = -1, labels: Dict[str, str] = None, prompt_metadata: JSON = None, - converters: "PromptConverterList" = None, - prompt_target: "PromptTarget" = None, + converters: "PromptConverterList" = None, # type: ignore # noqa + prompt_target: "PromptTarget" = None, # type: ignore # noqa original_prompt_data_type: PromptDataType = "text", converted_prompt_data_type: PromptDataType = "text", ): diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index c4c6c05db..950ead8bb 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pytest -import uuid from unittest.mock import MagicMock from pyrit.memory.memory_interface import MemoryInterface diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index f7d9d9966..86c5b6172 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -4,19 +4,16 @@ import json import pytest import uuid -import datetime from unittest.mock import MagicMock from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import inspect -from sqlalchemy import String, DateTime, Float, Enum, JSON, ForeignKey, Index, INTEGER, ARRAY +from sqlalchemy import String, DateTime, INTEGER, ARRAY from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql.sqltypes import NullType -from sqlalchemy.types import String, DateTime -from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData, PromptDataType +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory import DuckDBMemory -from pyrit.prompt_converter import PromptConverter from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_converter.prompt_converter import PromptConverterList from pyrit.prompt_target.text_target import TextTarget diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 032614de8..85c9e64a9 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -3,13 +3,10 @@ import json -import pytest - from pyrit.memory.memory_exporter import MemoryExporter -from pyrit.memory.memory_models import PromptMemoryEntry from sqlalchemy.inspection import inspect -from tests.mocks import sample_conversations +from tests.mocks import sample_conversations_fixture as sample_conversations # noqa: F401 def model_to_dict(instance): diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 99d2cc4b7..261ee48b7 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -4,53 +4,49 @@ import random from string import ascii_lowercase -import pytest -from sqlalchemy import inspect - -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory import MemoryInterface from pyrit.models import ChatMessage -from pyrit.memory.memory_models import PromptMemoryEntry -from tests.mocks import memory +from tests.mocks import memory_fixture as memory # noqa: F401 def generate_random_string(length: int = 10) -> str: return "".join(random.choice(ascii_lowercase) for _ in range(length)) -def test_memory(memory: MemoryInterface): - assert memory +def test_memory(memory_fixture: MemoryInterface): + assert memory_fixture -def test_conversation_memory_empty_by_default(memory: MemoryInterface): +def test_conversation_memory_empty_by_default(memory_fixture: MemoryInterface): expected_count = 0 - c = memory.get_all_prompt_entries() + c = memory_fixture.get_all_prompt_entries() assert len(c) == expected_count def test_count_of_memories_matches_number_of_conversations_added_1( - memory: MemoryInterface, + memory_fixture: MemoryInterface, ): expected_count = 1 message = ChatMessage(role="user", content="Hello") - memory.add_chat_message_to_memory(conversation=message, conversation_id="1", labels={}) - c = memory.get_all_prompt_entries() + memory_fixture.add_chat_message_to_memory(conversation=message, conversation_id="1", labels={}) + c = memory_fixture.get_all_prompt_entries() assert len(c) == expected_count -def test_add_chat_message_to_memory_added(memory: MemoryInterface): +def test_add_chat_message_to_memory_added(memory_fixture: MemoryInterface): expected_count = 3 - memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1") - memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1") - memory.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1") - assert len(memory.get_all_prompt_entries()) == expected_count + memory_fixture.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1") + memory_fixture.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1") + memory_fixture.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1") + assert len(memory_fixture.get_all_prompt_entries()) == expected_count -def test_add_chat_messages_to_memory_added(memory: MemoryInterface): +def test_add_chat_messages_to_memory_added(memory_fixture: MemoryInterface): messages = [ ChatMessage(role="user", content="Hello 1"), ChatMessage(role="user", content="Hello 2"), ] - memory.add_chat_messages_to_memory(conversations=messages, conversation_id="1") - assert len(memory.get_all_prompt_entries()) == len(messages) + memory_fixture.add_chat_messages_to_memory(conversations=messages, conversation_id="1") + assert len(memory_fixture.get_all_prompt_entries()) == len(messages) diff --git a/tests/mocks.py b/tests/mocks.py index 47ef9a685..3f050abf1 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -69,7 +69,7 @@ async def send_prompt_async(self, normalized_prompt: str, conversation_id: str, @pytest.fixture -def memory() -> MemoryInterface: # type: ignore +def memory_fixture() -> MemoryInterface: # type: ignore # Create an in-memory DuckDB engine duckdb_memory = DuckDBMemory(db_path=":memory:") @@ -86,7 +86,7 @@ def memory() -> MemoryInterface: # type: ignore @pytest.fixture -def sample_conversations(): +def sample_conversations_fixture(): return [ PromptMemoryEntry( role="user", diff --git a/tests/test_prompt_target.py b/tests/test_prompt_target.py index 2cc032f29..4a3780866 100644 --- a/tests/test_prompt_target.py +++ b/tests/test_prompt_target.py @@ -11,7 +11,7 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import AzureOpenAIChatTarget -from tests.mocks import memory +from tests.mocks import memory_fixture @pytest.fixture @@ -38,13 +38,13 @@ def chat_completion_engine() -> AzureOpenAIChatTarget: @pytest.fixture -def azure_openai_target(memory: DuckDBMemory): +def azure_openai_target(memory_fixture: DuckDBMemory): return AzureOpenAIChatTarget( deployment_name="test", endpoint="test", api_key="test", - memory=memory, + memory=memory_fixture, ) diff --git a/tests/test_prompt_target_azure_blob_storage.py b/tests/test_prompt_target_azure_blob_storage.py index 65c14ae60..c2f774c0e 100644 --- a/tests/test_prompt_target_azure_blob_storage.py +++ b/tests/test_prompt_target_azure_blob_storage.py @@ -10,15 +10,15 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import AzureBlobStorageTarget -from tests.mocks import memory +from tests.mocks import memory_fixture @pytest.fixture -def azure_blob_storage_target(memory: DuckDBMemory): +def azure_blob_storage_target(memory_fixture: DuckDBMemory): return AzureBlobStorageTarget( container_url="https://test.blob.core.windows.net/test", sas_token="valid_sas_token", - memory=memory, + memory=memory_fixture, ) diff --git a/tests/test_prompt_target_text.py b/tests/test_prompt_target_text.py index f84d4b126..2675a3d4e 100644 --- a/tests/test_prompt_target_text.py +++ b/tests/test_prompt_target_text.py @@ -10,11 +10,11 @@ from pyrit.memory import DuckDBMemory, MemoryInterface from pyrit.prompt_target import TextTarget -from tests.mocks import memory +from tests.mocks import memory_fixture -def test_send_prompt_user_no_system(memory: DuckDBMemory): - no_op = TextTarget(memory=memory) +def test_send_prompt_user_no_system(memory_fixture: DuckDBMemory): + no_op = TextTarget(memory=memory_fixture) no_op.send_prompt( normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" @@ -25,11 +25,11 @@ def test_send_prompt_user_no_system(memory: DuckDBMemory): assert chats[0].role == "user" -def test_send_prompt_stream(memory: DuckDBMemory): +def test_send_prompt_stream(memory_fixture: DuckDBMemory): with NamedTemporaryFile(mode="w+", delete=False) as tmp_file: prompt = "hi, I am a victim chatbot, how can I help?" - no_op = TextTarget(memory=memory, text_stream=tmp_file) + no_op = TextTarget(memory=memory_fixture, text_stream=tmp_file) no_op.send_prompt(normalized_prompt=prompt, conversation_id="1", normalizer_id="2") tmp_file.seek(0) diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index 64849eee9..539a8bad6 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -18,7 +18,7 @@ from pyrit.memory import DuckDBMemory from pyrit.common.path import DATASETS_PATH -from tests.mocks import memory +from tests.mocks import memory_fixture @pytest.fixture @@ -27,12 +27,12 @@ def chat_completion_engine() -> AzureOpenAIChatTarget: @pytest.fixture -def prompt_target(memory) -> AzureOpenAIChatTarget: +def prompt_target(memory_fixture) -> AzureOpenAIChatTarget: return AzureOpenAIChatTarget( deployment_name="test", endpoint="test", api_key="test", - memory=memory, + memory=memory_fixture, ) @@ -116,7 +116,7 @@ def test_send_prompt_twice( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_fixture: DuckDBMemory, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -126,7 +126,7 @@ def test_send_prompt_twice( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory, + "memory": memory_fixture, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -183,7 +183,7 @@ def test_send_fixed_prompt_then_generated_prompt( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_fixture: DuckDBMemory, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -193,7 +193,7 @@ def test_send_fixed_prompt_then_generated_prompt( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory, + "memory": memory_fixture, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -250,7 +250,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_fixture: DuckDBMemory, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -260,7 +260,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory, + "memory": memory_fixture, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -305,7 +305,7 @@ def test_reach_goal_after_two_turns_end_token( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory: DuckDBMemory, + memory_fixture: DuckDBMemory, attack_strategy_as_str: bool, ): attack_strategy: Union[str | AttackStrategy] = ( @@ -314,7 +314,7 @@ def test_reach_goal_after_two_turns_end_token( red_teaming_orchestrator = EndTokenRedTeamingOrchestrator( red_teaming_chat=chat_completion_engine, - memory=memory, + memory=memory_fixture, attack_strategy=attack_strategy, initial_red_teaming_prompt="how can I help you?", prompt_target=prompt_target, From e9d3b967b746e427f1cbbf331a98d3841697d74e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 28 Mar 2024 14:46:04 -0700 Subject: [PATCH 11/15] build and bug fixes --- pyrit/analytics/conversation_analytics.py | 6 +-- pyrit/memory/duckdb_memory.py | 27 +++++++++- pyrit/memory/memory_interface.py | 28 +++-------- pyrit/memory/memory_models.py | 6 +-- pyrit/orchestrator/benchmark_orchestrator.py | 2 +- .../end_token_red_teaming_orchestrator.py | 2 +- pyrit/orchestrator/orchestrator_class.py | 2 +- .../orchestrator/red_teaming_orchestrator.py | 2 +- .../scoring_red_teaming_orchestrator.py | 2 +- tests/memory/test_memory_exporter.py | 9 +++- tests/memory/test_memory_interface.py | 49 ++++++++++++------- tests/mocks.py | 8 ++- tests/test_prompt_target.py | 15 ++++-- .../test_prompt_target_azure_blob_storage.py | 15 ++++-- tests/test_prompt_target_text.py | 20 +++++--- tests/test_red_teaming_orchestrator.py | 33 +++++++------ 16 files changed, 137 insertions(+), 89 deletions(-) diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index f3bb5b1ed..b019406aa 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -6,7 +6,6 @@ from sklearn.metrics.pairwise import cosine_similarity from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import ConversationMessageWithSimilarity, EmbeddingMessageWithSimilarity -from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData class ConversationAnalytics: @@ -37,7 +36,7 @@ def get_prompt_entries_with_same_converted_content( list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on content. """ - all_memories = self.memory_interface.get_all_prompt_entries(PromptMemoryEntry) + all_memories = self.memory_interface.get_all_prompt_entries() similar_messages = [] for memory in all_memories: @@ -67,7 +66,8 @@ def get_similar_chat_messages_by_embedding( List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on embedding similarity. """ - all_memories = self.memory_interface.get_all_prompt_entries(EmbeddingData) + + all_memories = self.memory_interface.get_all_embeddings() similar_messages = [] target_embedding = np.array(chat_message_embedding).reshape(1, -1) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 443cf42ed..25c2feefa 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -12,7 +12,7 @@ from sqlalchemy.engine.base import Engine from contextlib import closing -from pyrit.memory.memory_models import PromptMemoryEntry, Base +from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base from pyrit.memory.memory_embedding import default_memory_embedding_factory from pyrit.memory.memory_interface import MemoryInterface from pyrit.interfaces import EmbeddingSupport @@ -91,6 +91,13 @@ def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: result = self.query_entries(PromptMemoryEntry) return result + def get_all_embeddings(self) -> list[EmbeddingData]: + """ + Fetches all entries from the specified table and returns them as model instances. + """ + result = self.query_entries(EmbeddingData) + return result + def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: """ Retrieves a list of ConversationData objects that have the specified conversation ID. @@ -255,6 +262,24 @@ def update_entries(self, *, entries: list[Base], update_fields: dict) -> bool: logger.exception(f"Error updating entries: {e}") return False + def export_all_tables(self, *, export_type: str = "json"): + """ + Exports all table data using the specified exporter. + + Iterates over all tables, retrieves their data, and exports each to a file named after the table. + + Args: + export_type (str): The format to export the data in (defaults to "json"). + """ + table_models = self.get_all_table_models() + + for model in table_models: + data = self.query_entries(model) + table_name = model.__tablename__ + file_extension = f".{export_type}" + file_path = RESULTS_PATH / f"{table_name}{file_extension}" + self.exporter.export_data(data, file_path=file_path, export_type=export_type) + def dispose_engine(self): """ Dispose the engine and clean up resources. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5db4e2b7b..ba703c68d 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -34,6 +34,12 @@ def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: Loads all ConversationData from the memory storage handler. """ + @abc.abstractmethod + def get_all_embeddings(self) -> list[EmbeddingData]: + """ + Loads all EmbeddingData from the memory storage handler. + """ + @abc.abstractmethod def get_prompt_entries_with_conversation_id(self, *, conversation_id: str) -> list[PromptMemoryEntry]: """ @@ -93,7 +99,7 @@ def add_chat_message_to_memory( conversation: ChatMessage, conversation_id: str, normalizer_id: str = None, - labels: list[str] = {}, + labels: dict[str, str] = {}, ): """ Deprecated. Will be refactored and removed soon. It currently works incorrectly. @@ -119,7 +125,7 @@ def add_chat_messages_to_memory( conversations: list[ChatMessage], conversation_id: str, normalizer_id: str = None, - labels: dict[str:str] = {}, + labels: dict[str, str] = {}, ): """ Deprecated. Will be refactored and removed soon. It currently works incorrectly. @@ -151,24 +157,6 @@ def add_chat_messages_to_memory( self.insert_prompt_entries(entries=entries_to_add) - def export_all_tables(self, *, export_type: str = "json"): - """ - Exports all table data using the specified exporter. - - Iterates over all tables, retrieves their data, and exports each to a file named after the table. - - Args: - export_type (str): The format to export the data in (defaults to "json"). - """ - table_models = self.get_all_table_models() - - for model in table_models: - data = self.query_entries(model) - table_name = model.__tablename__ - file_extension = f".{export_type}" - file_path = RESULTS_PATH / f"{table_name}{file_extension}" - self.exporter.export_data(data, file_path=file_path, export_type=export_type) - def export_conversation_by_id(self, *, conversation_id: str, file_path: Path = None, export_type: str = "json"): """ Exports conversation data with the given conversation ID to a specified file. diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 336a50460..449ab5dff 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -60,7 +60,7 @@ class PromptMemoryEntry(Base): # type: ignore conversation_id = Column(String, nullable=False) sequence = Column(INTEGER, nullable=False) timestamp = Column(DateTime, nullable=False) - labels: Column[Dict[str, str]] = Column(JSON) + labels: Column[Dict[str, str]] = Column(JSON) # type: ignore prompt_metadata = Column(JSON) converters: "Column[list[PromptConverter]]" = Column(JSON) # type: ignore # noqa prompt_target: "Column[PromptTarget]" = Column(JSON) # type: ignore # noqa @@ -92,7 +92,7 @@ def __init__( converted_prompt_data_type: PromptDataType = "text", ): - self.id = id if id else uuid4() + self.id = id if id else uuid4() # type: ignore self.role = role self.conversation_id = conversation_id if conversation_id else str(uuid4()) @@ -100,7 +100,7 @@ def __init__( self.timestamp = datetime.utcnow() self.labels = labels - self.prompt_metadata = prompt_metadata + self.prompt_metadata = prompt_metadata # type: ignore self.converters = converters.to_json() if converters else None self.prompt_target = prompt_target.to_dict() if prompt_target else None diff --git a/pyrit/orchestrator/benchmark_orchestrator.py b/pyrit/orchestrator/benchmark_orchestrator.py index d3e3c03fe..db0b9114a 100644 --- a/pyrit/orchestrator/benchmark_orchestrator.py +++ b/pyrit/orchestrator/benchmark_orchestrator.py @@ -31,7 +31,7 @@ def __init__( chat_model_under_evaluation: PromptChatTarget, scorer: QuestionAnswerScorer, memory: MemoryInterface | None = None, - memory_labels: dict[str:str] = None, + memory_labels: dict[str, str] = None, evaluation_prompt: str | None = None, batch_size: int = 1, verbose: bool = False, diff --git a/pyrit/orchestrator/end_token_red_teaming_orchestrator.py b/pyrit/orchestrator/end_token_red_teaming_orchestrator.py index 25af59115..55959cabf 100644 --- a/pyrit/orchestrator/end_token_red_teaming_orchestrator.py +++ b/pyrit/orchestrator/end_token_red_teaming_orchestrator.py @@ -24,7 +24,7 @@ def __init__( end_token: Optional[str] = RED_TEAM_CONVERSATION_END_TOKEN, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: dict[str:str] = None, + memory_labels: dict[str, str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming target and a prompt target. diff --git a/pyrit/orchestrator/orchestrator_class.py b/pyrit/orchestrator/orchestrator_class.py index 0553a1efb..ef57b385f 100644 --- a/pyrit/orchestrator/orchestrator_class.py +++ b/pyrit/orchestrator/orchestrator_class.py @@ -22,7 +22,7 @@ def __init__( *, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: dict[str:str] = {}, + memory_labels: dict[str, str] = {}, verbose: bool = False, ): self._prompt_converters = prompt_converters if prompt_converters else [NoOpConverter()] diff --git a/pyrit/orchestrator/red_teaming_orchestrator.py b/pyrit/orchestrator/red_teaming_orchestrator.py index 2628e8c57..3f2fcdd9c 100644 --- a/pyrit/orchestrator/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/red_teaming_orchestrator.py @@ -38,7 +38,7 @@ def __init__( initial_red_teaming_prompt: str = "Begin Conversation", prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: dict[str:str] = None, + memory_labels: dict[str, str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming target and a prompt target. diff --git a/pyrit/orchestrator/scoring_red_teaming_orchestrator.py b/pyrit/orchestrator/scoring_red_teaming_orchestrator.py index 39c35f30f..67031b88a 100644 --- a/pyrit/orchestrator/scoring_red_teaming_orchestrator.py +++ b/pyrit/orchestrator/scoring_red_teaming_orchestrator.py @@ -22,7 +22,7 @@ def __init__( scorer: SupportTextClassification, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: list[str] = ["red-teaming-orchestrator"], + memory_labels: dict[str, str] = ["red-teaming-orchestrator"], verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming bot and a prompt target. diff --git a/tests/memory/test_memory_exporter.py b/tests/memory/test_memory_exporter.py index 85c9e64a9..3d56aeccf 100644 --- a/tests/memory/test_memory_exporter.py +++ b/tests/memory/test_memory_exporter.py @@ -2,11 +2,18 @@ # Licensed under the MIT license. import json +import pytest from pyrit.memory.memory_exporter import MemoryExporter +from pyrit.memory.memory_models import PromptMemoryEntry + from sqlalchemy.inspection import inspect +from tests.mocks import get_sample_conversations + -from tests.mocks import sample_conversations_fixture as sample_conversations # noqa: F401 +@pytest.fixture +def sample_conversations() -> list[PromptMemoryEntry]: + return get_sample_conversations() def model_to_dict(instance): diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 261ee48b7..40520b1a6 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -1,52 +1,65 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Generator +import pytest import random from string import ascii_lowercase from pyrit.memory import MemoryInterface from pyrit.models import ChatMessage -from tests.mocks import memory_fixture as memory # noqa: F401 +from tests.mocks import get_memory_interface + + +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() def generate_random_string(length: int = 10) -> str: return "".join(random.choice(ascii_lowercase) for _ in range(length)) -def test_memory(memory_fixture: MemoryInterface): - assert memory_fixture +def test_memory(memory_interface: MemoryInterface): + assert memory_interface -def test_conversation_memory_empty_by_default(memory_fixture: MemoryInterface): +def test_conversation_memory_empty_by_default(memory_interface: MemoryInterface): expected_count = 0 - c = memory_fixture.get_all_prompt_entries() + c = memory_interface.get_all_prompt_entries() assert len(c) == expected_count def test_count_of_memories_matches_number_of_conversations_added_1( - memory_fixture: MemoryInterface, + memory_interface: MemoryInterface, ): expected_count = 1 message = ChatMessage(role="user", content="Hello") - memory_fixture.add_chat_message_to_memory(conversation=message, conversation_id="1", labels={}) - c = memory_fixture.get_all_prompt_entries() + memory_interface.add_chat_message_to_memory(conversation=message, conversation_id="1", labels={}) + c = memory_interface.get_all_prompt_entries() assert len(c) == expected_count -def test_add_chat_message_to_memory_added(memory_fixture: MemoryInterface): +def test_add_chat_message_to_memory_added(memory_interface: MemoryInterface): expected_count = 3 - memory_fixture.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1") - memory_fixture.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1") - memory_fixture.add_chat_message_to_memory(conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1") - assert len(memory_fixture.get_all_prompt_entries()) == expected_count - - -def test_add_chat_messages_to_memory_added(memory_fixture: MemoryInterface): + memory_interface.add_chat_message_to_memory( + conversation=ChatMessage(role="user", content="Hello 1"), conversation_id="1" + ) + memory_interface.add_chat_message_to_memory( + conversation=ChatMessage(role="user", content="Hello 2"), conversation_id="1" + ) + memory_interface.add_chat_message_to_memory( + conversation=ChatMessage(role="user", content="Hello 3"), conversation_id="1" + ) + assert len(memory_interface.get_all_prompt_entries()) == expected_count + + +def test_add_chat_messages_to_memory_added(memory_interface: MemoryInterface): messages = [ ChatMessage(role="user", content="Hello 1"), ChatMessage(role="user", content="Hello 2"), ] - memory_fixture.add_chat_messages_to_memory(conversations=messages, conversation_id="1") - assert len(memory_fixture.get_all_prompt_entries()) == len(messages) + memory_interface.add_chat_messages_to_memory(conversations=messages, conversation_id="1") + assert len(memory_interface.get_all_prompt_entries()) == len(messages) diff --git a/tests/mocks.py b/tests/mocks.py index 3f050abf1..858531a2c 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -2,8 +2,8 @@ # Licensed under the MIT license. from contextlib import AbstractAsyncContextManager +from typing import Generator -import pytest from sqlalchemy import inspect from pyrit.memory import DuckDBMemory, MemoryInterface @@ -68,8 +68,7 @@ async def send_prompt_async(self, normalized_prompt: str, conversation_id: str, self.prompt_sent.append(normalized_prompt) -@pytest.fixture -def memory_fixture() -> MemoryInterface: # type: ignore +def get_memory_interface() -> Generator[MemoryInterface, None, None]: # Create an in-memory DuckDB engine duckdb_memory = DuckDBMemory(db_path=":memory:") @@ -85,8 +84,7 @@ def memory_fixture() -> MemoryInterface: # type: ignore duckdb_memory.dispose_engine() -@pytest.fixture -def sample_conversations_fixture(): +def get_sample_conversations(): return [ PromptMemoryEntry( role="user", diff --git a/tests/test_prompt_target.py b/tests/test_prompt_target.py index 4a3780866..a25fdc899 100644 --- a/tests/test_prompt_target.py +++ b/tests/test_prompt_target.py @@ -1,17 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Generator from unittest.mock import patch import pytest -from sqlalchemy import inspect from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory.memory_interface import MemoryInterface from pyrit.prompt_target import AzureOpenAIChatTarget -from tests.mocks import memory_fixture +from tests.mocks import get_memory_interface + + +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() @pytest.fixture @@ -38,13 +43,13 @@ def chat_completion_engine() -> AzureOpenAIChatTarget: @pytest.fixture -def azure_openai_target(memory_fixture: DuckDBMemory): +def azure_openai_target(memory_interface: MemoryInterface): return AzureOpenAIChatTarget( deployment_name="test", endpoint="test", api_key="test", - memory=memory_fixture, + memory=memory_interface, ) diff --git a/tests/test_prompt_target_azure_blob_storage.py b/tests/test_prompt_target_azure_blob_storage.py index c2f774c0e..44f5e5550 100644 --- a/tests/test_prompt_target_azure_blob_storage.py +++ b/tests/test_prompt_target_azure_blob_storage.py @@ -2,23 +2,28 @@ # Licensed under the MIT license. import os +from typing import Generator import pytest -from sqlalchemy import inspect from unittest.mock import patch -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory import MemoryInterface from pyrit.prompt_target import AzureBlobStorageTarget -from tests.mocks import memory_fixture +from tests.mocks import get_memory_interface @pytest.fixture -def azure_blob_storage_target(memory_fixture: DuckDBMemory): +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() + + +@pytest.fixture +def azure_blob_storage_target(memory_interface: MemoryInterface): return AzureBlobStorageTarget( container_url="https://test.blob.core.windows.net/test", sas_token="valid_sas_token", - memory=memory_fixture, + memory=memory_interface, ) diff --git a/tests/test_prompt_target_text.py b/tests/test_prompt_target_text.py index 2675a3d4e..506af4171 100644 --- a/tests/test_prompt_target_text.py +++ b/tests/test_prompt_target_text.py @@ -3,18 +3,22 @@ import os from tempfile import NamedTemporaryFile +from typing import Generator import pytest -from sqlalchemy import inspect - -from pyrit.memory import DuckDBMemory, MemoryInterface +from pyrit.memory import MemoryInterface from pyrit.prompt_target import TextTarget -from tests.mocks import memory_fixture +from tests.mocks import get_memory_interface + + +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() -def test_send_prompt_user_no_system(memory_fixture: DuckDBMemory): - no_op = TextTarget(memory=memory_fixture) +def test_send_prompt_user_no_system(memory_interface: MemoryInterface): + no_op = TextTarget(memory=memory_interface) no_op.send_prompt( normalized_prompt="hi, I am a victim chatbot, how can I help?", conversation_id="1", normalizer_id="2" @@ -25,11 +29,11 @@ def test_send_prompt_user_no_system(memory_fixture: DuckDBMemory): assert chats[0].role == "user" -def test_send_prompt_stream(memory_fixture: DuckDBMemory): +def test_send_prompt_stream(memory_interface: MemoryInterface): with NamedTemporaryFile(mode="w+", delete=False) as tmp_file: prompt = "hi, I am a victim chatbot, how can I help?" - no_op = TextTarget(memory=memory_fixture, text_stream=tmp_file) + no_op = TextTarget(memory=memory_interface, text_stream=tmp_file) no_op.send_prompt(normalized_prompt=prompt, conversation_id="1", normalizer_id="2") tmp_file.seek(0) diff --git a/tests/test_red_teaming_orchestrator.py b/tests/test_red_teaming_orchestrator.py index 539a8bad6..95b3f1223 100644 --- a/tests/test_red_teaming_orchestrator.py +++ b/tests/test_red_teaming_orchestrator.py @@ -2,23 +2,26 @@ # Licensed under the MIT license. import pathlib -from typing import Union +from typing import Generator, Union from unittest.mock import Mock, patch +from pyrit.memory.memory_interface import MemoryInterface from pyrit.prompt_converter.prompt_converter import PromptConverter from pyrit.prompt_target.prompt_target import PromptTarget import pytest -from sqlalchemy import inspect -from pyrit.memory.memory_models import PromptMemoryEntry from pyrit.orchestrator import ScoringRedTeamingOrchestrator, EndTokenRedTeamingOrchestrator from pyrit.orchestrator.end_token_red_teaming_orchestrator import RED_TEAM_CONVERSATION_END_TOKEN from pyrit.prompt_target import AzureOpenAIChatTarget from pyrit.models import AttackStrategy, ChatMessage, Score -from pyrit.memory import DuckDBMemory from pyrit.common.path import DATASETS_PATH -from tests.mocks import memory_fixture +from tests.mocks import get_memory_interface + + +@pytest.fixture +def memory_interface() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() @pytest.fixture @@ -27,12 +30,12 @@ def chat_completion_engine() -> AzureOpenAIChatTarget: @pytest.fixture -def prompt_target(memory_fixture) -> AzureOpenAIChatTarget: +def prompt_target(memory_interface) -> AzureOpenAIChatTarget: return AzureOpenAIChatTarget( deployment_name="test", endpoint="test", api_key="test", - memory=memory_fixture, + memory=memory_interface, ) @@ -116,7 +119,7 @@ def test_send_prompt_twice( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory_fixture: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -126,7 +129,7 @@ def test_send_prompt_twice( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory_fixture, + "memory": memory_interface, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -183,7 +186,7 @@ def test_send_fixed_prompt_then_generated_prompt( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory_fixture: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -193,7 +196,7 @@ def test_send_fixed_prompt_then_generated_prompt( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory_fixture, + "memory": memory_interface, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -250,7 +253,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory_fixture: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, OrchestratorType: type, ): @@ -260,7 +263,7 @@ def test_send_fixed_prompt_beyond_first_iteration_failure( kwargs = { "red_teaming_chat": chat_completion_engine, - "memory": memory_fixture, + "memory": memory_interface, "attack_strategy": attack_strategy, "initial_red_teaming_prompt": "how can I help you?", "prompt_target": prompt_target, @@ -305,7 +308,7 @@ def test_reach_goal_after_two_turns_end_token( prompt_target: PromptTarget, chat_completion_engine: AzureOpenAIChatTarget, simple_attack_strategy: AttackStrategy, - memory_fixture: DuckDBMemory, + memory_interface: MemoryInterface, attack_strategy_as_str: bool, ): attack_strategy: Union[str | AttackStrategy] = ( @@ -314,7 +317,7 @@ def test_reach_goal_after_two_turns_end_token( red_teaming_orchestrator = EndTokenRedTeamingOrchestrator( red_teaming_chat=chat_completion_engine, - memory=memory_fixture, + memory=memory_interface, attack_strategy=attack_strategy, initial_red_teaming_prompt="how can I help you?", prompt_target=prompt_target, From 6c8bfc83b65bf8941cd58aff11dcbe230af66393 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 29 Mar 2024 09:00:02 -0700 Subject: [PATCH 12/15] fixing build and test --- pyrit/analytics/conversation_analytics.py | 4 ++-- pyrit/orchestrator/scoring_red_teaming_orchestrator.py | 2 +- tests/analytics/test_conversation_analytics.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index b019406aa..dd876ca58 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -67,12 +67,12 @@ def get_similar_chat_messages_by_embedding( the similar chat messages based on embedding similarity. """ - all_memories = self.memory_interface.get_all_embeddings() + all_embdedding_memory = self.memory_interface.get_all_embeddings() similar_messages = [] target_embedding = np.array(chat_message_embedding).reshape(1, -1) - for memory in all_memories: + for memory in all_embdedding_memory: if not hasattr(memory, "embedding") or memory.embedding is None: continue diff --git a/pyrit/orchestrator/scoring_red_teaming_orchestrator.py b/pyrit/orchestrator/scoring_red_teaming_orchestrator.py index 67031b88a..9a96fcdf5 100644 --- a/pyrit/orchestrator/scoring_red_teaming_orchestrator.py +++ b/pyrit/orchestrator/scoring_red_teaming_orchestrator.py @@ -22,7 +22,7 @@ def __init__( scorer: SupportTextClassification, prompt_converters: Optional[list[PromptConverter]] = None, memory: Optional[MemoryInterface] = None, - memory_labels: dict[str, str] = ["red-teaming-orchestrator"], + memory_labels: dict[str, str] = None, verbose: bool = False, ) -> None: """Creates an orchestrator to manage conversations between a red teaming bot and a prompt target. diff --git a/tests/analytics/test_conversation_analytics.py b/tests/analytics/test_conversation_analytics.py index 950ead8bb..be4691fc3 100644 --- a/tests/analytics/test_conversation_analytics.py +++ b/tests/analytics/test_conversation_analytics.py @@ -51,15 +51,14 @@ def test_get_similar_chat_messages_by_embedding(mock_memory_interface): similar_embedding = [0.1, 0.2, 0.31] # Slightly different, but should be similar different_embedding = [0.9, 0.8, 0.7] - mock_data = [ + mock_embeddings = [ EmbeddingData(id=conversation_entries[0].id, embedding=similar_embedding, embedding_type_name="model1"), EmbeddingData(id=conversation_entries[1].id, embedding=different_embedding, embedding_type_name="model2"), ] # Mock the get_all_prompt_entries method to return the mock EmbeddingData entries - mock_memory_interface.get_all_prompt_entries.side_effect = lambda model: ( - mock_data if model == EmbeddingData else conversation_entries - ) + mock_memory_interface.get_all_embeddings.return_value = mock_embeddings + mock_memory_interface.get_all_prompt_entries.return_value = conversation_entries analytics = ConversationAnalytics(memory_interface=mock_memory_interface) similar_messages = analytics.get_similar_chat_messages_by_embedding( From db9b492794b68df4212e219f153c831ce8e9cead Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 29 Mar 2024 19:16:39 -0700 Subject: [PATCH 13/15] adding additional tests and fixing bugs --- pyrit/memory/__init__.py | 4 +- pyrit/memory/duckdb_memory.py | 7 +- pyrit/memory/memory_models.py | 9 +- pyrit/orchestrator/orchestrator_class.py | 10 ++- pyrit/prompt_converter/__init__.py | 3 +- pyrit/prompt_target/prompt_target.py | 2 +- tests/memory/test_duckdb_memory.py | 70 ++++++++++++--- tests/memory/test_memory_models.py | 104 +++++++++++++++++++++++ 8 files changed, 187 insertions(+), 22 deletions(-) create mode 100644 tests/memory/test_memory_models.py diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index c6af289a9..226a45311 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.memory.memory_models import PromptMemoryEntry +from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory.duckdb_memory import DuckDBMemory from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter -__all__ = ["PromptMemoryEntry", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"] +__all__ = ["PromptMemoryEntry", "EmbeddingData", "MemoryInterface", "MemoryEmbedding", "DuckDBMemory", "MemoryExporter"] diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index 25c2feefa..eb600f0db 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -148,12 +148,15 @@ def insert_prompt_entries(self, *, entries: list[PromptMemoryEntry]) -> None: if self.memory_embedding: for chat_entry in entries: - embedding_entry = self.memory_embedding.generate_embedding_memory_data(chat_memory=chat_entry) embedding_entries.append(embedding_entry) + # The ordering of this is weird because after memories are inserted, we lose the reference to them + # and also entries must be inserted before embeddings because of the foreing key constraint self.insert_entries(entries=entries) - self.insert_entries(entries=embedding_entries) + + if embedding_entries: + self.insert_entries(entries=embedding_entries) def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool: """ diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 449ab5dff..ea9c4450e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -41,6 +41,7 @@ class PromptMemoryEntry(Base): # type: ignore e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. converters (list[PromptConverter]): The converters for the prompt. prompt_target (PromptTarget): The target for the prompt. + orchestrator (Orchestrator): The orchestrator for the prompt. original_prompt_data_type (PromptDataType): The data type of the original prompt (text, image) original_prompt_text (str): The text of the original prompt. If prompt is an image, it's a link. original_prompt_data_sha256 (str): The SHA256 hash of the original prompt data. @@ -64,6 +65,7 @@ class PromptMemoryEntry(Base): # type: ignore prompt_metadata = Column(JSON) converters: "Column[list[PromptConverter]]" = Column(JSON) # type: ignore # noqa prompt_target: "Column[PromptTarget]" = Column(JSON) # type: ignore # noqa + orchestrator: "Column[Orchestrator]" = Column(JSON) # type: ignore # noqa original_prompt_data_type: PromptDataType = Column(String, nullable=False) # type: ignore original_prompt_text = Column(String, nullable=False) @@ -88,6 +90,7 @@ def __init__( prompt_metadata: JSON = None, converters: "PromptConverterList" = None, # type: ignore # noqa prompt_target: "PromptTarget" = None, # type: ignore # noqa + orchestrator: "Orchestrator" = None, # type: ignore # noqa original_prompt_data_type: PromptDataType = "text", converted_prompt_data_type: PromptDataType = "text", ): @@ -103,7 +106,8 @@ def __init__( self.prompt_metadata = prompt_metadata # type: ignore self.converters = converters.to_json() if converters else None - self.prompt_target = prompt_target.to_dict() if prompt_target else None + self.prompt_target = prompt_target.to_json() if prompt_target else None + self.orchestrator = orchestrator.to_json() if orchestrator else None self.original_prompt_text = original_prompt_text self.original_prompt_data_type = original_prompt_data_type @@ -113,6 +117,9 @@ def __init__( self.converted_prompt_text = converted_prompt_text self.converted_prompt_data_sha256 = self._create_sha256(converted_prompt_text) + def is_sequence_set(self) -> bool: + return self.sequence != -1 + def _create_sha256(self, text: str) -> str: input_bytes = text.encode("utf-8") hash_object = hashlib.sha256(input_bytes) diff --git a/pyrit/orchestrator/orchestrator_class.py b/pyrit/orchestrator/orchestrator_class.py index ef57b385f..68e3c1d8e 100644 --- a/pyrit/orchestrator/orchestrator_class.py +++ b/pyrit/orchestrator/orchestrator_class.py @@ -2,10 +2,11 @@ # Licensed under the MIT license. import abc +import json import logging - from typing import Optional +from uuid import uuid4 from pyrit.memory import MemoryInterface, DuckDBMemory from pyrit.prompt_converter import PromptConverter, NoOpConverter @@ -67,3 +68,10 @@ def dispose_db_engine(self) -> None: Dispose DuckDB database engine to release database connections and resources. """ self._memory.dispose_engine() + + def to_json(self): + s = {} + s["__type__"] = self.__class__.__name__ + s["__module__"] = self.__class__.__module__ + s["id"] = str(uuid4()) + return json.dumps(s) diff --git a/pyrit/prompt_converter/__init__.py b/pyrit/prompt_converter/__init__.py index e6ec88189..1ddcef9dc 100644 --- a/pyrit/prompt_converter/__init__.py +++ b/pyrit/prompt_converter/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from pyrit.prompt_converter.prompt_converter import PromptConverter +from pyrit.prompt_converter.prompt_converter import PromptConverter, PromptConverterList from pyrit.prompt_converter.ascii_art_converter import AsciiArtConverter from pyrit.prompt_converter.base64_converter import Base64Converter @@ -19,6 +19,7 @@ "Base64Converter", "NoOpConverter", "PromptConverter", + "PromptConverterList", "ROT13Converter", "StringJoinConverter", "TranslationConverter", diff --git a/pyrit/prompt_target/prompt_target.py b/pyrit/prompt_target/prompt_target.py index 19d659b6d..c2e90a140 100644 --- a/pyrit/prompt_target/prompt_target.py +++ b/pyrit/prompt_target/prompt_target.py @@ -31,7 +31,7 @@ async def send_prompt_async(self, *, normalized_prompt: str, conversation_id: st Sends a normalized prompt async to the prompt target. """ - def to_dict(self): + def to_json(self): public_attributes = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} public_attributes["__type__"] = self.__class__.__name__ public_attributes["__module__"] = self.__class__.__module__ diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 86c5b6172..99949298e 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -2,6 +2,8 @@ # Licensed under the MIT license. import json +import os +from typing import Generator import pytest import uuid from unittest.mock import MagicMock @@ -12,28 +14,18 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.sql.sqltypes import NullType +from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory import DuckDBMemory from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_converter.prompt_converter import PromptConverterList from pyrit.prompt_target.text_target import TextTarget +from tests.mocks import get_memory_interface @pytest.fixture -def setup_duckdb_database(): - # Create an in-memory DuckDB engine - duckdb_memory = DuckDBMemory(db_path=":memory:") - - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) - - # Verify that tables are created as expected - assert "PromptMemoryEntries" in inspector.get_table_names(), "PromptMemoryEntries table not created." - assert "EmbeddingData" in inspector.get_table_names(), "EmbeddingData table not created." - - yield duckdb_memory - duckdb_memory.dispose_engine() +def setup_duckdb_database() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() @pytest.fixture @@ -170,6 +162,34 @@ def test_insert_entry(setup_duckdb_database): assert inserted_entry.original_prompt_data_sha256 == sha265 +def test_insert_prompt_memories_inserts_embedding(): + mock_embedding_model = MagicMock() + duckdb_memory = DuckDBMemory(db_path=":memory:", embedding_model=mock_embedding_model) + + # Reset the database to ensure a clean state + duckdb_memory.reset_database() + inspector = inspect(duckdb_memory.engine) + + # Verify that tables are created as expected + assert "PromptMemoryEntries" in inspector.get_table_names(), "PromptMemoryEntries table not created." + assert "EmbeddingData" in inspector.get_table_names(), "EmbeddingData table not created." + + with duckdb_memory.get_session(): + id = uuid.uuid4() + entry = PromptMemoryEntry( + id=id, + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + + duckdb_memory.insert_prompt_entries(entries=[entry]) + + duckdb_memory.dispose_engine() + mock_embedding_model.generate_text_embedding.assert_called_once(), \ + "Embedding data should be generated since we passed this model in" + + def test_insert_entry_violates_constraint(setup_duckdb_database): # Generate a fixed UUID fixed_uuid = uuid.uuid4() @@ -249,6 +269,28 @@ def test_insert_embedding_entry(setup_duckdb_database): assert persisted_embedding_entry.embedding_type_name == "test_type" +def test_disable_embedding(): + mock_embedding_model = MagicMock() + duckdb_memory = DuckDBMemory(db_path=":memory:", disable_embedding=True, embedding_model=mock_embedding_model) + duckdb_memory.dispose_engine() + # Even though we passed an embedding_model, embedding should be disabled. + + assert ( + duckdb_memory.memory_embedding is None + ), "disable_memory flag was passed, so memory embedding should be disabled." + + +def test_default_embedding(): + os.environ["AZURE_OPENAI_EMBEDDING_KEY"] = "mock_key" + os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"] = "mock_key" + duckdb_memory = DuckDBMemory(db_path=":memory:") + duckdb_memory.dispose_engine() + + assert ( + duckdb_memory.memory_embedding is not None + ), "Memory embedding should be enabled when set with environment variables." + + def test_query_entries(setup_duckdb_database): # Insert some test data entries = [ diff --git a/tests/memory/test_memory_models.py b/tests/memory/test_memory_models.py new file mode 100644 index 000000000..bd960a204 --- /dev/null +++ b/tests/memory/test_memory_models.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import time + +from datetime import datetime +from unittest.mock import MagicMock +from pyrit.memory import PromptMemoryEntry, EmbeddingData +from pyrit.orchestrator import PromptSendingOrchestrator +from pyrit.prompt_converter import Base64Converter, PromptConverterList +from tests.mocks import MockPromptTarget + + +def test_id_set(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.id is not None + +def test_datetime_set(): + now = datetime.utcnow() + time.sleep(.1) + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.timestamp > now + +def test_is_sequence_set_false(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.is_sequence_set() == False + + +def test_is_sequence_set_true(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + sequence=1 + ) + assert entry.is_sequence_set() == True + +def test_converters_serialize(): + converters = PromptConverterList([Base64Converter()]) + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + converters=converters + ) + assert entry.converters == \ + '[{"__type__": "Base64Converter", "__module__": "pyrit.prompt_converter.base64_converter"}]' + +def test_prompt_targets_serialize(): + target = MockPromptTarget() + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + prompt_target=target + ) + + j = json.loads(entry.prompt_target) + + assert j["__type__"] == "MockPromptTarget" + assert j["__module__"] == "tests.mocks" + +def test_orchestrators_serialize(): + orchestrator = PromptSendingOrchestrator( + prompt_target=MagicMock(), + memory=MagicMock() + ) + + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + orchestrator=orchestrator + ) + + j = json.loads(entry.orchestrator) + + assert j["id"] is not None + assert j["__type__"] == "PromptSendingOrchestrator" + assert j["__module__"] == "pyrit.orchestrator.prompt_sending_orchestrator" + + +def test_hashes_generated(): + entry = PromptMemoryEntry( + role="user", + original_prompt_text="Hello1", + converted_prompt_text="Hello2", + ) + + assert entry.original_prompt_data_sha256 == "948edbe7ede5aa7423476ae29dcd7d61e7711a071aea0d83698377effa896525" + assert entry.converted_prompt_data_sha256 == "be98c2510e417405647facb89399582fc499c3de4452b3014857f92e6baad9a9" \ No newline at end of file From 4d0a7ea1a293bb5af75ae5707b1af23d319bac44 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Sat, 30 Mar 2024 18:43:16 -0700 Subject: [PATCH 14/15] configuring pytest --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 14fc521c8..8edc523fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,9 @@ dev = [ "types-PyYAML>=6.0.12.9", ] +[tool.pytest.ini_options] +pythonpath = ["."] + [tool.mypy] plugins = [] ignore_missing_imports = true From a88fb1fe50137890e9c7f9db7af8ba83f0a3da1c Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Sat, 30 Mar 2024 19:10:42 -0700 Subject: [PATCH 15/15] making calls to enable_embedding explicit --- pyrit/memory/duckdb_memory.py | 6 -- pyrit/memory/memory_embedding.py | 4 +- pyrit/memory/memory_interface.py | 7 +++ tests/memory/test_duckdb_memory.py | 54 ++++++++--------- tests/memory/test_memory_embedding.py | 4 +- tests/memory/test_memory_models.py | 85 ++++++++++++--------------- tests/mocks.py | 2 + 7 files changed, 78 insertions(+), 84 deletions(-) diff --git a/pyrit/memory/duckdb_memory.py b/pyrit/memory/duckdb_memory.py index eb600f0db..8bb5efc29 100644 --- a/pyrit/memory/duckdb_memory.py +++ b/pyrit/memory/duckdb_memory.py @@ -13,9 +13,7 @@ from contextlib import closing from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base -from pyrit.memory.memory_embedding import default_memory_embedding_factory from pyrit.memory.memory_interface import MemoryInterface -from pyrit.interfaces import EmbeddingSupport from pyrit.common.path import RESULTS_PATH from pyrit.common.singleton import Singleton @@ -36,13 +34,9 @@ def __init__( self, *, db_path: Union[Path, str] = None, - embedding_model: EmbeddingSupport = None, - disable_embedding: bool = False, verbose: bool = False, ): super(DuckDBMemory, self).__init__() - if not disable_embedding: - self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) if db_path == ":memory:": self.db_path: Union[Path, str] = ":memory:" diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index 0ad7c68da..3dc03ccda 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -54,4 +54,6 @@ def default_memory_embedding_factory(embedding_model: EmbeddingSupport = None) - model = AzureTextEmbedding(api_key=api_key, endpoint=api_base, deployment=deployment) return MemoryEmbedding(embedding_model=model) else: - return None + raise ValueError( + "No embedding model was provided and no Azure OpenAI embedding model was found in the environment." + ) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index ba703c68d..f147c046c 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -4,6 +4,7 @@ import abc from pathlib import Path +from pyrit.memory.memory_embedding import default_memory_embedding_factory from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter @@ -28,6 +29,12 @@ def __init__(self, embedding_model=None): # Initialize the MemoryExporter instance self.exporter = MemoryExporter() + def enable_embedding(self, embedding_model=None): + self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) + + def disable_embedding(self): + self.memory_embedding = None + @abc.abstractmethod def get_all_prompt_entries(self) -> list[PromptMemoryEntry]: """ diff --git a/tests/memory/test_duckdb_memory.py b/tests/memory/test_duckdb_memory.py index 99949298e..d38a7de5e 100644 --- a/tests/memory/test_duckdb_memory.py +++ b/tests/memory/test_duckdb_memory.py @@ -16,7 +16,6 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData -from pyrit.memory import DuckDBMemory from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_converter.prompt_converter import PromptConverterList from pyrit.prompt_target.text_target import TextTarget @@ -162,19 +161,12 @@ def test_insert_entry(setup_duckdb_database): assert inserted_entry.original_prompt_data_sha256 == sha265 -def test_insert_prompt_memories_inserts_embedding(): - mock_embedding_model = MagicMock() - duckdb_memory = DuckDBMemory(db_path=":memory:", embedding_model=mock_embedding_model) +def test_insert_prompt_memories_inserts_embedding(setup_duckdb_database): - # Reset the database to ensure a clean state - duckdb_memory.reset_database() - inspector = inspect(duckdb_memory.engine) + embedding_mock = MagicMock() + setup_duckdb_database.enable_embedding(embedding_model=embedding_mock) - # Verify that tables are created as expected - assert "PromptMemoryEntries" in inspector.get_table_names(), "PromptMemoryEntries table not created." - assert "EmbeddingData" in inspector.get_table_names(), "EmbeddingData table not created." - - with duckdb_memory.get_session(): + with setup_duckdb_database.get_session(): id = uuid.uuid4() entry = PromptMemoryEntry( id=id, @@ -183,11 +175,12 @@ def test_insert_prompt_memories_inserts_embedding(): converted_prompt_text="Hello", ) - duckdb_memory.insert_prompt_entries(entries=[entry]) + setup_duckdb_database.insert_prompt_entries(entries=[entry]) + + setup_duckdb_database.dispose_engine() - duckdb_memory.dispose_engine() - mock_embedding_model.generate_text_embedding.assert_called_once(), \ - "Embedding data should be generated since we passed this model in" + # Embedding data should be generated since we passed this model in + embedding_mock.generate_text_embedding.assert_called_once() def test_insert_entry_violates_constraint(setup_duckdb_database): @@ -269,28 +262,35 @@ def test_insert_embedding_entry(setup_duckdb_database): assert persisted_embedding_entry.embedding_type_name == "test_type" -def test_disable_embedding(): - mock_embedding_model = MagicMock() - duckdb_memory = DuckDBMemory(db_path=":memory:", disable_embedding=True, embedding_model=mock_embedding_model) - duckdb_memory.dispose_engine() - # Even though we passed an embedding_model, embedding should be disabled. +def test_disable_embedding(setup_duckdb_database): + setup_duckdb_database.disable_embedding() assert ( - duckdb_memory.memory_embedding is None + setup_duckdb_database.memory_embedding is None ), "disable_memory flag was passed, so memory embedding should be disabled." -def test_default_embedding(): +def test_default_enable_embedding(setup_duckdb_database): os.environ["AZURE_OPENAI_EMBEDDING_KEY"] = "mock_key" - os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"] = "mock_key" - duckdb_memory = DuckDBMemory(db_path=":memory:") - duckdb_memory.dispose_engine() + os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"] = "embedding" + os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] = "deployment" + + setup_duckdb_database.enable_embedding() assert ( - duckdb_memory.memory_embedding is not None + setup_duckdb_database.memory_embedding is not None ), "Memory embedding should be enabled when set with environment variables." +def test_default_embedding_raises(setup_duckdb_database): + os.environ["AZURE_OPENAI_EMBEDDING_KEY"] = "" + os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"] = "" + os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] = "" + + with pytest.raises(ValueError): + setup_duckdb_database.enable_embedding() + + def test_query_entries(setup_duckdb_database): # Insert some test data entries = [ diff --git a/tests/memory/test_memory_embedding.py b/tests/memory/test_memory_embedding.py index 01b6d4709..8e459bf53 100644 --- a/tests/memory/test_memory_embedding.py +++ b/tests/memory/test_memory_embedding.py @@ -84,5 +84,5 @@ def test_default_memory_embedding_factory_without_embedding_model_and_environmen monkeypatch.delenv("AZURE_OPENAI_EMBEDDING_ENDPOINT", raising=False) monkeypatch.delenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", raising=False) - memory_embedding = default_memory_embedding_factory() - assert memory_embedding is None + with pytest.raises(ValueError): + default_memory_embedding_factory() diff --git a/tests/memory/test_memory_models.py b/tests/memory/test_memory_models.py index bd960a204..f1e0fe76e 100644 --- a/tests/memory/test_memory_models.py +++ b/tests/memory/test_memory_models.py @@ -6,7 +6,7 @@ from datetime import datetime from unittest.mock import MagicMock -from pyrit.memory import PromptMemoryEntry, EmbeddingData +from pyrit.memory import PromptMemoryEntry from pyrit.orchestrator import PromptSendingOrchestrator from pyrit.prompt_converter import Base64Converter, PromptConverterList from tests.mocks import MockPromptTarget @@ -14,77 +14,66 @@ def test_id_set(): entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - ) + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) assert entry.id is not None + def test_datetime_set(): now = datetime.utcnow() - time.sleep(.1) + time.sleep(0.1) entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - ) + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) assert entry.timestamp > now + def test_is_sequence_set_false(): entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - ) - assert entry.is_sequence_set() == False + role="user", + original_prompt_text="Hello", + converted_prompt_text="Hello", + ) + assert entry.is_sequence_set() is False def test_is_sequence_set_true(): - entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - sequence=1 - ) - assert entry.is_sequence_set() == True + entry = PromptMemoryEntry(role="user", original_prompt_text="Hello", converted_prompt_text="Hello", sequence=1) + assert entry.is_sequence_set() + def test_converters_serialize(): converters = PromptConverterList([Base64Converter()]) entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - converters=converters - ) - assert entry.converters == \ - '[{"__type__": "Base64Converter", "__module__": "pyrit.prompt_converter.base64_converter"}]' + role="user", original_prompt_text="Hello", converted_prompt_text="Hello", converters=converters + ) + assert ( + entry.converters == '[{"__type__": "Base64Converter", "__module__": "pyrit.prompt_converter.base64_converter"}]' + ) + def test_prompt_targets_serialize(): target = MockPromptTarget() entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - prompt_target=target - ) + role="user", original_prompt_text="Hello", converted_prompt_text="Hello", prompt_target=target + ) j = json.loads(entry.prompt_target) assert j["__type__"] == "MockPromptTarget" assert j["__module__"] == "tests.mocks" + def test_orchestrators_serialize(): - orchestrator = PromptSendingOrchestrator( - prompt_target=MagicMock(), - memory=MagicMock() - ) + orchestrator = PromptSendingOrchestrator(prompt_target=MagicMock(), memory=MagicMock()) entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello", - converted_prompt_text="Hello", - orchestrator=orchestrator - ) + role="user", original_prompt_text="Hello", converted_prompt_text="Hello", orchestrator=orchestrator + ) j = json.loads(entry.orchestrator) @@ -95,10 +84,10 @@ def test_orchestrators_serialize(): def test_hashes_generated(): entry = PromptMemoryEntry( - role="user", - original_prompt_text="Hello1", - converted_prompt_text="Hello2", - ) + role="user", + original_prompt_text="Hello1", + converted_prompt_text="Hello2", + ) assert entry.original_prompt_data_sha256 == "948edbe7ede5aa7423476ae29dcd7d61e7711a071aea0d83698377effa896525" - assert entry.converted_prompt_data_sha256 == "be98c2510e417405647facb89399582fc499c3de4452b3014857f92e6baad9a9" \ No newline at end of file + assert entry.converted_prompt_data_sha256 == "be98c2510e417405647facb89399582fc499c3de4452b3014857f92e6baad9a9" diff --git a/tests/mocks.py b/tests/mocks.py index 858531a2c..3e3cce529 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -72,6 +72,8 @@ def get_memory_interface() -> Generator[MemoryInterface, None, None]: # Create an in-memory DuckDB engine duckdb_memory = DuckDBMemory(db_path=":memory:") + duckdb_memory.disable_embedding() + # Reset the database to ensure a clean state duckdb_memory.reset_database() inspector = inspect(duckdb_memory.engine)