Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions dbt/adapters/sqlserver/sqlserver_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,41 @@
import struct
import time
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain, repeat
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union

import agate
import dbt_common.exceptions
import pyodbc
from azure.core.credentials import AccessToken
from azure.identity import (
AzureCliCredential,
ClientSecretCredential,
DefaultAzureCredential,
EnvironmentCredential,
ManagedIdentityCredential,
)

try:
from azure.core.credentials import AccessToken
except ModuleNotFoundError:
@dataclass
class AccessToken:
token: str
expires_on: int


try:
from azure.identity import (
AzureCliCredential,
ClientSecretCredential,
DefaultAzureCredential,
EnvironmentCredential,
ManagedIdentityCredential,
)

_AZURE_IDENTITY_IMPORT_ERROR = None
except ModuleNotFoundError as exc:
AzureCliCredential = None
ClientSecretCredential = None
DefaultAzureCredential = None
EnvironmentCredential = None
ManagedIdentityCredential = None
_AZURE_IDENTITY_IMPORT_ERROR = exc

from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus
Expand Down Expand Up @@ -51,6 +72,15 @@
}


def _require_azure_identity(authentication: str) -> None:
if _AZURE_IDENTITY_IMPORT_ERROR is not None:
raise dbt_common.exceptions.DbtRuntimeError(
"Azure authentication '{}' requires the optional dependency 'azure-identity'. "
"Install it with `pip install azure-identity` or use a non-Azure authentication mode."
.format(authentication)
) from _AZURE_IDENTITY_IMPORT_ERROR


def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes:
"""
Convert bytes to a Microsoft windows byte string.
Expand Down Expand Up @@ -110,6 +140,7 @@ def get_cli_access_token(
Access token.
"""
_ = credentials
_require_azure_identity("cli")
token = AzureCliCredential().get_token(
scope, timeout=getattr(credentials, "login_timeout", None)
)
Expand All @@ -132,6 +163,7 @@ def get_auto_access_token(
out : AccessToken
The access token.
"""
_require_azure_identity("auto")
token = DefaultAzureCredential().get_token(
scope, timeout=getattr(credentials, "login_timeout", None)
)
Expand All @@ -154,6 +186,7 @@ def get_environment_access_token(
out : AccessToken
The access token.
"""
_require_azure_identity("environment")
token = EnvironmentCredential().get_token(
scope, timeout=getattr(credentials, "login_timeout", None)
)
Expand All @@ -177,6 +210,7 @@ def get_msi_access_token(
The access token.
"""
_ = credentials
_require_azure_identity("msi")
token = ManagedIdentityCredential().get_token(scope)
return token

Expand All @@ -198,6 +232,7 @@ def get_sp_access_token(
The access token.
"""
_ = scope
_require_azure_identity("serviceprincipal")
token = ClientSecretCredential(
str(credentials.tenant_id),
str(credentials.client_id),
Expand Down
20 changes: 18 additions & 2 deletions tests/functional/adapter/dbt/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
from dbt.tests.adapter.utils import fixture_cast_bool_to_text, fixture_dateadd, fixture_listagg
from dbt.tests.adapter.utils import (
fixture_cast_bool_to_text,
fixture_dateadd,
fixture_listagg,
fixture_split_part,
)
from dbt.tests.adapter.utils.test_any_value import BaseAnyValue
from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend
from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat
Expand Down Expand Up @@ -340,7 +345,18 @@ class TestSafeCast(BaseSafeCast):


class TestSplitPart(BaseSplitPart):
pass
@pytest.fixture(scope="class")
def models(self):
model_sql = """
-- depends_on: {{ ref('data_split_part') }}
""" + self.interpolate_macro_namespace(
fixture_split_part.models__test_split_part_sql, "split_part"
)

return {
"test_split_part.yml": fixture_split_part.models__test_split_part_yml,
"test_split_part.sql": model_sql,
}


class TestStringLiteral(BaseStringLiteral):
Expand Down
10 changes: 9 additions & 1 deletion tests/functional/adapter/mssql/test_cross_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def create_secondary_db(self, project):
)

def cleanup_secondary_database(self, project):
drop_sql = "DROP DATABASE IF EXISTS secondary_db"
drop_sql = """
USE [master]

IF EXISTS (SELECT * FROM sys.databases WHERE name = 'secondary_db')
BEGIN
ALTER DATABASE [secondary_db] SET SINGLE_USER WITH ROLLBACK IMMEDIATE
DROP DATABASE [secondary_db]
END
"""
with get_connection(project.adapter):
project.adapter.execute(
drop_sql.format(database=project.database),
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/adapters/mssql/test_sqlserver_connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from azure.identity import AzureCliCredential
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.sqlserver import sqlserver_connections
from dbt.adapters.sqlserver.sqlserver_connections import ( # byte_array_to_datetime,
bool_to_connection_string_arg,
get_pyodbc_attrs_before_credentials,
Expand Down Expand Up @@ -33,6 +35,26 @@ def test_get_pyodbc_attrs_before_empty_dict_when_service_principal(
assert attrs_before == {}


def test_get_pyodbc_attrs_before_sql_auth_without_azure_identity(
credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", ModuleNotFoundError())

attrs_before = get_pyodbc_attrs_before_credentials(credentials)

assert attrs_before == {}


def test_get_pyodbc_attrs_before_cli_auth_requires_azure_identity(
credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch
) -> None:
credentials.authentication = "cli"
monkeypatch.setattr(sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", ModuleNotFoundError())

with pytest.raises(DbtRuntimeError, match="requires the optional dependency 'azure-identity'"):
get_pyodbc_attrs_before_credentials(credentials)


@pytest.mark.parametrize(
"key, value, expected",
[("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")],
Expand Down