Skip to content
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,16 @@ repos:
rev: v1.15.0
hooks:
- id: mypy
name: mypy (strict)
files: &strict_modules ^pyrit/(auth|analytics|embedding|exceptions|memory|prompt_normalizer)/
args: [--install-types, --non-interactive, --ignore-missing-imports, --sqlite-cache, --cache-dir=.mypy_cache, --strict]
entry: mypy
language: system
types: [ python ]
- id: mypy
name: mypy (regular)
exclude: *strict_modules
args: [--install-types, --non-interactive, --ignore-missing-imports, --sqlite-cache, --cache-dir=.mypy_cache]
name: mypy
entry: mypy
language: system
types: [ python ]
Expand Down
22 changes: 15 additions & 7 deletions pyrit/analytics/conversation_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,32 @@ def get_similar_chat_messages_by_embedding(
if similarity_score >= threshold:
similar_messages.append(
EmbeddingMessageWithSimilarity(
score=float(similarity_score),
uuid=memory.id,
metric="cosine_similarity", # type: ignore
score=float(similarity_score), uuid=memory.id, metric="cosine_similarity"
)
)

return similar_messages


def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""
Calculate the cosine similarity between two vectors.
Calculate the cosine similarity between two 1D vectors.

Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.

Returns:
np.ndarray: The cosine similarity between the two vectors.
float: The cosine similarity between the two 1D vectors.

Raises:
ValueError: If the input vectors are not 1D.
"""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
# Ensure we are dealing with 1D vectors
if a.ndim != 1 or b.ndim != 1:
raise ValueError("Inputs must be 1D vectors")

dot_product = np.dot(a, b)
norms = np.linalg.norm(a) * np.linalg.norm(b)

return float(dot_product / norms)
2 changes: 1 addition & 1 deletion pyrit/analytics/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _compute_stats(successes: int, failures: int, undetermined: int) -> AttackSt
)


def analyze_results(attack_results: list[AttackResult]) -> dict:
def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats | dict[str, AttackStats]]:
"""
Analyze a list of AttackResult objects and return overall and grouped statistics.

Expand Down
27 changes: 16 additions & 11 deletions pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import logging
import time
from typing import Callable, Union
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from urllib.parse import urlparse

import msal
Expand All @@ -20,6 +22,9 @@
get_bearer_token_provider as get_async_bearer_token_provider,
)

if TYPE_CHECKING:
import azure.cognitiveservices.speech as speechsdk

from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC
from pyrit.auth.authenticator import Authenticator

Expand All @@ -34,7 +39,7 @@ class TokenProviderCredential:
get_azure_token_provider) and Azure SDK clients that require a TokenCredential object.
"""

def __init__(self, token_provider: Callable[[], Union[str, Callable]]) -> None:
def __init__(self, token_provider: Callable[[], Union[str, Callable[..., Any]]]) -> None:
"""
Initialize TokenProviderCredential.

Expand All @@ -43,7 +48,7 @@ def __init__(self, token_provider: Callable[[], Union[str, Callable]]) -> None:
"""
self._token_provider = token_provider

def get_token(self, *scopes, **kwargs) -> AccessToken:
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""
Get an access token.

Expand Down Expand Up @@ -116,7 +121,7 @@ def get_token(self) -> str:
return self.token


def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = ""):
def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = "") -> str:
"""
Get access token from Azure CLI.

Expand All @@ -130,13 +135,13 @@ def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = ""):
try:
credential = AzureCliCredential(tenant_id=tenant_id)
token = credential.get_token(scope)
return token.token
return cast(str, token.token)
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with tenant ID '{tenant_id}': {e}")
raise


def get_access_token_from_azure_msi(*, client_id: str, scope: str):
def get_access_token_from_azure_msi(*, client_id: str, scope: str) -> str:
"""
Connect to an AOAI endpoint via managed identity credential attached to an Azure resource.
For proper setup and configuration of MSI
Expand All @@ -152,13 +157,13 @@ def get_access_token_from_azure_msi(*, client_id: str, scope: str):
try:
credential = ManagedIdentityCredential(client_id=client_id)
token = credential.get_token(scope)
return token.token
return cast(str, token.token)
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}")
raise


def get_access_token_from_msa_public_client(*, client_id: str, scope: str):
def get_access_token_from_msa_public_client(*, client_id: str, scope: str) -> str:
"""
Use MSA account to connect to an AOAI endpoint via interactive login. A browser window
will open and ask for login credentials.
Expand All @@ -173,7 +178,7 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str):
try:
app = msal.PublicClientApplication(client_id)
result = app.acquire_token_interactive(scopes=[scope])
return result["access_token"]
return cast(str, result["access_token"])
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}")
raise
Expand Down Expand Up @@ -298,7 +303,7 @@ def get_azure_openai_auth(endpoint: str): # type: ignore[no-untyped-def]
return get_azure_async_token_provider(scope)


def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str):
def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str) -> speechsdk.SpeechConfig:
"""
Get the speech config using key/region pair (for key auth scenarios) or resource_id/region pair
(for Entra auth scenarios).
Expand Down Expand Up @@ -338,7 +343,7 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi
raise ValueError("Insufficient information provided for Azure Speech service.")


def get_speech_config_from_default_azure_credential(resource_id: str, region: str):
def get_speech_config_from_default_azure_credential(resource_id: str, region: str) -> speechsdk.SpeechConfig:
"""
Get the speech config for the given resource ID and region.

