@@ -60,6 +60,18 @@ def __init__(
6060 results_sas_token : Optional [str ] = None ,
6161 verbose : bool = False ,
6262 ):
63+ """
64+ Initialize an Azure SQL Memory backend.
65+
66+ Args:
67+ connection_string (Optional[str]): The connection string for the Azure Sql Database. If not provided,
68+ it falls back to the 'AZURE_SQL_DB_CONNECTION_STRING' environment variable.
69+ results_container_url (Optional[str]): The URL to an Azure Storage Container. If not provided,
70+ it falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL' environment variable.
71+ results_sas_token (Optional[str]): The Shared Access Signature (SAS) token for the storage container.
72+ If not provided, falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN' environment variable.
73+ verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False.
74+ """
6375 self ._connection_string = default_values .get_required_value (
6476 env_var_name = self .AZURE_SQL_DB_CONNECTION_STRING , passed_value = connection_string
6577 )
@@ -114,7 +126,7 @@ def _init_storage_io(self):
114126
115127 def _create_auth_token (self ) -> None :
116128 """
117- Creates an Azure Entra ID access token.
129+ Create an Azure Entra ID access token.
118130 Stores the token and its expiry time.
119131 """
120132 azure_auth = AzureAuth (token_scope = self .TOKEN_URL )
@@ -133,13 +145,19 @@ def _refresh_token_if_needed(self) -> None:
133145
134146 def _create_engine (self , * , has_echo : bool ) -> Engine :
135147 """
136- Creates the SQLAlchemy engine for Azure SQL Server.
148+ Create the SQLAlchemy engine for Azure SQL Server.
137149
138150 Creates an engine bound to the specified server and database. The `has_echo` parameter
139151 controls the verbosity of SQL execution logging.
140152
141153 Args:
142154 has_echo (bool): Flag to enable detailed SQL execution logging.
155+
156+ Returns:
157+ Engine: SQLAlchemy engine bound to the AZURE SQL Database.
158+
159+ Raises:
160+ SQLAlchemyError: If the engine creation fails.
143161 """
144162 try :
145163 # Create the SQLAlchemy engine.
@@ -156,6 +174,8 @@ def _create_engine(self, *, has_echo: bool) -> Engine:
156174
157175 def _enable_azure_authorization (self ) -> None :
158176 """
177+ Enable Azure token-based authorization for SQL connections.
178+
159179 The following is necessary because of how SQLAlchemy and PyODBC handle connection creation. In PyODBC, the
160180 token is passed outside the connection string in the `connect()` method. Since SQLAlchemy lazy-loads
161181 its connections, we need to set this as a separate argument to the `connect()` method. In SQLALchemy
@@ -184,7 +204,7 @@ def provide_token(_dialect, _conn_rec, cargs, cparams):
184204
185205 def _create_tables_if_not_exist (self ):
186206 """
187- Creates all tables defined in the Base metadata, if they don't already exist in the database.
207+ Create all tables defined in the Base metadata, if they don't already exist in the database.
188208
189209 Raises:
190210 Exception: If there's an issue creating the tables in the database.
@@ -198,7 +218,7 @@ def _create_tables_if_not_exist(self):
198218
199219 def _add_embeddings_to_memory (self , * , embedding_data : Sequence [EmbeddingDataEntry ]) -> None :
200220 """
201- Inserts embedding data into memory storage.
221+ Insert embedding data into memory storage.
202222 """
203223 self ._insert_entries (entries = embedding_data )
204224
@@ -295,7 +315,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
295315
296316 def _get_attack_result_harm_category_condition (self , * , targeted_harm_categories : Sequence [str ]) -> Any :
297317 """
298- SQL Azure implementation for filtering AttackResults by targeted harm categories.
318+ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
299319
300320 Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array.
301321
@@ -333,7 +353,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
333353
334354 def _get_attack_result_label_condition (self , * , labels : dict [str , str ]) -> Any :
335355 """
336- SQL Azure implementation for filtering AttackResults by labels.
356+ Get the SQL Azure implementation for filtering AttackResults by labels.
337357
338358 Uses JSON_VALUE() function specific to SQL Azure with parameterized queries.
339359
@@ -364,7 +384,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
364384
365385 def _get_scenario_result_label_condition (self , * , labels : dict [str , str ]) -> Any :
366386 """
367- SQL Azure implementation for filtering ScenarioResults by labels.
387+ Get the SQL Azure implementation for filtering ScenarioResults by labels.
368388
369389 Uses JSON_VALUE() function specific to SQL Azure.
370390
@@ -385,7 +405,7 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
385405
386406 def _get_scenario_result_target_endpoint_condition (self , * , endpoint : str ) -> Any :
387407 """
388- SQL Azure implementation for filtering ScenarioResults by target endpoint.
408+ Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.
389409
390410 Uses JSON_VALUE() function specific to SQL Azure.
391411
@@ -402,7 +422,7 @@ def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> An
402422
403423 def _get_scenario_result_target_model_condition (self , * , model_name : str ) -> Any :
404424 """
405- SQL Azure implementation for filtering ScenarioResults by target model name.
425+ Get the SQL Azure implementation for filtering ScenarioResults by target model name.
406426
407427 Uses JSON_VALUE() function specific to SQL Azure.
408428
@@ -419,7 +439,7 @@ def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any
419439
420440 def add_message_pieces_to_memory (self , * , message_pieces : Sequence [MessagePiece ]) -> None :
421441 """
422- Inserts a list of message pieces into the memory storage.
442+ Insert a list of message pieces into the memory storage.
423443
424444 """
425445 self ._insert_entries (entries = [PromptMemoryEntry (entry = piece ) for piece in message_pieces ])
@@ -434,17 +454,23 @@ def dispose_engine(self):
434454
435455 def get_all_embeddings (self ) -> Sequence [EmbeddingDataEntry ]:
436456 """
437- Fetches all entries from the specified table and returns them as model instances.
457+ Fetch all entries from the specified table and returns them as model instances.
458+
459+ Returns:
460+ Sequence[EmbeddingDataEntry]: A sequence of EmbeddingDataEntry instances representing all stored embeddings.
438461 """
439462 result : Sequence [EmbeddingDataEntry ] = self ._query_entries (EmbeddingDataEntry )
440463 return result
441464
442465 def _insert_entry (self , entry : Base ) -> None : # type: ignore
443466 """
444- Inserts an entry into the Table.
467+ Insert an entry into the Table.
445468
446469 Args:
447470 entry: An instance of a SQLAlchemy model to be added to the Table.
471+
472+ Raises:
473+ SQLAlchemyError: If the insertion fails.
448474 """
449475 with closing (self .get_session ()) as session :
450476 try :
@@ -459,7 +485,15 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore
459485 # common between SQLAlchemy-based implementations, regardless of engine.
460486 # Perhaps we should find a way to refactor
461487 def _insert_entries (self , * , entries : Sequence [Base ]) -> None : # type: ignore
462- """Inserts multiple entries into the database."""
488+ """
489+ Insert multiple entries into the database.
490+
491+ Args:
492+ entries (Sequence[Base]): A sequence of SQLAlchemy model instances to insert.
493+
494+ Raises:
495+ SQLAlchemyError: If the insertion fails.
496+ """
463497 with closing (self .get_session ()) as session :
464498 try :
465499 session .add_all (entries )
@@ -471,7 +505,10 @@ def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore
471505
472506 def get_session (self ) -> Session :
473507 """
474- Provides a session for database operations.
508+ Provide a session for database operations.
509+
510+ Returns:
511+ Session: A new SQLAlchemy session bound to the configured engine.
475512 """
476513 return self .SessionFactory ()
477514
@@ -484,15 +521,19 @@ def _query_entries(
484521 join_scores : bool = False ,
485522 ) -> MutableSequence [Model ]:
486523 """
487- Fetches data from the specified table model with optional conditions.
524+ Fetch data from the specified table model with optional conditions.
488525
489526 Args:
527+ Model: The SQLAlchemy model class to query.
490528 conditions: SQLAlchemy filter conditions (Optional).
491529 distinct: Flag to return distinct rows (defaults to False).
492530 join_scores: Flag to join the scores table with entries (defaults to False).
493531
494532 Returns:
495533 List of model instances representing the rows fetched from the table.
534+
535+ Raises:
536+ SQLAlchemyError: If the query fails.
496537 """
497538 with closing (self .get_session ()) as session :
498539 try :
@@ -515,14 +556,18 @@ def _query_entries(
515556
516557 def _update_entries (self , * , entries : MutableSequence [Base ], update_fields : dict ) -> bool : # type: ignore
517558 """
518- Updates the given entries with the specified field values.
559+ Update the given entries with the specified field values.
519560
520561 Args:
521562 entries (Sequence[Base]): A list of SQLAlchemy model instances to be updated.
522563 update_fields (dict): A dictionary of field names and their new values.
523564
524565 Returns:
525566 bool: True if the update was successful, False otherwise.
567+
568+ Raises:
569+ ValueError: If 'update_fields' is empty.
570+ SQLAlchemyError: If the update fails.
526571 """
527572 if not update_fields :
528573 raise ValueError ("update_fields must be provided to update prompt entries." )
0 commit comments