diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index a4c5c347c..ffaf9f826 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -2,20 +2,41 @@ import struct import time from contextlib import contextmanager +from dataclasses import dataclass from itertools import chain, repeat from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union import agate import dbt_common.exceptions import pyodbc -from azure.core.credentials import AccessToken -from azure.identity import ( - AzureCliCredential, - ClientSecretCredential, - DefaultAzureCredential, - EnvironmentCredential, - ManagedIdentityCredential, -) + +try: + from azure.core.credentials import AccessToken +except ModuleNotFoundError: + @dataclass + class AccessToken: + token: str + expires_on: int + + +try: + from azure.identity import ( + AzureCliCredential, + ClientSecretCredential, + DefaultAzureCredential, + EnvironmentCredential, + ManagedIdentityCredential, + ) + + _AZURE_IDENTITY_IMPORT_ERROR = None +except ModuleNotFoundError as exc: + AzureCliCredential = None + ClientSecretCredential = None + DefaultAzureCredential = None + EnvironmentCredential = None + ManagedIdentityCredential = None + _AZURE_IDENTITY_IMPORT_ERROR = exc + from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState from dbt.adapters.events.logging import AdapterLogger from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus @@ -51,6 +72,15 @@ } +def _require_azure_identity(authentication: str) -> None: + if _AZURE_IDENTITY_IMPORT_ERROR is not None: + raise dbt_common.exceptions.DbtRuntimeError( + "Azure authentication '{}' requires the optional dependency 'azure-identity'. " + "Install it with `pip install azure-identity` or use a non-Azure authentication mode." + .format(authentication) + ) from _AZURE_IDENTITY_IMPORT_ERROR + + def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: """ Convert bytes to a Microsoft windows byte string. @@ -110,6 +140,7 @@ def get_cli_access_token( Access token. """ _ = credentials + _require_azure_identity("cli") token = AzureCliCredential().get_token( scope, timeout=getattr(credentials, "login_timeout", None) ) @@ -132,6 +163,7 @@ def get_auto_access_token( out : AccessToken The access token. """ + _require_azure_identity("auto") token = DefaultAzureCredential().get_token( scope, timeout=getattr(credentials, "login_timeout", None) ) @@ -154,6 +186,7 @@ def get_environment_access_token( out : AccessToken The access token. """ + _require_azure_identity("environment") token = EnvironmentCredential().get_token( scope, timeout=getattr(credentials, "login_timeout", None) ) @@ -177,6 +210,7 @@ def get_msi_access_token( The access token. """ _ = credentials + _require_azure_identity("msi") token = ManagedIdentityCredential().get_token(scope) return token @@ -198,6 +232,7 @@ def get_sp_access_token( The access token. """ _ = scope + _require_azure_identity("serviceprincipal") token = ClientSecretCredential( str(credentials.tenant_id), str(credentials.client_id), diff --git a/tests/functional/adapter/dbt/test_utils.py b/tests/functional/adapter/dbt/test_utils.py index 4d4587012..2feb5ee1a 100644 --- a/tests/functional/adapter/dbt/test_utils.py +++ b/tests/functional/adapter/dbt/test_utils.py @@ -1,5 +1,10 @@ import pytest -from dbt.tests.adapter.utils import fixture_cast_bool_to_text, fixture_dateadd, fixture_listagg +from dbt.tests.adapter.utils import ( + fixture_cast_bool_to_text, + fixture_dateadd, + fixture_listagg, + fixture_split_part, +) from dbt.tests.adapter.utils.test_any_value import BaseAnyValue from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat @@ -340,7 +345,18 @@ class TestSafeCast(BaseSafeCast): class TestSplitPart(BaseSplitPart): - pass + @pytest.fixture(scope="class") + def models(self): + model_sql = """ +-- depends_on: {{ ref('data_split_part') }} +""" + self.interpolate_macro_namespace( + fixture_split_part.models__test_split_part_sql, "split_part" + ) + + return { + "test_split_part.yml": fixture_split_part.models__test_split_part_yml, + "test_split_part.sql": model_sql, + } class TestStringLiteral(BaseStringLiteral): diff --git a/tests/functional/adapter/mssql/test_cross_db.py b/tests/functional/adapter/mssql/test_cross_db.py index 72802d6f6..8c4b91a2a 100644 --- a/tests/functional/adapter/mssql/test_cross_db.py +++ b/tests/functional/adapter/mssql/test_cross_db.py @@ -55,7 +55,15 @@ def create_secondary_db(self, project): ) def cleanup_secondary_database(self, project): - drop_sql = "DROP DATABASE IF EXISTS secondary_db" + drop_sql = """ + USE [master] + + IF EXISTS (SELECT * FROM sys.databases WHERE name = 'secondary_db') + BEGIN + ALTER DATABASE [secondary_db] SET SINGLE_USER WITH ROLLBACK IMMEDIATE + DROP DATABASE [secondary_db] + END + """ with get_connection(project.adapter): project.adapter.execute( drop_sql.format(database=project.database), diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index 2acb2520b..be909afb4 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -1,6 +1,8 @@ import pytest from azure.identity import AzureCliCredential +from dbt_common.exceptions import DbtRuntimeError +from dbt.adapters.sqlserver import sqlserver_connections from dbt.adapters.sqlserver.sqlserver_connections import ( # byte_array_to_datetime, bool_to_connection_string_arg, get_pyodbc_attrs_before_credentials, @@ -33,6 +35,26 @@ def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( assert attrs_before == {} +def test_get_pyodbc_attrs_before_sql_auth_without_azure_identity( + credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", ModuleNotFoundError()) + + attrs_before = get_pyodbc_attrs_before_credentials(credentials) + + assert attrs_before == {} + + +def test_get_pyodbc_attrs_before_cli_auth_requires_azure_identity( + credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch +) -> None: + credentials.authentication = "cli" + monkeypatch.setattr(sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", ModuleNotFoundError()) + + with pytest.raises(DbtRuntimeError, match="requires the optional dependency 'azure-identity'"): + get_pyodbc_attrs_before_credentials(credentials) + + @pytest.mark.parametrize( "key, value, expected", [("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")],