Expand Down
5 changes: 3 additions & 2 deletions pyrit/auth/azure_storage_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

from datetime import datetime, timedelta
from typing import cast
from urllib.parse import urlparse

from azure.identity.aio import DefaultAzureCredential
Expand Down Expand Up @@ -88,11 +89,11 @@ async def get_sas_token(container_url: str) -> str:
account_name=storage_account_name,
container_name=container_name,
user_delegation_key=user_delegation_key,
permission=ContainerSasPermissions(read=True, write=True, create=True, list=True, delete=True),
permission=ContainerSasPermissions(read=True, write=True, create=True, list=True, delete=True), # type: ignore
expiry=expiry_time,
start=start_time,
)
finally:
await credential.close()

return sas_token
return cast(str, sas_token)
2 changes: 1 addition & 1 deletion pyrit/common/default_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logger = logging.getLogger(__name__)


def get_required_value(*, env_var_name: str, passed_value: Any) -> Any:
def get_required_value(*, env_var_name: str, passed_value: Any) -> Optional[str]:
"""
Get a required value from an environment variable or a passed value,
preferring the passed value.
Expand Down
4 changes: 2 additions & 2 deletions pyrit/embedding/_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import abc
from typing import Union
from typing import Any, Union

import tenacity
from openai import AzureOpenAI, OpenAI
Expand All @@ -29,7 +29,7 @@ def __init__(self) -> None:
)

@tenacity.retry(wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_delay(3))
def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse:
def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse:
"""
Generate text embedding.

Expand Down
6 changes: 3 additions & 3 deletions pyrit/embedding/openai_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import asyncio
from typing import Awaitable, Callable, Optional
from typing import Any, Awaitable, Callable, Optional

import tenacity
from openai import AsyncOpenAI
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
super().__init__()

@tenacity.retry(wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_delay(3))
async def generate_text_embedding_async(self, text: str, **kwargs) -> EmbeddingResponse:
async def generate_text_embedding_async(self, text: str, **kwargs: Any) -> EmbeddingResponse:
"""
Generate text embedding asynchronously.

Expand Down Expand Up @@ -99,7 +99,7 @@ async def generate_text_embedding_async(self, text: str, **kwargs) -> EmbeddingR
)
return embedding_response

def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse:
def generate_text_embedding(self, text: str, **kwargs: Any) -> EmbeddingResponse:
"""
Generate text embedding synchronously by calling the async method.

Expand Down
17 changes: 10 additions & 7 deletions pyrit/exceptions/exception_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
from abc import ABC
from typing import Callable, Optional
from typing import Any, Callable, Optional

from openai import RateLimitError
from tenacity import (
Expand Down Expand Up @@ -147,7 +147,9 @@ def __init__(self, *, message: str = "No prompt placeholder") -> None:
super().__init__(message=message)


def pyrit_custom_result_retry(retry_function: Callable, retry_max_num_attempts: Optional[int] = None) -> Callable:
def pyrit_custom_result_retry(
retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None
) -> Callable[..., Any]:
"""
A decorator to apply retry logic with exponential backoff to a function.

Expand All @@ -166,8 +168,9 @@ def pyrit_custom_result_retry(retry_function: Callable, retry_max_num_attempts:
Callable: The decorated function with retry logic applied.
"""

def inner_retry(func):
def inner_retry(func: Callable[..., Any]) -> Callable[..., Any]:
# Use static value if explicitly provided, otherwise use dynamic getter
stop_strategy: stop_base
if retry_max_num_attempts is not None:
stop_strategy = stop_after_attempt(retry_max_num_attempts)
else:
Expand All @@ -184,7 +187,7 @@ def inner_retry(func):
return inner_retry


def pyrit_target_retry(func: Callable) -> Callable:
def pyrit_target_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to apply retry logic with exponential backoff to a function.

Expand All @@ -209,7 +212,7 @@ def pyrit_target_retry(func: Callable) -> Callable:
)(func)


def pyrit_json_retry(func: Callable) -> Callable:
def pyrit_json_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to apply retry logic to a function.

Expand All @@ -230,7 +233,7 @@ def pyrit_json_retry(func: Callable) -> Callable:
)(func)


def pyrit_placeholder_retry(func: Callable) -> Callable:
def pyrit_placeholder_retry(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to apply retry logic.

Expand All @@ -254,7 +257,7 @@ def pyrit_placeholder_retry(func: Callable) -> Callable:
def handle_bad_request_exception(
response_text: str,
request: MessagePiece,
is_content_filter=False,
is_content_filter: bool = False,
error_code: int = 400,
) -> Message:
if (
Expand Down
2 changes: 1 addition & 1 deletion pyrit/exceptions/exceptions_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = logging.getLogger(__name__)


def log_exception(retry_state: RetryCallState):
def log_exception(retry_state: RetryCallState) -> None:
# Log each retry attempt with exception details at ERROR level
elapsed_time = time.monotonic() - retry_state.start_time
call_count = retry_state.attempt_number
Expand Down
Loading