@@ -60,6 +60,19 @@ def __init__(
6060 results_sas_token : Optional [str ] = None ,
6161 verbose : bool = False ,
6262 ):
63+ """
64+ Initialize an Azure SQL Memory backend.
65+
66+ Args:
67+ connection_string (Optional[str]): The connection string for the Azure Sql Database. If not provided,
68+ it falls back to the 'AZURE_SQL_DB_CONNECTION_STRING' environment variable.
69+ results_container_url (Optional[str]): The URL to an Azure Storage Container. If not provided,
70+ it falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL' environment variable.
71+ results_sas_token (Optional[str]): The Shared Access Signature (SAS) token for the storage container.
72+ If not provided, falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN' environment variable.
73+ verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False.
74+ """
75+ self ._init_storage_io ()
6376 self ._connection_string = default_values .get_required_value (
6477 env_var_name = self .AZURE_SQL_DB_CONNECTION_STRING , passed_value = connection_string
6578 )
@@ -114,7 +127,7 @@ def _init_storage_io(self):
114127
115128 def _create_auth_token (self ) -> None :
116129 """
117- Creates an Azure Entra ID access token.
130+ Create an Azure Entra ID access token.
118131 Stores the token and its expiry time.
119132 """
120133 azure_auth = AzureAuth (token_scope = self .TOKEN_URL )
@@ -133,13 +146,19 @@ def _refresh_token_if_needed(self) -> None:
133146
134147 def _create_engine (self , * , has_echo : bool ) -> Engine :
135148 """
136- Creates the SQLAlchemy engine for Azure SQL Server.
149+ Create the SQLAlchemy engine for Azure SQL Server.
137150
138151 Creates an engine bound to the specified server and database. The `has_echo` parameter
139152 controls the verbosity of SQL execution logging.
140153
141154 Args:
142155 has_echo (bool): Flag to enable detailed SQL execution logging.
156+
157+ Returns:
158+ Engine: SQLAlchemy engine bound to the AZURE SQL Database.
159+
160+ Raises:
161+ SQLAlchemyError: If the engine creation fails.
143162 """
144163 try :
145164 # Create the SQLAlchemy engine.
@@ -156,6 +175,8 @@ def _create_engine(self, *, has_echo: bool) -> Engine:
156175
157176 def _enable_azure_authorization (self ) -> None :
158177 """
178+ Enable Azure token-based authorization for SQL connections.
179+
159180 The following is necessary because of how SQLAlchemy and PyODBC handle connection creation. In PyODBC, the
160181 token is passed outside the connection string in the `connect()` method. Since SQLAlchemy lazy-loads
161182 its connections, we need to set this as a separate argument to the `connect()` method. In SQLALchemy
@@ -184,7 +205,7 @@ def provide_token(_dialect, _conn_rec, cargs, cparams):
184205
185206 def _create_tables_if_not_exist (self ):
186207 """
187- Creates all tables defined in the Base metadata, if they don't already exist in the database.
208+ Create all tables defined in the Base metadata, if they don't already exist in the database.
188209
189210 Raises:
190211 Exception: If there's an issue creating the tables in the database.
@@ -198,7 +219,7 @@ def _create_tables_if_not_exist(self):
198219
199220 def _add_embeddings_to_memory (self , * , embedding_data : Sequence [EmbeddingDataEntry ]) -> None :
200221 """
201- Inserts embedding data into memory storage.
222+ Insert embedding data into memory storage.
202223 """
203224 self ._insert_entries (entries = embedding_data )
204225
@@ -295,7 +316,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
295316
296317 def _get_attack_result_harm_category_condition (self , * , targeted_harm_categories : Sequence [str ]) -> Any :
297318 """
298- SQL Azure implementation for filtering AttackResults by targeted harm categories.
319+ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
299320
300321 Uses JSON_QUERY() function specific to SQL Azure to check if categories exist in the JSON array.
301322
@@ -333,7 +354,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories
333354
334355 def _get_attack_result_label_condition (self , * , labels : dict [str , str ]) -> Any :
335356 """
336- SQL Azure implementation for filtering AttackResults by labels.
357+ Get the SQL Azure implementation for filtering AttackResults by labels.
337358
338359 Uses JSON_VALUE() function specific to SQL Azure with parameterized queries.
339360
@@ -364,7 +385,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
364385
365386 def _get_scenario_result_label_condition (self , * , labels : dict [str , str ]) -> Any :
366387 """
367- SQL Azure implementation for filtering ScenarioResults by labels.
388+ Get the SQL Azure implementation for filtering ScenarioResults by labels.
368389
369390 Uses JSON_VALUE() function specific to SQL Azure.
370391
@@ -385,7 +406,7 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
385406
386407 def _get_scenario_result_target_endpoint_condition (self , * , endpoint : str ) -> Any :
387408 """
388- SQL Azure implementation for filtering ScenarioResults by target endpoint.
409+ Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.
389410
390411 Uses JSON_VALUE() function specific to SQL Azure.
391412
@@ -402,7 +423,7 @@ def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> An
402423
403424 def _get_scenario_result_target_model_condition (self , * , model_name : str ) -> Any :
404425 """
405- SQL Azure implementation for filtering ScenarioResults by target model name.
426+ Get the SQL Azure implementation for filtering ScenarioResults by target model name.
406427
407428 Uses JSON_VALUE() function specific to SQL Azure.
408429
@@ -419,7 +440,7 @@ def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any
419440
420441 def add_message_pieces_to_memory (self , * , message_pieces : Sequence [MessagePiece ]) -> None :
421442 """
422- Inserts a list of message pieces into the memory storage.
443+ Insert a list of message pieces into the memory storage.
423444
424445 """
425446 self ._insert_entries (entries = [PromptMemoryEntry (entry = piece ) for piece in message_pieces ])
@@ -434,17 +455,23 @@ def dispose_engine(self):
434455
435456 def get_all_embeddings (self ) -> Sequence [EmbeddingDataEntry ]:
436457 """
437- Fetches all entries from the specified table and returns them as model instances.
458+ Fetch all entries from the specified table and returns them as model instances.
459+
460+ Returns:
461+ Sequence[EmbeddingDataEntry]: A sequence of EmbeddingDataEntry instances representing all stored embeddings.
438462 """
439463 result : Sequence [EmbeddingDataEntry ] = self ._query_entries (EmbeddingDataEntry )
440464 return result
441465
442466 def _insert_entry (self , entry : Base ) -> None : # type: ignore
443467 """
444- Inserts an entry into the Table.
468+ Insert an entry into the Table.
445469
446470 Args:
447471 entry: An instance of a SQLAlchemy model to be added to the Table.
472+
473+ Raises:
474+ SQLAlchemyError: If the insertion fails.
448475 """
449476 with closing (self .get_session ()) as session :
450477 try :
@@ -459,7 +486,15 @@ def _insert_entry(self, entry: Base) -> None: # type: ignore
459486 # common between SQLAlchemy-based implementations, regardless of engine.
460487 # Perhaps we should find a way to refactor
461488 def _insert_entries (self , * , entries : Sequence [Base ]) -> None : # type: ignore
462- """Inserts multiple entries into the database."""
489+ """
490+ Insert multiple entries into the database.
491+
492+ Args:
493+ entries (Sequence[Base]): A sequence of SQLAlchemy model instances to insert.
494+
495+ Raises:
496+ SQLAlchemyError: If the insertion fails.
497+ """
463498 with closing (self .get_session ()) as session :
464499 try :
465500 session .add_all (entries )
@@ -471,7 +506,10 @@ def _insert_entries(self, *, entries: Sequence[Base]) -> None: # type: ignore
471506
472507 def get_session (self ) -> Session :
473508 """
474- Provides a session for database operations.
509+ Provide a session for database operations.
510+
511+ Returns:
512+ Session: A new SQLAlchemy session bound to the configured engine.
475513 """
476514 return self .SessionFactory ()
477515
@@ -484,15 +522,19 @@ def _query_entries(
484522 join_scores : bool = False ,
485523 ) -> MutableSequence [Model ]:
486524 """
487- Fetches data from the specified table model with optional conditions.
525+ Fetch data from the specified table model with optional conditions.
488526
489527 Args:
528+ Model: The SQLAlchemy model class to query.
490529 conditions: SQLAlchemy filter conditions (Optional).
491530 distinct: Flag to return distinct rows (defaults to False).
492531 join_scores: Flag to join the scores table with entries (defaults to False).
493532
494533 Returns:
495534 List of model instances representing the rows fetched from the table.
535+
536+ Raises:
537+ SQLAlchemyError: If the query fails.
496538 """
497539 with closing (self .get_session ()) as session :
498540 try :
@@ -515,14 +557,18 @@ def _query_entries(
515557
516558 def _update_entries (self , * , entries : MutableSequence [Base ], update_fields : dict ) -> bool : # type: ignore
517559 """
518- Updates the given entries with the specified field values.
560+ Update the given entries with the specified field values.
519561
520562 Args:
521563 entries (Sequence[Base]): A list of SQLAlchemy model instances to be updated.
522564 update_fields (dict): A dictionary of field names and their new values.
523565
524566 Returns:
525567 bool: True if the update was successful, False otherwise.
568+
569+ Raises:
570+ ValueError: If 'update_fields' is empty.
571+ SQLAlchemyError: If the update fails.
526572 """
527573 if not update_fields :
528574 raise ValueError ("update_fields must be provided to update prompt entries." )
0 commit comments