diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5200613b..c980e69ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,8 +54,16 @@ repos: rev: v1.15.0 hooks: - id: mypy + name: mypy (strict) + files: &strict_modules ^pyrit/(auth|analytics|embedding|exceptions|memory|prompt_normalizer)/ + args: [--install-types, --non-interactive, --ignore-missing-imports, --sqlite-cache, --cache-dir=.mypy_cache, --strict] + entry: mypy + language: system + types: [ python ] + - id: mypy + name: mypy (regular) + exclude: *strict_modules args: [--install-types, --non-interactive, --ignore-missing-imports, --sqlite-cache, --cache-dir=.mypy_cache] - name: mypy entry: mypy language: system types: [ python ] diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index b3dcef613..305572ec0 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -83,24 +83,32 @@ def get_similar_chat_messages_by_embedding( if similarity_score >= threshold: similar_messages.append( EmbeddingMessageWithSimilarity( - score=float(similarity_score), - uuid=memory.id, - metric="cosine_similarity", # type: ignore + score=float(similarity_score), uuid=memory.id, metric="cosine_similarity" ) ) return similar_messages -def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray: +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: """ - Calculate the cosine similarity between two vectors. + Calculate the cosine similarity between two 1D vectors. Args: a (np.ndarray): The first vector. b (np.ndarray): The second vector. Returns: - np.ndarray: The cosine similarity between the two vectors. + float: The cosine similarity between the two 1D vectors. + + Raises: + ValueError: If the input vectors are not 1D. """ - return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + # Ensure we are dealing with 1D vectors + if a.ndim != 1 or b.ndim != 1: + raise ValueError("Inputs must be 1D vectors") + + dot_product = np.dot(a, b) + norms = np.linalg.norm(a) * np.linalg.norm(b) + + return float(dot_product / norms) diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 009938c6e..7db3c02e8 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -31,7 +31,7 @@ def _compute_stats(successes: int, failures: int, undetermined: int) -> AttackSt ) -def analyze_results(attack_results: list[AttackResult]) -> dict: +def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats | dict[str, AttackStats]]: """ Analyze a list of AttackResult objects and return overall and grouped statistics. diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 34d31f667..44b801600 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import time -from typing import Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Union, cast from urllib.parse import urlparse import msal @@ -20,6 +22,9 @@ get_bearer_token_provider as get_async_bearer_token_provider, ) +if TYPE_CHECKING: + import azure.cognitiveservices.speech as speechsdk + from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC from pyrit.auth.authenticator import Authenticator @@ -34,7 +39,7 @@ class TokenProviderCredential: get_azure_token_provider) and Azure SDK clients that require a TokenCredential object. """ - def __init__(self, token_provider: Callable[[], Union[str, Callable]]) -> None: + def __init__(self, token_provider: Callable[[], Union[str, Callable[..., Any]]]) -> None: """ Initialize TokenProviderCredential. @@ -43,7 +48,7 @@ def __init__(self, token_provider: Callable[[], Union[str, Callable]]) -> None: """ self._token_provider = token_provider - def get_token(self, *scopes, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """ Get an access token. @@ -116,7 +121,7 @@ def get_token(self) -> str: return self.token -def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = ""): +def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = "") -> str: """ Get access token from Azure CLI. @@ -130,13 +135,13 @@ def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = ""): try: credential = AzureCliCredential(tenant_id=tenant_id) token = credential.get_token(scope) - return token.token + return cast(str, token.token) except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with tenant ID '{tenant_id}': {e}") raise -def get_access_token_from_azure_msi(*, client_id: str, scope: str): +def get_access_token_from_azure_msi(*, client_id: str, scope: str) -> str: """ Connect to an AOAI endpoint via managed identity credential attached to an Azure resource. For proper setup and configuration of MSI @@ -152,13 +157,13 @@ def get_access_token_from_azure_msi(*, client_id: str, scope: str): try: credential = ManagedIdentityCredential(client_id=client_id) token = credential.get_token(scope) - return token.token + return cast(str, token.token) except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") raise -def get_access_token_from_msa_public_client(*, client_id: str, scope: str): +def get_access_token_from_msa_public_client(*, client_id: str, scope: str) -> str: """ Use MSA account to connect to an AOAI endpoint via interactive login. A browser window will open and ask for login credentials. @@ -173,7 +178,7 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str): try: app = msal.PublicClientApplication(client_id) result = app.acquire_token_interactive(scopes=[scope]) - return result["access_token"] + return cast(str, result["access_token"]) except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") raise @@ -298,7 +303,7 @@ def get_azure_openai_auth(endpoint: str): # type: ignore[no-untyped-def] return get_azure_async_token_provider(scope) -def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str): +def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str) -> speechsdk.SpeechConfig: """ Get the speech config using key/region pair (for key auth scenarios) or resource_id/region pair (for Entra auth scenarios). @@ -338,7 +343,7 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi raise ValueError("Insufficient information provided for Azure Speech service.") -def get_speech_config_from_default_azure_credential(resource_id: str, region: str): +def get_speech_config_from_default_azure_credential(resource_id: str, region: str) -> speechsdk.SpeechConfig: """ Get the speech config for the given resource ID and region. diff --git a/pyrit/auth/azure_storage_auth.py b/pyrit/auth/azure_storage_auth.py index aff8792df..5df1740b0 100644 --- a/pyrit/auth/azure_storage_auth.py +++ b/pyrit/auth/azure_storage_auth.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. from datetime import datetime, timedelta +from typing import cast from urllib.parse import urlparse from azure.identity.aio import DefaultAzureCredential @@ -88,11 +89,11 @@ async def get_sas_token(container_url: str) -> str: account_name=storage_account_name, container_name=container_name, user_delegation_key=user_delegation_key, - permission=ContainerSasPermissions(read=True, write=True, create=True, list=True, delete=True), + permission=ContainerSasPermissions(read=True, write=True, create=True, list=True, delete=True), # type: ignore expiry=expiry_time, start=start_time, ) finally: await credential.close() - return sas_token + return cast(str, sas_token) diff --git a/pyrit/common/default_values.py b/pyrit/common/default_values.py index 4f9fc6a52..31ebbb26c 100644 --- a/pyrit/common/default_values.py +++ b/pyrit/common/default_values.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -def get_required_value(*, env_var_name: str, passed_value: Any) -> Any: +def get_required_value(*, env_var_name: str, passed_value: Any) -> Optional[str]: """ Get a required value from an environment variable or a passed value, preferring the passed value. diff --git a/pyrit/embedding/_text_embedding.py b/pyrit/embedding/_text_embedding.py index c181b6caf..55857c725 100644 --- a/pyrit/embedding/_text_embedding.py +++ b/pyrit/embedding/_text_embedding.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import abc -from typing import Union +from typing import Any, Union import tenacity from openai import AzureOpenAI, OpenAI @@ -29,7 +29,7 @@ def __init__(self) -> None: ) @tenacity.retry(wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_delay(3)) - def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse: + def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse: """ Generate text embedding. diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index f0d819e15..157e5676f 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import asyncio -from typing import Awaitable, Callable, Optional +from typing import Any, Awaitable, Callable, Optional import tenacity from openai import AsyncOpenAI @@ -70,7 +70,7 @@ def __init__( super().__init__() @tenacity.retry(wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_delay(3)) - async def generate_text_embedding_async(self, text: str, **kwargs) -> EmbeddingResponse: + async def generate_text_embedding_async(self, text: str, **kwargs: Any) -> EmbeddingResponse: """ Generate text embedding asynchronously. @@ -99,7 +99,7 @@ async def generate_text_embedding_async(self, text: str, **kwargs) -> EmbeddingR ) return embedding_response - def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse: + def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse: """ Generate text embedding synchronously by calling the async method. diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 1c742e77f..10aad6497 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -5,7 +5,7 @@ import logging import os from abc import ABC -from typing import Callable, Optional +from typing import Any, Callable, Optional from openai import RateLimitError from tenacity import ( @@ -147,7 +147,9 @@ def __init__(self, *, message: str = "No prompt placeholder") -> None: super().__init__(message=message) -def pyrit_custom_result_retry(retry_function: Callable, retry_max_num_attempts: Optional[int] = None) -> Callable: +def pyrit_custom_result_retry( + retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None +) -> Callable[..., Any]: """ A decorator to apply retry logic with exponential backoff to a function. @@ -166,8 +168,9 @@ def pyrit_custom_result_retry(retry_function: Callable, retry_max_num_attempts: Callable: The decorated function with retry logic applied. """ - def inner_retry(func): + def inner_retry(func: Callable[..., Any]) -> Callable[..., Any]: # Use static value if explicitly provided, otherwise use dynamic getter + stop_strategy: stop_base if retry_max_num_attempts is not None: stop_strategy = stop_after_attempt(retry_max_num_attempts) else: @@ -184,7 +187,7 @@ def inner_retry(func): return inner_retry -def pyrit_target_retry(func: Callable) -> Callable: +def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to apply retry logic with exponential backoff to a function. @@ -209,7 +212,7 @@ def pyrit_target_retry(func: Callable) -> Callable: )(func) -def pyrit_json_retry(func: Callable) -> Callable: +def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to apply retry logic to a function. @@ -230,7 +233,7 @@ def pyrit_json_retry(func: Callable) -> Callable: )(func) -def pyrit_placeholder_retry(func: Callable) -> Callable: +def pyrit_placeholder_retry(func: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to apply retry logic. @@ -254,7 +257,7 @@ def pyrit_placeholder_retry(func: Callable) -> Callable: def handle_bad_request_exception( response_text: str, request: MessagePiece, - is_content_filter=False, + is_content_filter: bool = False, error_code: int = 400, ) -> Message: if ( diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 258f65267..e986421e9 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def log_exception(retry_state: RetryCallState): +def log_exception(retry_state: RetryCallState) -> None: # Log each retry attempt with exception details at ERROR level elapsed_time = time.monotonic() - retry_state.start_time call_count = retry_state.attempt_number diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 375aa4855..a16b3b3d0 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -13,6 +13,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload, sessionmaker from sqlalchemy.orm.session import Session +from sqlalchemy.sql.expression import TextClause from pyrit.auth.azure_auth import AzureAuth from pyrit.common import default_values @@ -118,7 +119,7 @@ def _resolve_sas_token(env_var_name: str, passed_value: Optional[str] = None) -> except ValueError: return None - def _init_storage_io(self): + def _init_storage_io(self) -> None: # Handle for Azure Blob Storage when using Azure SQL memory. self.results_storage_io = AzureBlobStorageIO( container_url=self._results_container_url, sas_token=self._results_container_sas_token @@ -187,7 +188,7 @@ def _enable_azure_authorization(self) -> None: """ @event.listens_for(self.engine, "do_connect") - def provide_token(_dialect, _conn_rec, cargs, cparams): + def provide_token(_dialect: Any, _conn_rec: Any, cargs: list[Any], cparams: dict[str, Any]) -> None: # Refresh token if it's close to expiry self._refresh_token_if_needed() @@ -202,7 +203,7 @@ def provide_token(_dialect, _conn_rec, cargs, cparams): # add the encoded token cparams["attrs_before"] = {self.SQL_COPT_SS_ACCESS_TOKEN: packed_azure_token} - def _create_tables_if_not_exist(self): + def _create_tables_if_not_exist(self) -> None: """ Create all tables defined in the Base metadata, if they don't already exist in the database. @@ -222,7 +223,7 @@ def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEnt """ self._insert_entries(entries=embedding_data) - def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list: + def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[TextClause]: """ Generate SQL conditions for filtering message pieces by memory labels. @@ -260,7 +261,7 @@ def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: json_id=str(attack_id) ) - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list: + def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. @@ -284,7 +285,9 @@ def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int] condition = text(conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list: + def _get_message_pieces_prompt_metadata_conditions( + self, *, prompt_metadata: dict[str, Union[str, int]] + ) -> list[TextClause]: """ Generate SQL conditions for filtering message pieces by prompt metadata. @@ -298,7 +301,7 @@ def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dic """ return self._get_metadata_conditions(prompt_metadata=prompt_metadata) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> TextClause: """ Generate SQL condition for filtering seed prompts by metadata. @@ -403,7 +406,7 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: + def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause: """ Get the SQL Azure implementation for filtering ScenarioResults by target endpoint. @@ -420,7 +423,7 @@ def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> An AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint""" ).bindparams(endpoint=f"%{endpoint.lower()}%") - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: + def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause: """ Get the SQL Azure implementation for filtering ScenarioResults by target model name. @@ -444,7 +447,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) - def dispose_engine(self): + def dispose_engine(self) -> None: """ Dispose the engine and clean up resources. """ @@ -462,7 +465,7 @@ def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: result: Sequence[EmbeddingDataEntry] = self._query_entries(EmbeddingDataEntry) return result - def _insert_entry(self, entry: Base) -> None: # type: ignore + def _insert_entry(self, entry: Base) -> None: """ Insert an entry into the Table. @@ -484,7 +487,7 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore # The following methods are not part of MemoryInterface, but seem # common between SQLAlchemy-based implementations, regardless of engine. # Perhaps we should find a way to refactor - def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore + def _insert_entries(self, *, entries: Sequence[Base]) -> None: """ Insert multiple entries into the database. @@ -514,9 +517,9 @@ def get_session(self) -> Session: def _query_entries( self, - Model, + model_class: type[Model], *, - conditions: Optional[Any] = None, # type: ignore + conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, ) -> MutableSequence[Model]: @@ -524,7 +527,7 @@ def _query_entries( Fetch data from the specified table model with optional conditions. Args: - Model: The SQLAlchemy model class to query. + model_class: The SQLAlchemy model class to query. conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (defaults to False). join_scores: Flag to join the scores table with entries (defaults to False). @@ -537,10 +540,10 @@ def _query_entries( """ with closing(self.get_session()) as session: try: - query = session.query(Model) - if join_scores and Model == PromptMemoryEntry: + query = session.query(model_class) + if join_scores and model_class == PromptMemoryEntry: query = query.options(joinedload(PromptMemoryEntry.scores)) - elif Model == AttackResultEntry: + elif model_class == AttackResultEntry: query = query.options( joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), joinedload(AttackResultEntry.last_score), @@ -551,10 +554,10 @@ def _query_entries( return query.distinct().all() return query.all() except SQLAlchemyError as e: - logger.exception(f"Error fetching data from table {Model.__tablename__}: {e}") + logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore raise - def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict) -> bool: # type: ignore + def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict[str, Any]) -> bool: """ Update the given entries with the specified field values. @@ -595,7 +598,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict logger.exception(f"Error updating entries: {e}") raise - def reset_database(self): + def reset_database(self) -> None: """Drop and recreate existing tables.""" # Drop all existing tables Base.metadata.drop_all(self.engine) diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py index f4ffb0c01..59679c6c1 100644 --- a/pyrit/memory/memory_exporter.py +++ b/pyrit/memory/memory_exporter.py @@ -15,7 +15,7 @@ class MemoryExporter: This class utilizes the strategy design pattern to select the appropriate export format. """ - def __init__(self): + def __init__(self) -> None: """ Initialize the MemoryExporter. @@ -29,7 +29,9 @@ def __init__(self): # Future formats can be added here } - def export_data(self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json"): # type: ignore + def export_data( + self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json" + ) -> None: """ Export the provided data to a file in the specified format. @@ -50,7 +52,7 @@ def export_data(self, data: list[MessagePiece], *, file_path: Optional[Path] = N else: raise ValueError(f"Unsupported export format: {export_type}") - def export_to_json(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore + def export_to_json(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: """ Export the provided data to a JSON file at the specified file path. Each item in the data list, representing a row from the table, @@ -73,7 +75,7 @@ def export_to_json(self, data: list[MessagePiece], file_path: Path = None) -> No with open(file_path, "w") as f: json.dump(export_data, f, indent=4) - def export_to_csv(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore + def export_to_csv(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: """ Export the provided data to a CSV file at the specified file path. Each item in the data list, representing a row from the table, @@ -99,7 +101,7 @@ def export_to_csv(self, data: list[MessagePiece], file_path: Path = None) -> Non writer.writeheader() writer.writerows(export_data) - def export_to_markdown(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore + def export_to_markdown(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: """ Export the provided data to a Markdown file at the specified file path. Each item in the data list is converted to a dictionary and formatted as a table. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index bd8a87feb..6d2ad9c0a 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -66,7 +66,7 @@ class MemoryInterface(abc.ABC): results_path: str = None engine: Engine = None - def __init__(self, embedding_model=None): + def __init__(self, embedding_model: Optional[Any] = None) -> None: """ Initialize the MemoryInterface. @@ -83,7 +83,7 @@ def __init__(self, embedding_model=None): # Ensure cleanup at process exit self.cleanup() - def enable_embedding(self, embedding_model=None): + def enable_embedding(self, embedding_model: Optional[Any] = None) -> None: """ Enable embedding functionality for the memory interface. @@ -97,7 +97,7 @@ def enable_embedding(self, embedding_model=None): """ self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) - def disable_embedding(self): + def disable_embedding(self) -> None: """ Disable embedding functionality for the memory interface. @@ -112,13 +112,13 @@ def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ @abc.abstractmethod - def _init_storage_io(self): + def _init_storage_io(self) -> None: """ Initialize the storage IO handler results_storage_io. """ @abc.abstractmethod - def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list: + def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[Any]: """ Return a list of conditions for filtering memory entries based on memory labels. @@ -133,7 +133,9 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str """ @abc.abstractmethod - def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list: + def _get_message_pieces_prompt_metadata_conditions( + self, *, prompt_metadata: dict[str, Union[str, int]] + ) -> list[Any]: """ Return a list of conditions for filtering memory entries based on prompt metadata. @@ -179,17 +181,17 @@ def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEnt @abc.abstractmethod def _query_entries( self, - Model, + model_class: type[Model], *, conditions: Optional[Any] = None, distinct: bool = False, - join_scores: bool = False, # type: ignore - ) -> MutableSequence[Model]: # type: ignore + join_scores: bool = False, + ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. Args: - Model: The SQLAlchemy model class corresponding to the table you want to query. + model_class: The SQLAlchemy model class corresponding to the table you want to query. conditions: SQLAlchemy filter conditions (Optional). distinct: Whether to return distinct rows only. Defaults to False. join_scores: Whether to join the scores table. Defaults to False. @@ -199,7 +201,7 @@ def _query_entries( """ @abc.abstractmethod - def _insert_entry(self, entry: Base) -> None: # type: ignore + def _insert_entry(self, entry: Base) -> None: """ Insert an entry into the Table. @@ -208,11 +210,11 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore """ @abc.abstractmethod - def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore + def _insert_entries(self, *, entries: Sequence[Base]) -> None: """Insert multiple entries into the database.""" @abc.abstractmethod - def get_session(self): # type: ignore + def get_session(self) -> Any: """ Provide a SQLAlchemy session for transactional operations. @@ -220,7 +222,7 @@ def get_session(self): # type: ignore Session: A SQLAlchemy session bound to the engine. """ - def _update_entry(self, entry: Base) -> None: # type: ignore + def _update_entry(self, entry: Base) -> None: """ Update an existing entry in the Table using merge. @@ -248,7 +250,7 @@ def _update_entry(self, entry: Base) -> None: # type: ignore raise @abc.abstractmethod - def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict) -> bool: # type: ignore + def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict[str, Any]) -> bool: """ Update the given entries with the specified field values. @@ -360,7 +362,7 @@ def get_scores( conditions: list[Any] = [] if score_ids: - conditions.append(ScoreEntry.id.in_(score_ids)) # type: ignore + conditions.append(ScoreEntry.id.in_(score_ids)) if score_type: conditions.append(ScoreEntry.score_type == score_type) if score_category: @@ -537,7 +539,7 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) if prompt_ids: prompt_ids = [str(pi) for pi in prompt_ids] - conditions.append(PromptMemoryEntry.id.in_(prompt_ids)) # type: ignore + conditions.append(PromptMemoryEntry.id.in_(prompt_ids)) if labels: conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) if prompt_metadata: @@ -547,20 +549,20 @@ def get_message_pieces( if sent_before: conditions.append(PromptMemoryEntry.timestamp <= sent_before) if original_values: - conditions.append(PromptMemoryEntry.original_value.in_(original_values)) # type: ignore + conditions.append(PromptMemoryEntry.original_value.in_(original_values)) if converted_values: - conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) # type: ignore + conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) if data_type: conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) if not_data_type: conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if converted_value_sha256: - conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) # type: ignore + conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True - ) # type: ignore + ) message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] return sort_message_pieces(message_pieces=message_pieces) except Exception as e: @@ -672,7 +674,7 @@ def add_message_to_memory(self, *, request: Message) -> None: self._add_embeddings_to_memory(embedding_data=embedding_entries) - def _update_sequence(self, *, message_pieces: Sequence[MessagePiece]): + def _update_sequence(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Update the sequence number of the message pieces in the conversation. @@ -689,7 +691,7 @@ def _update_sequence(self, *, message_pieces: Sequence[MessagePiece]): for piece in message_pieces: piece.sequence = sequence - def update_prompt_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool: + def update_prompt_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict[str, Any]) -> bool: """ Update prompt entries for a given conversation ID with the specified field values. @@ -723,7 +725,7 @@ def update_prompt_entries_by_conversation_id(self, *, conversation_id: str, upda logger.error(f"Failed to update entries with conversation_id {conversation_id}.") return success - def update_labels_by_conversation_id(self, *, conversation_id: str, labels: dict) -> bool: + def update_labels_by_conversation_id(self, *, conversation_id: str, labels: dict[str, Any]) -> bool: """ Update the labels of prompt entries in memory for a given conversation ID. @@ -756,12 +758,12 @@ def update_prompt_metadata_by_conversation_id( ) @abc.abstractmethod - def dispose_engine(self): + def dispose_engine(self) -> None: """ Dispose the engine and clean up resources. """ - def cleanup(self): + def cleanup(self) -> None: """ Ensure cleanup on process exit. """ @@ -824,17 +826,17 @@ def get_seeds( # Apply filters for non-list fields if value: - conditions.append(SeedEntry.value.contains(value)) # type: ignore + conditions.append(SeedEntry.value.contains(value)) if value_sha256: - conditions.append(SeedEntry.value_sha256.in_(value_sha256)) # type: ignore + conditions.append(SeedEntry.value_sha256.in_(value_sha256)) if dataset_name: conditions.append(SeedEntry.dataset_name == dataset_name) elif dataset_name_pattern: - conditions.append(SeedEntry.dataset_name.like(dataset_name_pattern)) # type: ignore + conditions.append(SeedEntry.dataset_name.like(dataset_name_pattern)) if prompt_group_ids: - conditions.append(SeedEntry.prompt_group_id.in_(prompt_group_ids)) # type: ignore + conditions.append(SeedEntry.prompt_group_id.in_(prompt_group_ids)) if data_types: - data_type_conditions = SeedEntry.data_type.in_(data_types) # type: ignore + data_type_conditions = SeedEntry.data_type.in_(data_types) conditions.append(data_type_conditions) if added_by: conditions.append(SeedEntry.added_by == added_by) @@ -857,18 +859,18 @@ def get_seeds( memory_entries: Sequence[SeedEntry] = self._query_entries( SeedEntry, conditions=and_(*conditions) if conditions else None, - ) # type: ignore + ) return [memory_entry.get_seed() for memory_entry in memory_entries] except Exception as e: logger.exception(f"Failed to retrieve prompts with dataset name {dataset_name} with error {e}") raise def _add_list_conditions( - self, field: InstrumentedAttribute, conditions: list, values: Optional[Sequence[str]] = None + self, field: InstrumentedAttribute[Any], conditions: list[Any], values: Optional[Sequence[str]] = None ) -> None: if values: for value in values: - conditions.append(field.contains(value)) # type: ignore + conditions.append(field.contains(value)) async def _serialize_seed_value(self, prompt: Seed) -> str: """ @@ -926,7 +928,7 @@ async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Op if prompt.date_added is None: prompt.date_added = current_time - prompt.set_encoding_metadata() + prompt.set_encoding_metadata() # type: ignore # Handle serialization for image, audio & video SeedPrompts if prompt.data_type in ["image_path", "audio_path", "video_path"]: @@ -1211,14 +1213,14 @@ def get_attack_results( if len(attack_result_ids) == 0: # Empty list means no results return [] - conditions.append(AttackResultEntry.id.in_(attack_result_ids)) # type: ignore + conditions.append(AttackResultEntry.id.in_(attack_result_ids)) if conversation_id: - conditions.append(AttackResultEntry.conversation_id == conversation_id) # type: ignore + conditions.append(AttackResultEntry.conversation_id == conversation_id) if objective: - conditions.append(AttackResultEntry.objective.contains(objective)) # type: ignore + conditions.append(AttackResultEntry.objective.contains(objective)) if objective_sha256: - conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) # type: ignore + conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) if outcome: conditions.append(AttackResultEntry.outcome == outcome) @@ -1399,26 +1401,26 @@ def get_scenario_results( if len(scenario_result_ids) == 0: # Empty list means no results return [] - conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) # type: ignore + conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) if scenario_name: # Normalize CLI snake_case names (e.g., "foundry" or "content_harms") # to class names (e.g., "Foundry" or "ContentHarms") # This allows users to query with either format normalized_name = ScenarioResult.normalize_scenario_name(scenario_name) - conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name)) # type: ignore + conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name)) if scenario_version is not None: - conditions.append(ScenarioResultEntry.scenario_version == scenario_version) # type: ignore + conditions.append(ScenarioResultEntry.scenario_version == scenario_version) if pyrit_version: - conditions.append(ScenarioResultEntry.pyrit_version == pyrit_version) # type: ignore + conditions.append(ScenarioResultEntry.pyrit_version == pyrit_version) if added_after: - conditions.append(ScenarioResultEntry.completion_time >= added_after) # type: ignore + conditions.append(ScenarioResultEntry.completion_time >= added_after) if added_before: - conditions.append(ScenarioResultEntry.completion_time <= added_before) # type: ignore + conditions.append(ScenarioResultEntry.completion_time <= added_before) if labels: # Use database-specific JSON query method @@ -1453,7 +1455,7 @@ def get_scenario_results( # Query all AttackResults in a single batch if there are any if all_conversation_ids: # Build condition to query multiple conversation IDs at once - attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)] # type: ignore + attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)] attack_entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*attack_conditions) ) @@ -1475,7 +1477,7 @@ def get_scenario_results( logger.exception(f"Failed to retrieve scenario results with error {e}") raise - def print_schema(self): + def print_schema(self) -> None: """Print the schema of all tables in the database.""" metadata = MetaData() metadata.reflect(bind=self.engine) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 67e664ac7..24f7b8072 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -21,13 +21,13 @@ Unicode, ) from sqlalchemy.dialects.sqlite import CHAR -from sqlalchemy.orm import ( # type: ignore +from sqlalchemy.orm import ( DeclarativeBase, Mapped, mapped_column, relationship, ) -from sqlalchemy.types import Uuid # type: ignore +from sqlalchemy.types import Uuid from pyrit.common.utils import to_sha256 from pyrit.models import ( @@ -47,7 +47,7 @@ ) -class CustomUUID(TypeDecorator): +class CustomUUID(TypeDecorator[uuid.UUID]): """ A custom UUID type that works consistently across different database backends. For SQLite, stores UUIDs as strings and converts them back to UUID objects. @@ -57,7 +57,7 @@ class CustomUUID(TypeDecorator): impl = CHAR cache_ok = True - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: Any) -> Any: """ Load the dialect-specific implementation for UUID handling. @@ -72,7 +72,7 @@ def load_dialect_impl(self, dialect): else: return dialect.type_descriptor(Uuid()) - def process_bind_param(self, value, dialect): + def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Optional[str]: """ Process a parameter value before binding it to a database statement. @@ -85,7 +85,7 @@ def process_bind_param(self, value, dialect): """ return str(value) if value else None - def process_result_value(self, value, dialect): + def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> Optional[uuid.UUID]: """ Process a result value after it has been retrieved from the database. @@ -96,9 +96,11 @@ def process_result_value(self, value, dialect): Returns: UUID or None: The UUID object or None if value is None. """ + if value is None: + return None if dialect.name == "sqlite": - return uuid.UUID(value) if value else None - return value + return uuid.UUID(value) if isinstance(value, str) else value + return value if isinstance(value, uuid.UUID) else uuid.UUID(value) class Base(DeclarativeBase): @@ -248,7 +250,7 @@ def get_message_piece(self) -> MessagePiece: message_piece.scores = [score.get_score() for score in self.scores] return message_piece - def __str__(self): + def __str__(self) -> str: """ Return a string representation of the memory entry. @@ -260,7 +262,7 @@ def __str__(self): return f": {self.role}: {self.converted_value}" -class EmbeddingDataEntry(Base): # type: ignore +class EmbeddingDataEntry(Base): """ Represents the embedding data associated with conversation entries in the database. Each embedding is linked to a specific conversation entry via an id. @@ -276,10 +278,10 @@ class EmbeddingDataEntry(Base): # type: ignore __table_args__ = {"extend_existing": True} id = mapped_column(Uuid(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True) # Use ARRAY for PostgreSQL, JSON for SQLite and MSSQL (SQL Server/Azure SQL) - embedding = mapped_column(ARRAY(Float).with_variant(JSON, "sqlite").with_variant(JSON, "mssql")) # type: ignore + embedding = mapped_column(ARRAY(Float).with_variant(JSON, "sqlite").with_variant(JSON, "mssql")) embedding_type_name = mapped_column(String) - def __str__(self): + def __str__(self) -> str: """ Return a string representation of the embedding data entry (its ID). @@ -289,7 +291,7 @@ def __str__(self): return f"{self.id}" -class ScoreEntry(Base): # type: ignore +class ScoreEntry(Base): """ Represents the Score Memory Entry. @@ -335,7 +337,7 @@ def __init__(self, *, entry: Score): self.objective = entry.objective @staticmethod - def _normalize_scorer_identifier(identifier: Dict) -> Dict[str, str]: + def _normalize_scorer_identifier(identifier: Dict[str, Any]) -> Dict[str, str]: """ Normalize scorer identifier to ensure all values are strings for database storage. Handles nested sub_identifiers by converting them to JSON strings. @@ -377,7 +379,7 @@ def get_score(self) -> Score: ) @staticmethod - def _denormalize_scorer_identifier(identifier: Dict[str, str]) -> Dict: + def _denormalize_scorer_identifier(identifier: Dict[str, str]) -> Dict[str, Any]: """ Denormalize scorer identifier when reading from database. Parses JSON strings back into dicts/lists for sub_identifier field. @@ -399,7 +401,7 @@ def _denormalize_scorer_identifier(identifier: Dict[str, str]) -> Dict: pass return denormalized - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: """ Convert this database entry to a dictionary. @@ -529,7 +531,7 @@ def __init__(self, *, entry: Seed): self.id = entry.id self.value = entry.value self.value_sha256 = entry.value_sha256 - self.data_type = entry.data_type # type: ignore + self.data_type = entry.data_type self.name = entry.name self.dataset_name = entry.dataset_name self.harm_categories = entry.harm_categories # type: ignore @@ -539,7 +541,7 @@ def __init__(self, *, entry: Seed): self.source = entry.source self.date_added = entry.date_added self.added_by = entry.added_by - self.prompt_metadata = entry.metadata # type: ignore + self.prompt_metadata = entry.metadata self.parameters = None if is_objective else entry.parameters # type: ignore self.prompt_group_id = entry.prompt_group_id self.sequence = None if is_objective else entry.sequence # type: ignore @@ -674,7 +676,7 @@ def __init__(self, *, entry: AttackResult): self.executed_turns = entry.executed_turns self.execution_time_ms = entry.execution_time_ms - self.outcome = entry.outcome.value # type: ignore + self.outcome = entry.outcome.value self.outcome_reason = entry.outcome_reason self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata) @@ -822,14 +824,14 @@ class ScenarioResultEntry(Base): scenario_description = mapped_column(Unicode, nullable=True) scenario_version = mapped_column(INTEGER, nullable=False, default=1) pyrit_version = mapped_column(String, nullable=False) - scenario_init_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) - objective_target_identifier: Mapped[dict] = mapped_column(JSON, nullable=False) - objective_scorer_identifier: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) + objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) + objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"]] = mapped_column( String, nullable=False, default="CREATED" ) attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False) - labels: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + labels: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) number_tries: Mapped[int] = mapped_column(INTEGER, nullable=False, default=0) completion_time = mapped_column(DateTime, nullable=False) timestamp = mapped_column(DateTime, nullable=False) @@ -849,7 +851,7 @@ def __init__(self, *, entry: ScenarioResult): self.scenario_init_data = entry.scenario_identifier.init_data self.objective_target_identifier = entry.objective_target_identifier self.objective_scorer_identifier = entry.objective_scorer_identifier - self.scenario_run_state = entry.scenario_run_state # type: ignore + self.scenario_run_state = entry.scenario_run_state self.labels = entry.labels self.number_tries = entry.number_tries self.completion_time = entry.completion_time @@ -905,9 +907,10 @@ def get_conversation_ids_by_attack_name(self) -> dict[str, list[str]]: Returns: Dictionary mapping attack names to lists of conversation IDs """ - return json.loads(self.attack_results_json) + result: dict[str, list[str]] = json.loads(self.attack_results_json) + return result - def __str__(self): + def __str__(self) -> str: """ Return a string representation of the scenario result entry. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 30d3e5cd7..30a251cf7 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -13,6 +13,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload, sessionmaker from sqlalchemy.orm.session import Session +from sqlalchemy.sql.expression import TextClause from pyrit.common.path import DB_DATA_PATH from pyrit.common.singleton import Singleton @@ -69,7 +70,7 @@ def __init__( self.SessionFactory = sessionmaker(bind=self.engine) self._create_tables_if_not_exist() - def _init_storage_io(self): + def _init_storage_io(self) -> None: # Handles disk-based storage for SQLite local memory. self.results_storage_io = DiskStorageIO() @@ -98,7 +99,7 @@ def _create_engine(self, *, has_echo: bool) -> Engine: logger.exception(f"Error creating the engine for the database: {e}") raise - def _create_tables_if_not_exist(self): + def _create_tables_if_not_exist(self) -> None: """ Create all tables defined in the Base metadata, if they don't already exist in the database. @@ -122,7 +123,7 @@ def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: result: Sequence[EmbeddingDataEntry] = self._query_entries(EmbeddingDataEntry) return result - def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list: + def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[TextClause]: """ Generate SQLAlchemy filter conditions for filtering conversation pieces by memory labels. For SQLite, we use JSON_EXTRACT function to handle JSON fields. @@ -138,7 +139,9 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str condition = text(json_conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) return [condition] - def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list: + def _get_message_pieces_prompt_metadata_conditions( + self, *, prompt_metadata: dict[str, Union[str, int]] + ) -> list[TextClause]: """ Generate SQLAlchemy filter conditions for filtering conversation pieces by prompt metadata. @@ -187,7 +190,7 @@ def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEnt """ self._insert_entries(entries=embedding_data) - def get_all_table_models(self) -> list[type[Base]]: # type: ignore + def get_all_table_models(self) -> list[type[Base]]: """ Return a list of all table models used in the database by inspecting the Base registry. @@ -199,17 +202,17 @@ def get_all_table_models(self) -> list[type[Base]]: # type: ignore def _query_entries( self, - Model, + model_class: type[Model], *, conditions: Optional[Any] = None, distinct: bool = False, - join_scores: bool = False, # type: ignore + join_scores: bool = False, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. Args: - Model: The SQLAlchemy model class corresponding to the table you want to query. + model_class: The SQLAlchemy model class corresponding to the table you want to query. conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (default is False). join_scores: Flag to join the scores table (default is False). @@ -222,10 +225,10 @@ def _query_entries( """ with closing(self.get_session()) as session: try: - query = session.query(Model) - if join_scores and Model == PromptMemoryEntry: + query = session.query(model_class) + if join_scores and model_class == PromptMemoryEntry: query = query.options(joinedload(PromptMemoryEntry.scores)) - elif Model == AttackResultEntry: + elif model_class == AttackResultEntry: query = query.options( joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), joinedload(AttackResultEntry.last_score), @@ -236,10 +239,10 @@ def _query_entries( return query.distinct().all() return query.all() except SQLAlchemyError as e: - logger.exception(f"Error fetching data from table {Model.__tablename__}: {e}") + logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore raise - def _insert_entry(self, entry: Base) -> None: # type: ignore + def _insert_entry(self, entry: Base) -> None: """ Insert an entry into the Table. @@ -258,7 +261,7 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore logger.exception(f"Error inserting entry into the table: {e}") raise - def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore + def _insert_entries(self, *, entries: Sequence[Base]) -> None: """ Insert multiple entries into the database. @@ -274,7 +277,7 @@ def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore logger.exception(f"Error inserting multiple entries into the table: {e}") raise - def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict) -> bool: # type: ignore + def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict[str, Any]) -> bool: """ Update the given entries with the specified field values. @@ -416,7 +419,7 @@ def export_conversations( json.dump(merged_data, f, indent=4) return file_path - def print_schema(self): + def print_schema(self) -> None: """ Print the schema of all tables in the SQLite database. """ @@ -430,7 +433,7 @@ def print_schema(self): default = f" DEFAULT {column.default}" if column.default else "" print(f" {column.name}: {column.type} {nullable}{default}") - def export_all_tables(self, *, export_type: str = "json"): + def export_all_tables(self, *, export_type: str = "json") -> None: """ Export all table data using the specified exporter. @@ -442,12 +445,12 @@ def export_all_tables(self, *, export_type: str = "json"): table_models = self.get_all_table_models() for model in table_models: - data = self._query_entries(model) # type: ignore + data = self._query_entries(model) table_name = model.__tablename__ file_extension = f".{export_type}" file_path = DB_DATA_PATH / f"{table_name}{file_extension}" # Convert to list for exporter compatibility - self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) + self.exporter.export_data(list(data), file_path=file_path, export_type=export_type) # type: ignore def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 0fd207d48..f15a16f43 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -24,7 +24,7 @@ class Message: message_pieces (Sequence[MessagePiece]): The list of message pieces. """ - def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False): + def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False) -> None: if not message_pieces: raise ValueError("Message must have at least one message piece.") self.message_pieces = message_pieces @@ -133,7 +133,7 @@ def set_simulated_role(self) -> None: if piece._role == "assistant": piece._role = "simulated_assistant" - def validate(self): + def validate(self) -> None: """ Validates the request response. """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 1a8f73d58..7a58c0c16 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -140,7 +140,7 @@ def __init__( self.scores = scores if scores else [] self.targeted_harm_categories = targeted_harm_categories if targeted_harm_categories else [] - async def set_sha256_values_async(self): + async def set_sha256_values_async(self) -> None: """ This method computes the SHA256 hash values asynchronously. It should be called after object creation if `original_value` and `converted_value` are set. diff --git a/pyrit/models/score.py b/pyrit/models/score.py index ed04fde49..6a9a422dd 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -4,7 +4,7 @@ import uuid from dataclasses import dataclass from datetime import datetime -from typing import Dict, List, Literal, Optional, Union, get_args +from typing import Any, Dict, List, Literal, Optional, Union, get_args ScoreType = Literal["true_false", "float_scale"] @@ -112,7 +112,7 @@ def validate(self, scorer_type, score_value): except ValueError: raise ValueError(f"Float scale scorers require a numeric score value. Got {score_value}") - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "id": str(self.id), "score_value": self.score_value, diff --git a/pyrit/models/seed.py b/pyrit/models/seed.py index fd3a846c9..13d76220c 100644 --- a/pyrit/models/seed.py +++ b/pyrit/models/seed.py @@ -153,7 +153,7 @@ def render_template_value_silent(self, **kwargs) -> str: logging.error("Error rendering template: %s", e) return self.value - async def set_sha256_value_async(self): + async def set_sha256_value_async(self) -> None: """ This method computes the SHA256 hash value asynchronously. It should be called after prompt `value` is serialized to text, diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 6ed95f64d..4be2729f1 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -156,7 +156,7 @@ async def _replace_text_match(self, match): result = await self.convert_async(prompt=match, input_type="text") return result - def get_identifier(self): + def get_identifier(self) -> dict[str, str]: """ Returns an identifier dictionary for the converter.