Skip to content

Commit a1a709b

Browse files
authored
MAINT: Improve docstrings completeness in pyrit/memory (#1223)
1 parent 0d9b9f6 commit a1a709b

File tree

9 files changed

+430
-94
lines changed

9 files changed

+430
-94
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ extend-select = [
251251
# Temporary ignores for pyrit/ subdirectories until issue #1176
252252
# https://github.com/Azure/PyRIT/issues/1176 is fully resolved
253253
# TODO: Remove these ignores once the issues are fixed
254-
"pyrit/{auxiliary_attacks,exceptions,memory,models,prompt_converter,prompt_target,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
254+
"pyrit/{auxiliary_attacks,exceptions,models,prompt_converter,prompt_target,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
255255
"pyrit/__init__.py" = ["D104"]
256256

257257
[tool.ruff.lint.pydocstyle]

pyrit/memory/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
"""
5+
Provide functionality for storing and retrieving conversation history and embeddings.
6+
7+
This package defines the core `MemoryInterface` and concrete implementations for different storage backends.
8+
"""
9+
410
from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry, SeedEntry, AttackResultEntry
511
from pyrit.memory.memory_interface import MemoryInterface
612

pyrit/memory/azure_sql_memory.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ def __init__(
6060
results_sas_token: Optional[str] = None,
6161
verbose: bool = False,
6262
):
63+
"""
64+
Initialize an Azure SQL Memory backend.
65+
66+
Args:
67+
connection_string (Optional[str]): The connection string for the Azure Sql Database. If not provided,
68+
it falls back to the 'AZURE_SQL_DB_CONNECTION_STRING' environment variable.
69+
results_container_url (Optional[str]): The URL to an Azure Storage Container. If not provided,
70+
it falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL' environment variable.
71+
results_sas_token (Optional[str]): The Shared Access Signature (SAS) token for the storage container.
72+
If not provided, falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN' environment variable.
73+
verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False.
74+
"""
6375
self._connection_string = default_values.get_required_value(
6476
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
6577
)
@@ -114,7 +126,7 @@ def _init_storage_io(self):
114126

115127
def _create_auth_token(self) -> None:
116128
"""
117-
Creates an Azure Entra ID access token.
129+
Create an Azure Entra ID access token.
118130
Stores the token and its expiry time.
119131
"""
120132
azure_auth = AzureAuth(token_scope=self.TOKEN_URL)
@@ -133,13 +145,19 @@ def _refresh_token_if_needed(self) -> None:
133145

134146
def _create_engine(self, *, has_echo: bool) -> Engine:
135147
"""
136-
Creates the SQLAlchemy engine for Azure SQL Server.
148+
Create the SQLAlchemy engine for Azure SQL Server.
137149
138150
Creates an engine bound to the specified server and database. The `has_echo` parameter
139151
controls the verbosity of SQL execution logging.
140152
141153
Args:
142154
has_echo (bool): Flag to enable detailed SQL execution logging.
155+
156+
Returns:
157+
Engine: SQLAlchemy engine bound to the AZURE SQL Database.
158+
159+
Raises:
160+
SQLAlchemyError: If the engine creation fails.
143161
"""
144162
try:
145163
# Create the SQLAlchemy engine.
@@ -156,6 +174,8 @@ def _create_engine(self, *, has_echo: bool) -> Engine:
156174

157175
def _enable_azure_authorization(self) -> None:
158176
"""
177+
Enable Azure token-based authorization for SQL connections.
178+
159179
The following is necessary because of how SQLAlchemy and PyODBC handle connection creation. In PyODBC, the
160180
token is passed outside the connection string in the `connect()` method. Since SQLAlchemy lazy-loads
161181
its connections, we need to set this as a separate argument to the `connect()` method. In SQLALchemy
@@ -184,7 +204,7 @@ def provide_token(_dialect, _conn_rec, cargs, cparams):
184204

185205
def _create_tables_if_not_exist(self):
186206
"""
187-
Creates all tables defined in the Base metadata, if they don't already exist in the database.
207+
Create all tables defined in the Base metadata, if they don't already exist in the database.
188208
189209
Raises:
190210
Exception: If there's an issue creating the tables in the database.
@@ -198,7 +218,7 @@ def _create_tables_if_not_exist(self):
198218

199219
def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None:
200220
"""
201-
Inserts embedding data into memory storage.
221+
Insert embedding data into memory storage.
202222
"""
203223
self._insert_entries(entries=embedding_data)
204224

@@ -295,7 +315,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
295315

296316
def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
297317
"""
298-
SQL Azure implementation for filtering AttackResults by targeted harm categories.
318+
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
299319
300320
Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array.
301321
@@ -333,7 +353,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
333353

334354
def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
335355
"""
336-
SQL Azure implementation for filtering AttackResults by labels.
356+
Get the SQL Azure implementation for filtering AttackResults by labels.
337357
338358
Uses JSON_VALUE() function specific to SQL Azure with parameterized queries.
339359
@@ -364,7 +384,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
364384

365385
def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any:
366386
"""
367-
SQL Azure implementation for filtering ScenarioResults by labels.
387+
Get the SQL Azure implementation for filtering ScenarioResults by labels.
368388
369389
Uses JSON_VALUE() function specific to SQL Azure.
370390
@@ -385,7 +405,7 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
385405

386406
def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any:
387407
"""
388-
SQL Azure implementation for filtering ScenarioResults by target endpoint.
408+
Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.
389409
390410
Uses JSON_VALUE() function specific to SQL Azure.
391411
@@ -402,7 +422,7 @@ def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> An
402422

403423
def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any:
404424
"""
405-
SQL Azure implementation for filtering ScenarioResults by target model name.
425+
Get the SQL Azure implementation for filtering ScenarioResults by target model name.
406426
407427
Uses JSON_VALUE() function specific to SQL Azure.
408428
@@ -419,7 +439,7 @@ def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any
419439

420440
def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
421441
"""
422-
Inserts a list of message pieces into the memory storage.
442+
Insert a list of message pieces into the memory storage.
423443
424444
"""
425445
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])
@@ -434,17 +454,23 @@ def dispose_engine(self):
434454

435455
def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]:
436456
"""
437-
Fetches all entries from the specified table and returns them as model instances.
457+
Fetch all entries from the specified table and returns them as model instances.
458+
459+
Returns:
460+
Sequence[EmbeddingDataEntry]: A sequence of EmbeddingDataEntry instances representing all stored embeddings.
438461
"""
439462
result: Sequence[EmbeddingDataEntry] = self._query_entries(EmbeddingDataEntry)
440463
return result
441464

442465
def _insert_entry(self, entry: Base) -> None: # type: ignore
443466
"""
444-
Inserts an entry into the Table.
467+
Insert an entry into the Table.
445468
446469
Args:
447470
entry: An instance of a SQLAlchemy model to be added to the Table.
471+
472+
Raises:
473+
SQLAlchemyError: If the insertion fails.
448474
"""
449475
with closing(self.get_session()) as session:
450476
try:
@@ -459,7 +485,15 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore
459485
# common between SQLAlchemy-based implementations, regardless of engine.
460486
# Perhaps we should find a way to refactor
461487
def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore
462-
"""Inserts multiple entries into the database."""
488+
"""
489+
Insert multiple entries into the database.
490+
491+
Args:
492+
entries (Sequence[Base]): A sequence of SQLAlchemy model instances to insert.
493+
494+
Raises:
495+
SQLAlchemyError: If the insertion fails.
496+
"""
463497
with closing(self.get_session()) as session:
464498
try:
465499
session.add_all(entries)
@@ -471,7 +505,10 @@ def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore
471505

472506
def get_session(self) -> Session:
473507
"""
474-
Provides a session for database operations.
508+
Provide a session for database operations.
509+
510+
Returns:
511+
Session: A new SQLAlchemy session bound to the configured engine.
475512
"""
476513
return self.SessionFactory()
477514

@@ -484,15 +521,19 @@ def _query_entries(
484521
join_scores: bool = False,
485522
) -> MutableSequence[Model]:
486523
"""
487-
Fetches data from the specified table model with optional conditions.
524+
Fetch data from the specified table model with optional conditions.
488525
489526
Args:
527+
Model: The SQLAlchemy model class to query.
490528
conditions: SQLAlchemy filter conditions (Optional).
491529
distinct: Flag to return distinct rows (defaults to False).
492530
join_scores: Flag to join the scores table with entries (defaults to False).
493531
494532
Returns:
495533
List of model instances representing the rows fetched from the table.
534+
535+
Raises:
536+
SQLAlchemyError: If the query fails.
496537
"""
497538
with closing(self.get_session()) as session:
498539
try:
@@ -515,14 +556,18 @@ def _query_entries(
515556

516557
def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict) -> bool: # type: ignore
517558
"""
518-
Updates the given entries with the specified field values.
559+
Update the given entries with the specified field values.
519560
520561
Args:
521562
entries (Sequence[Base]): A list of SQLAlchemy model instances to be updated.
522563
update_fields (dict): A dictionary of field names and their new values.
523564
524565
Returns:
525566
bool: True if the update was successful, False otherwise.
567+
568+
Raises:
569+
ValueError: If 'update_fields' is empty.
570+
SQLAlchemyError: If the update fails.
526571
"""
527572
if not update_fields:
528573
raise ValueError("update_fields must be provided to update prompt entries.")

pyrit/memory/central_memory.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
class CentralMemory:
1212
"""
13-
Provides a centralized memory instance across the framework. The provided memory
14-
instance will be reused for future calls.
13+
Provide a centralized memory instance across the framework.
14+
The provided memory instance will be reused for future calls.
1515
"""
1616

1717
_memory_instance: MemoryInterface = None
@@ -30,7 +30,13 @@ def set_memory_instance(cls, passed_memory: MemoryInterface) -> None:
3030
@classmethod
3131
def get_memory_instance(cls) -> MemoryInterface:
3232
"""
33-
Returns a centralized memory instance.
33+
Return a centralized memory instance.
34+
35+
Returns:
36+
MemoryInterface: The singleton memory instance.
37+
38+
Raises:
39+
ValueError: If the central memory instance has not been set.
3440
"""
3541
if cls._memory_instance:
3642
logger.info(f"Using existing memory instance: {type(cls._memory_instance).__name__}")

pyrit/memory/memory_embedding.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,32 @@ class MemoryEmbedding:
1818
"""
1919

2020
def __init__(self, *, embedding_model: Optional[EmbeddingSupport] = None):
21+
"""
22+
Initialize the memory embedding helper with a backing embedding model.
23+
24+
Args:
25+
embedding_model (Optional[EmbeddingSupport]): The embedding model used to
26+
generate text embeddings. If not provided, a ValueError is raised.
27+
28+
Raises:
29+
ValueError: If `embedding_model` is not provided.
30+
"""
2131
if embedding_model is None:
2232
raise ValueError("embedding_model must be set.")
2333
self.embedding_model = embedding_model
2434

2535
def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> EmbeddingDataEntry:
2636
"""
27-
Generates metadata for a message piece.
37+
Generate metadata for a message piece.
2838
2939
Args:
3040
message_piece (MessagePiece): the message piece for which to generate a text embedding
3141
3242
Returns:
3343
EmbeddingDataEntry: The generated metadata.
44+
45+
Raises:
46+
ValueError: If the message piece is not of type text.
3447
"""
3548
if message_piece.converted_value_data_type == "text":
3649
embedding_data = EmbeddingDataEntry(
@@ -46,6 +59,24 @@ def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> Embe
4659

4760

4861
def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None:
62+
"""
63+
Create a MemoryEmbedding instance with default or provided embedding model.
64+
65+
Factory function that creates a MemoryEmbedding instance. If an embedding_model
66+
is provided, it uses that model. Otherwise, it attempts to create an Azure
67+
OpenAI embedding model from environment variables.
68+
69+
Args:
70+
embedding_model: Optional embedding model to use. If not provided,
71+
attempts to create AzureTextEmbedding from environment variables.
72+
73+
Returns:
74+
MemoryEmbedding: Configured memory embedding instance.
75+
76+
Raises:
77+
ValueError: If no embedding model is provided and required Azure
78+
OpenAI environment variables are not set.
79+
"""
4980
if embedding_model:
5081
return MemoryEmbedding(embedding_model=embedding_model)
5182

pyrit/memory/memory_exporter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ class MemoryExporter:
1616
"""
1717

1818
def __init__(self):
19+
"""
20+
Initialize the MemoryExporter.
21+
22+
Sets up the available export formats using the strategy design pattern.
23+
"""
1924
# Using strategy design pattern for export functionality.
2025
self.export_strategies = {
2126
"json": self.export_to_json,
@@ -28,7 +33,7 @@ def export_data(
2833
self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json"
2934
): # type: ignore
3035
"""
31-
Exports the provided data to a file in the specified format.
36+
Export the provided data to a file in the specified format.
3237
3338
Args:
3439
data (list[MessagePiece]): The data to be exported, as a list of MessagePiece instances.
@@ -49,7 +54,7 @@ def export_data(
4954

5055
def export_to_json(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore
5156
"""
52-
Exports the provided data to a JSON file at the specified file path.
57+
Export the provided data to a JSON file at the specified file path.
5358
Each item in the data list, representing a row from the table,
5459
is converted to a dictionary before being written to the file.
5560
@@ -72,7 +77,7 @@ def export_to_json(self, data: list[MessagePiece], file_path: Path = None) -> No
7277

7378
def export_to_csv(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore
7479
"""
75-
Exports the provided data to a CSV file at the specified file path.
80+
Export the provided data to a CSV file at the specified file path.
7681
Each item in the data list, representing a row from the table,
7782
is converted to a dictionary before being written to the file.
7883
@@ -98,7 +103,7 @@ def export_to_csv(self, data: list[MessagePiece], file_path: Path = None) -> Non
98103

99104
def export_to_markdown(self, data: list[MessagePiece], file_path: Path = None) -> None: # type: ignore
100105
"""
101-
Exports the provided data to a Markdown file at the specified file path.
106+
Export the provided data to a Markdown file at the specified file path.
102107
Each item in the data list is converted to a dictionary and formatted as a table.
103108
104109
Args:

0 commit comments

Comments
 (0)