Skip to content

Commit f5e7a32

Browse files
committed
MAINT: Improve docstrings in pyrit/memory (#1176)
1 parent 93aca73 commit f5e7a32

File tree

9 files changed

+431
-94
lines changed

9 files changed

+431
-94
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ extend-select = [
251251
# Temporary ignores for pyrit/ subdirectories until issue #1176
252252
# https://github.com/Azure/PyRIT/issues/1176 is fully resolved
253253
# TODO: Remove these ignores once the issues are fixed
254-
"pyrit/{auxiliary_attacks,exceptions,executor,memory,models,prompt_converter,prompt_normalizer,prompt_target,score,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
254+
"pyrit/{auxiliary_attacks,exceptions,executor,models,prompt_converter,prompt_normalizer,prompt_target,score,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
255255
"pyrit/__init__.py" = ["D104"]
256256

257257
[tool.ruff.lint.pydocstyle]

pyrit/memory/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
"""
5+
Provide functionality for storing and retrieving conversation history and embeddings.
6+
7+
This package defines the core `MemoryInterface` and concrete implementations for different storage backends.
8+
"""
9+
410
from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry, SeedEntry, AttackResultEntry
511
from pyrit.memory.memory_interface import MemoryInterface
612

pyrit/memory/azure_sql_memory.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@ def __init__(
6060
results_sas_token: Optional[str] = None,
6161
verbose: bool = False,
6262
):
63+
"""
64+
Initialize an Azure SQL Memory backend.
65+
66+
Args:
67+
connection_string (Optional[str]): The connection string for the Azure Sql Database. If not provided,
68+
it falls back to the 'AZURE_SQL_DB_CONNECTION_STRING' environment variable.
69+
results_container_url (Optional[str]): The URL to an Azure Storage Container. If not provided,
70+
it falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL' environment variable.
71+
results_sas_token (Optional[str]): The Shared Access Signature (SAS) token for the storage container.
72+
If not provided, falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN' environment variable.
73+
verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False.
74+
"""
75+
self._init_storage_io()
6376
self._connection_string = default_values.get_required_value(
6477
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
6578
)
@@ -114,7 +127,7 @@ def _init_storage_io(self):
114127

115128
def _create_auth_token(self) -> None:
116129
"""
117-
Creates an Azure Entra ID access token.
130+
Create an Azure Entra ID access token.
118131
Stores the token and its expiry time.
119132
"""
120133
azure_auth = AzureAuth(token_scope=self.TOKEN_URL)
@@ -133,13 +146,19 @@ def _refresh_token_if_needed(self) -> None:
133146

134147
def _create_engine(self, *, has_echo: bool) -> Engine:
135148
"""
136-
Creates the SQLAlchemy engine for Azure SQL Server.
149+
Create the SQLAlchemy engine for Azure SQL Server.
137150
138151
Creates an engine bound to the specified server and database. The `has_echo` parameter
139152
controls the verbosity of SQL execution logging.
140153
141154
Args:
142155
has_echo (bool): Flag to enable detailed SQL execution logging.
156+
157+
Returns:
158+
Engine: SQLAlchemy engine bound to the AZURE SQL Database.
159+
160+
Raises:
161+
SQLAlchemyError: If the engine creation fails.
143162
"""
144163
try:
145164
# Create the SQLAlchemy engine.
@@ -156,6 +175,8 @@ def _create_engine(self, *, has_echo: bool) -> Engine:
156175

157176
def _enable_azure_authorization(self) -> None:
158177
"""
178+
Enable Azure token-based authorization for SQL connections.
179+
159180
The following is necessary because of how SQLAlchemy and PyODBC handle connection creation. In PyODBC, the
160181
token is passed outside the connection string in the `connect()` method. Since SQLAlchemy lazy-loads
161182
its connections, we need to set this as a separate argument to the `connect()` method. In SQLALchemy
@@ -184,7 +205,7 @@ def provide_token(_dialect, _conn_rec, cargs, cparams):
184205

185206
def _create_tables_if_not_exist(self):
186207
"""
187-
Creates all tables defined in the Base metadata, if they don't already exist in the database.
208+
Create all tables defined in the Base metadata, if they don't already exist in the database.
188209
189210
Raises:
190211
Exception: If there's an issue creating the tables in the database.
@@ -198,7 +219,7 @@ def _create_tables_if_not_exist(self):
198219

199220
def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None:
200221
"""
201-
Inserts embedding data into memory storage.
222+
Insert embedding data into memory storage.
202223
"""
203224
self._insert_entries(entries=embedding_data)
204225

@@ -295,7 +316,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
295316

296317
def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
297318
"""
298-
SQL Azure implementation for filtering AttackResults by targeted harm categories.
319+
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
299320
300321
Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array.
301322
@@ -333,7 +354,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
333354

334355
def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
335356
"""
336-
SQL Azure implementation for filtering AttackResults by labels.
357+
Get the SQL Azure implementation for filtering AttackResults by labels.
337358
338359
Uses JSON_VALUE() function specific to SQL Azure with parameterized queries.
339360
@@ -364,7 +385,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
364385

365386
def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any:
366387
"""
367-
SQL Azure implementation for filtering ScenarioResults by labels.
388+
Get the SQL Azure implementation for filtering ScenarioResults by labels.
368389
369390
Uses JSON_VALUE() function specific to SQL Azure.
370391
@@ -385,7 +406,7 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
385406

386407
def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any:
387408
"""
388-
SQL Azure implementation for filtering ScenarioResults by target endpoint.
409+
Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.
389410
390411
Uses JSON_VALUE() function specific to SQL Azure.
391412
@@ -402,7 +423,7 @@ def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> An
402423

403424
def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any:
404425
"""
405-
SQL Azure implementation for filtering ScenarioResults by target model name.
426+
Get the SQL Azure implementation for filtering ScenarioResults by target model name.
406427
407428
Uses JSON_VALUE() function specific to SQL Azure.
408429
@@ -419,7 +440,7 @@ def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any
419440

420441
def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
421442
"""
422-
Inserts a list of message pieces into the memory storage.
443+
Insert a list of message pieces into the memory storage.
423444
424445
"""
425446
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])
@@ -434,17 +455,23 @@ def dispose_engine(self):
434455

435456
def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]:
436457
"""
437-
Fetches all entries from the specified table and returns them as model instances.
458+
Fetch all entries from the specified table and returns them as model instances.
459+
460+
Returns:
461+
Sequence[EmbeddingDataEntry]: A sequence of EmbeddingDataEntry instances representing all stored embeddings.
438462
"""
439463
result: Sequence[EmbeddingDataEntry] = self._query_entries(EmbeddingDataEntry)
440464
return result
441465

442466
def _insert_entry(self, entry: Base) -> None: # type: ignore
443467
"""
444-
Inserts an entry into the Table.
468+
Insert an entry into the Table.
445469
446470
Args:
447471
entry: An instance of a SQLAlchemy model to be added to the Table.
472+
473+
Raises:
474+
SQLAlchemyError: If the insertion fails.
448475
"""
449476
with closing(self.get_session()) as session:
450477
try:
@@ -459,7 +486,15 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore
459486
# common between SQLAlchemy-based implementations, regardless of engine.
460487
# Perhaps we should find a way to refactor
461488
def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore
462-
"""Inserts multiple entries into the database."""
489+
"""
490+
Insert multiple entries into the database.
491+
492+
Args:
493+
entries (Sequence[Base]): A sequence of SQLAlchemy model instances to insert.
494+
495+
Raises:
496+
SQLAlchemyError: If the insertion fails.
497+
"""
463498
with closing(self.get_session()) as session:
464499
try:
465500
session.add_all(entries)
@@ -471,7 +506,10 @@ def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore
471506

472507
def get_session(self) -> Session:
473508
"""
474-
Provides a session for database operations.
509+
Provide a session for database operations.
510+
511+
Returns:
512+
Session: A new SQLAlchemy session bound to the configured engine.
475513
"""
476514
return self.SessionFactory()
477515

@@ -484,15 +522,19 @@ def _query_entries(
484522
join_scores: bool = False,
485523
) -> MutableSequence[Model]:
486524
"""
487-
Fetches data from the specified table model with optional conditions.
525+
Fetch data from the specified table model with optional conditions.
488526
489527
Args:
528+
Model: The SQLAlchemy model class to query.
490529
conditions: SQLAlchemy filter conditions (Optional).
491530
distinct: Flag to return distinct rows (defaults to False).
492531
join_scores: Flag to join the scores table with entries (defaults to False).
493532
494533
Returns:
495534
List of model instances representing the rows fetched from the table.
535+
536+
Raises:
537+
SQLAlchemyError: If the query fails.
496538
"""
497539
with closing(self.get_session()) as session:
498540
try:
@@ -515,14 +557,18 @@ def _query_entries(
515557

516558
def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict) -> bool: # type: ignore
517559
"""
518-
Updates the given entries with the specified field values.
560+
Update the given entries with the specified field values.
519561
520562
Args:
521563
entries (Sequence[Base]): A list of SQLAlchemy model instances to be updated.
522564
update_fields (dict): A dictionary of field names and their new values.
523565
524566
Returns:
525567
bool: True if the update was successful, False otherwise.
568+
569+
Raises:
570+
ValueError: If 'update_fields' is empty.
571+
SQLAlchemyError: If the update fails.
526572
"""
527573
if not update_fields:
528574
raise ValueError("update_fields must be provided to update prompt entries.")

pyrit/memory/central_memory.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
class CentralMemory:
1212
"""
13-
Provides a centralized memory instance across the framework. The provided memory
14-
instance will be reused for future calls.
13+
Provide a centralized memory instance across the framework.
14+
The provided memory instance will be reused for future calls.
1515
"""
1616

1717
_memory_instance: MemoryInterface = None
@@ -30,7 +30,13 @@ def set_memory_instance(cls, passed_memory: MemoryInterface) -> None:
3030
@classmethod
3131
def get_memory_instance(cls) -> MemoryInterface:
3232
"""
33-
Returns a centralized memory instance.
33+
Return a centralized memory instance.
34+
35+
Returns:
36+
MemoryInterface: The singleton memory instance.
37+
38+
Raises:
39+
ValueError: If the central memory instance has not been set.
3440
"""
3541
if cls._memory_instance:
3642
logger.info(f"Using existing memory instance: {type(cls._memory_instance).__name__}")

pyrit/memory/memory_embedding.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,32 @@ class MemoryEmbedding:
1818
"""
1919

2020
def __init__(self, *, embedding_model: Optional[EmbeddingSupport] = None):
21+
"""
22+
Initialize the memory embedding helper with a backing embedding model.
23+
24+
Args:
25+
embedding_model (Optional[EmbeddingSupport]): The embedding model used to
26+
generate text embeddings. If not provided, a ValueError is raised.
27+
28+
Raises:
29+
ValueError: If `embedding_model` is not provided.
30+
"""
2131
if embedding_model is None:
2232
raise ValueError("embedding_model must be set.")
2333
self.embedding_model = embedding_model
2434

2535
def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> EmbeddingDataEntry:
2636
"""
27-
Generates metadata for a message piece.
37+
Generate metadata for a message piece.
2838
2939
Args:
3040
message_piece (MessagePiece): the message piece for which to generate a text embedding
3141
3242
Returns:
3343
EmbeddingDataEntry: The generated metadata.
44+
45+
Raises:
46+
ValueError: If the message piece is not of type text.
3447
"""
3548
if message_piece.converted_value_data_type == "text":
3649
embedding_data = EmbeddingDataEntry(
@@ -46,6 +59,24 @@ def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> Embe
4659

4760

4861
def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None:
62+
"""
63+
Create a MemoryEmbedding instance with default or provided embedding model.
64+
65+
Factory function that creates a MemoryEmbedding instance. If an embedding_model
66+
is provided, it uses that model. Otherwise, it attempts to create an Azure
67+
OpenAI embedding model from environment variables.
68+
69+
Args:
70+
embedding_model: Optional embedding model to use. If not provided,
71+
attempts to create AzureTextEmbedding from environment variables.
72+
73+
Returns:
74+
MemoryEmbedding: Configured memory embedding instance.
75+
76+
Raises:
77+
ValueError: If no embedding model is provided and required Azure
78+
OpenAI environment variables are not set.
79+
"""
4980
if embedding_model:
5081
return MemoryEmbedding(embedding_model=embedding_model)
5182

pyrit/memory/memory_exporter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class MemoryExporter:
1616
"""
1717

1818
def __init__(self):
19+
"""
20+
Initialize the MemoryExporter.
21+
22+
Sets up the available export formats using the strategy design pattern.
23+
"""
1924
# Using strategy design pattern for export functionality.
2025
self.export_strategies = {
2126
"json": self.export_to_json,
@@ -28,7 +33,7 @@ def export_data(
2833
self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json"
2934
): # type: ignore
3035
"""
31-
Exports the provided data to a file in the specified format.
36+
Export the provided data to a file in the specified format.
3237
3338
Args:
3439
data (list[MessagePiece]): The data to be exported, as a list of MessagePiece instances.
@@ -49,7 +54,7 @@ def export_data(
4954

5055
def export_to_json(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore
5156
"""
52-
Exports the provided data to a JSON file at the specified file path.
57+
Export the provided data to a JSON file at the specified file path.
5358
Each item in the data list, representing a row from the table,
5459
is converted to a dictionary before being written to the file.
5560
@@ -72,7 +77,7 @@ def export_to_json(self, data: list[MessagePiece], file_path: Path = None) -> No
7277

7378
def export_to_csv(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore
7479
"""
75-
Exports the provided data to a CSV file at the specified file path.
80+
Export the provided data to a CSV file at the specified file path.
7681
Each item in the data list, representing a row from the table,
7782
is converted to a dictionary before being written to the file.
7883
@@ -98,7 +103,7 @@ def export_to_csv(self, data: list[MessagePiece], file_path: Path = None) -> Non
98103

99104
def export_to_markdown(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore
100105
"""
101-
Exports the provided data to a Markdown file at the specified file path.
106+
Export the provided data to a Markdown file at the specified file path.
102107
Each item in the data list is converted to a dictionary and formatted as a table.
103108
104109
Args:

0 commit comments

Comments
 (0)