diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 40c3e06e..9b488c6d 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -11,15 +11,33 @@ from mssql_python.logging import logger from mssql_python.constants import AuthType, ConstantsDDBC +from mssql_python.connection_string_parser import _ConnectionStringParser # Module-level credential instance cache. # Reusing credential objects allows the Azure Identity SDK's built-in # in-memory token cache to work, avoiding redundant token acquisitions. # See: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/TOKEN_CACHING.md -_credential_cache: Dict[str, object] = {} +# +# Cache is keyed on (auth_type, sorted credential_kwargs), which is +# bounded by the distinct credentials a single process ever uses. +_credential_cache: Dict[object, object] = {} _credential_cache_lock = threading.Lock() +def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): + """Build a hashable cache key from auth_type and optional credential kwargs. + + Returns the plain auth_type string when no kwargs are provided so that + callers caching by string (the original behavior) keep working. When + kwargs are present (e.g. user-assigned MSI client_id), the key is a + tuple of ``(auth_type, sorted_kwargs_items)`` so different kwargs map + to different cached credentials. + """ + if not credential_kwargs: + return auth_type + return (auth_type, tuple(sorted(credential_kwargs.items()))) + + class AADAuth: """Handles Azure Active Directory authentication""" @@ -37,24 +55,26 @@ def get_token_struct(token: str) -> bytes: return struct.pack(f" bytes: + def get_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> bytes: """Get DDBC token struct for the specified authentication type.""" - token_struct, _ = AADAuth._acquire_token(auth_type) + token_struct, _ = AADAuth._acquire_token(auth_type, credential_kwargs) return token_struct @staticmethod - def get_raw_token(auth_type: str) -> str: + def get_raw_token(auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None) -> str: """Acquire a raw JWT for the mssql-py-core connection (bulk copy). Uses the cached credential instance so the Azure Identity SDK's built-in token cache can serve a valid token without a round-trip when the previous token has not yet expired. """ - _, raw_token = AADAuth._acquire_token(auth_type) + _, raw_token = AADAuth._acquire_token(auth_type, credential_kwargs) return raw_token @staticmethod - def _acquire_token(auth_type: str) -> Tuple[bytes, str]: + def _acquire_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None + ) -> Tuple[bytes, str]: """Internal: acquire token and return (ddbc_struct, raw_jwt).""" # Import Azure libraries inside method to support test mocking # pylint: disable=import-outside-toplevel @@ -63,6 +83,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: DefaultAzureCredential, DeviceCodeCredential, InteractiveBrowserCredential, + ManagedIdentityCredential, ) from azure.core.exceptions import ClientAuthenticationError except ImportError as e: @@ -76,6 +97,7 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: "default": DefaultAzureCredential, "devicecode": DeviceCodeCredential, "interactive": InteractiveBrowserCredential, + "msi": ManagedIdentityCredential, } credential_class = credential_map.get(auth_type) @@ -89,20 +111,22 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: credential_class.__name__, ) + kwargs = credential_kwargs or {} + cache_key = _credential_cache_key(auth_type, kwargs) try: with _credential_cache_lock: - if auth_type not in _credential_cache: + if cache_key not in _credential_cache: logger.debug( "get_token: Creating new credential instance for auth_type=%s", auth_type, ) - _credential_cache[auth_type] = credential_class() + _credential_cache[cache_key] = credential_class(**kwargs) else: logger.debug( "get_token: Reusing cached credential instance for auth_type=%s", auth_type, ) - credential = _credential_cache[auth_type] + credential = _credential_cache[cache_key] raw_token = credential.get_token("https://database.windows.net/.default").token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", @@ -130,6 +154,28 @@ def _acquire_token(auth_type: str) -> Tuple[bytes, str]: raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e +def _extract_msi_client_id(connection_string: str) -> Optional[str]: + """Pull UID out of a connection string for user-assigned MSI. + + For ActiveDirectoryMSI, UID (when present) carries the user-assigned + identity's ``client_id``. Returns None for system-assigned MSI. + + Uses the canonical ``_ConnectionStringParser`` so braced ODBC values + are handled correctly: a ``UID={hello=world}`` resolves to the value + ``hello=world`` (no surrounding braces, no false split on the inner + ``=``), and a semicolon inside a legitimate braced value (e.g. + ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. + """ + # Connection.__init__ already parsed the same string through + # _ConnectionStringParser via _construct_connection_string, so by the + # time we get here the input is guaranteed parseable. No defensive + # try/except: a parse failure now means a real bug upstream and should + # propagate, not silently degrade user-assigned MSI to system-assigned. + parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) + uid = (parsed.get("uid") or "").strip() + return uid or None + + def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: """ Process connection parameters and extract authentication type. @@ -180,6 +226,10 @@ def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[ # Default authentication (uses DefaultAzureCredential) logger.debug("process_auth_parameters: Default Azure authentication detected") auth_type = "default" + elif value_lower == AuthType.MSI.value: + # Managed identity authentication (system- or user-assigned) + logger.debug("process_auth_parameters: Managed identity authentication detected") + auth_type = "msi" modified_parameters.append(param) logger.debug( @@ -212,7 +262,9 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]: return result -def get_auth_token(auth_type: str) -> Optional[bytes]: +def get_auth_token( + auth_type: str, credential_kwargs: Optional[Dict[str, str]] = None +) -> Optional[bytes]: """Get DDBC authentication token struct based on auth type.""" logger.debug("get_auth_token: Starting - auth_type=%s", auth_type) if not auth_type: @@ -225,7 +277,7 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: return None # Let Windows handle AADInteractive natively try: - token = AADAuth.get_token(auth_type) + token = AADAuth.get_token(auth_type, credential_kwargs) logger.info("get_auth_token: Token acquired successfully - auth_type=%s", auth_type) return token except (ValueError, RuntimeError) as e: @@ -246,6 +298,7 @@ def extract_auth_type(connection_string: str) -> Optional[str]: AuthType.INTERACTIVE.value: "interactive", AuthType.DEVICE_CODE.value: "devicecode", AuthType.DEFAULT.value: "default", + AuthType.MSI.value: "msi", } for part in connection_string.split(";"): key, _, value = part.strip().partition("=") @@ -256,16 +309,28 @@ def extract_auth_type(connection_string: str) -> Optional[str]: def process_connection_string( connection_string: str, -) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: +) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]: """ Process connection string and handle authentication. + NOTE: Returns a 4-tuple. Callers must unpack all four elements. + Destructuring with three names raises ``ValueError: too many values + to unpack``. The fourth element (``credential_kwargs``) is needed by + Connection.__init__ to persist credential constructor args (e.g. the + user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, + since UID is stripped from the sanitized connection string. + Args: connection_string: The connection string to process Returns: - Tuple[str, Optional[Dict], Optional[str]]: Processed connection string, - attrs_before dict if needed, and auth_type string for bulk copy token acquisition + Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]: + Processed connection string, attrs_before dict if needed, auth_type + string for bulk copy token acquisition, and credential constructor + kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on + the Connection so bulkcopy can re-use them when acquiring a fresh + token after sanitization has stripped UID from the connection + string. Raises: ValueError: If the connection string is invalid or empty @@ -301,12 +366,33 @@ def process_connection_string( modified_parameters, auth_type = process_auth_parameters(parameters) + # Capture credential kwargs (e.g. user-assigned MSI client_id) before + # remove_sensitive_params strips UID from the parameter list. Pass the + # original connection_string (not modified_parameters) so the helper can + # use the canonical _ConnectionStringParser — handles braced values like + # UID={hello=world} correctly. + credential_kwargs: Dict[str, str] = {} + if auth_type == "msi": + client_id = _extract_msi_client_id(connection_string) + if client_id: + credential_kwargs["client_id"] = client_id + logger.debug( + "process_connection_string: ActiveDirectoryMSI with UID — " + "user-assigned managed identity selected (client_id length=%d)", + len(client_id), + ) + else: + logger.debug( + "process_connection_string: ActiveDirectoryMSI without UID — " + "system-assigned managed identity selected" + ) + if auth_type: logger.info( "process_connection_string: Authentication type detected - auth_type=%s", auth_type ) modified_parameters = remove_sensitive_params(modified_parameters) - token_struct = get_auth_token(auth_type) + token_struct = get_auth_token(auth_type, credential_kwargs or None) if token_struct: logger.info( "process_connection_string: Token authentication configured successfully - auth_type=%s", @@ -316,6 +402,7 @@ def process_connection_string( ";".join(modified_parameters) + ";", {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, auth_type, + credential_kwargs or None, ) else: logger.warning( @@ -326,4 +413,9 @@ def process_connection_string( "process_connection_string: Connection string processing complete - has_auth=%s", bool(auth_type), ) - return ";".join(modified_parameters) + ";", None, auth_type + return ( + ";".join(modified_parameters) + ";", + None, + auth_type, + credential_kwargs or None, + ) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 83457650..0933560b 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -321,6 +321,12 @@ def __init__( # We intentionally do NOT cache the token — a fresh one is acquired # each time bulkcopy() is called to avoid expired-token errors. self._auth_type = None + # Credential constructor kwargs (e.g. user-assigned MSI client_id) + # captured at __init__ time before remove_sensitive_params strips UID + # from self.connection_str. bulkcopy() re-uses these when acquiring a + # fresh token; re-parsing self.connection_str at that point would miss + # them because UID is already gone. + self._credential_kwargs: Optional[Dict[str, str]] = None # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. @@ -335,6 +341,7 @@ def __init__( # On Windows Interactive, process_connection_string returns None # (DDBC handles auth natively), so fall back to the connection string. self._auth_type = connection_result[2] or extract_auth_type(self.connection_str) + self._credential_kwargs = connection_result[3] self._closed = False self._timeout = timeout diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 549737c6..5de02ece 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -337,6 +337,7 @@ class AuthType(Enum): INTERACTIVE = "activedirectoryinteractive" DEVICE_CODE = "activedirectorydevicecode" DEFAULT = "activedirectorydefault" + MSI = "activedirectorymsi" class SQLTypes: diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index f0b1d6a6..ece27c61 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2933,11 +2933,17 @@ def bulkcopy( # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: - # Fresh token acquisition for mssql-py-core connection + # Fresh token acquisition for mssql-py-core connection. credential + # kwargs (e.g. user-assigned MSI client_id) were captured by + # Connection.__init__ before remove_sensitive_params stripped UID + # from connection_str — re-parsing here would miss them. from mssql_python.auth import AADAuth try: - raw_token = AADAuth.get_raw_token(self.connection._auth_type) + raw_token = AADAuth.get_raw_token( + self.connection._auth_type, + self.connection._credential_kwargs, + ) except (RuntimeError, ValueError) as e: raise RuntimeError( f"Bulk copy failed: unable to acquire Azure AD token " diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f680518b..f8df6f6f 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -44,6 +44,17 @@ class MockInteractiveBrowserCredential: def get_token(self, scope): return MockToken() + class MockManagedIdentityCredential: + # Captures construction kwargs so user-assigned MSI tests can assert + # client_id was forwarded correctly. + last_init_kwargs = None + + def __init__(self, **kwargs): + MockManagedIdentityCredential.last_init_kwargs = kwargs + + def get_token(self, scope): + return MockToken() + # Mock ClientAuthenticationError class MockClientAuthenticationError(Exception): pass @@ -52,6 +63,7 @@ class MockIdentity: DefaultAzureCredential = MockDefaultAzureCredential DeviceCodeCredential = MockDeviceCodeCredential InteractiveBrowserCredential = MockInteractiveBrowserCredential + ManagedIdentityCredential = MockManagedIdentityCredential class MockCore: class exceptions: @@ -87,6 +99,7 @@ def test_auth_type_constants(self): assert AuthType.INTERACTIVE.value == "activedirectoryinteractive" assert AuthType.DEVICE_CODE.value == "activedirectorydevicecode" assert AuthType.DEFAULT.value == "activedirectorydefault" + assert AuthType.MSI.value == "activedirectorymsi" class TestAADAuth: @@ -317,6 +330,16 @@ def test_default_auth(self): _, auth_type = process_auth_parameters(params) assert auth_type == "default" + def test_msi_auth(self): + params = ["Authentication=ActiveDirectoryMSI", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type == "msi" + + def test_msi_auth_case_insensitive(self): + params = ["authentication=activedirectorymsi", "Server=test"] + _, auth_type = process_auth_parameters(params) + assert auth_type == "msi" + class TestRemoveSensitiveParams: def test_remove_sensitive_parameters(self): @@ -344,7 +367,7 @@ def test_remove_sensitive_parameters(self): class TestProcessConnectionString: def test_process_connection_string_with_default_auth(self): conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -352,10 +375,11 @@ def test_process_connection_string_with_default_auth(self): assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) assert auth_type == "default" + assert credential_kwargs is None def test_process_connection_string_no_auth(self): conn_str = "Server=test;Database=testdb;UID=user;PWD=password" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -363,11 +387,12 @@ def test_process_connection_string_no_auth(self): assert "PWD=password" in result_str assert attrs is None assert auth_type is None + assert credential_kwargs is None def test_process_connection_string_interactive_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str @@ -375,6 +400,7 @@ def test_process_connection_string_interactive_non_windows(self, monkeypatch): assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) assert auth_type == "interactive" + assert credential_kwargs is None def test_error_handling(): @@ -407,6 +433,9 @@ def test_devicecode(self): == "devicecode" ) + def test_msi(self): + assert extract_auth_type("Server=test;Authentication=ActiveDirectoryMSI;") == "msi" + def test_no_auth(self): assert extract_auth_type("Server=test;Database=db;") is None @@ -414,6 +443,159 @@ def test_unsupported_auth(self): assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None +class TestManagedIdentity: + """Tests for ActiveDirectoryMSI support (system- and user-assigned).""" + + def test_get_token_system_assigned_msi(self): + """System-assigned MSI: ManagedIdentityCredential() constructed with no kwargs.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + token_struct = AADAuth.get_token("msi") + assert isinstance(token_struct, bytes) + assert az.ManagedIdentityCredential.last_init_kwargs == {} + + def test_get_raw_token_system_assigned_msi(self): + raw_token = AADAuth.get_raw_token("msi") + assert raw_token == SAMPLE_TOKEN + + def test_get_token_user_assigned_msi(self): + """User-assigned MSI: client_id is forwarded to the credential constructor.""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + client_id = "11111111-2222-3333-4444-555555555555" + token_struct = AADAuth.get_token("msi", {"client_id": client_id}) + assert isinstance(token_struct, bytes) + assert az.ManagedIdentityCredential.last_init_kwargs == {"client_id": client_id} + + def test_msi_separate_cache_entries_per_client_id(self): + """System-assigned and user-assigned MSI must not share a cached credential.""" + AADAuth.get_token("msi") # system-assigned + AADAuth.get_token("msi", {"client_id": "abc"}) + AADAuth.get_token("msi", {"client_id": "def"}) + + # System-assigned uses the bare string key; user-assigned uses tuples. + assert "msi" in _credential_cache + assert ("msi", (("client_id", "abc"),)) in _credential_cache + assert ("msi", (("client_id", "def"),)) in _credential_cache + assert _credential_cache["msi"] is not _credential_cache[("msi", (("client_id", "abc"),))] + + def test_process_connection_string_msi_strips_uid_and_returns_kwargs(self): + """MSI connection strings: UID is stripped from the ODBC connection + string but the client_id is captured as credential_kwargs (so it can + be persisted on the Connection for the bulkcopy fresh-token path).""" + az = sys.modules["azure.identity"] + + az.ManagedIdentityCredential.last_init_kwargs = None + conn_str = ( + "Server=test;Authentication=ActiveDirectoryMSI;" + "UID=11111111-2222-3333-4444-555555555555;Database=testdb" + ) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) + + assert auth_type == "msi" + assert "UID=" not in result_str + assert "Authentication=" not in result_str + assert "Server=test" in result_str + assert "Database=testdb" in result_str + assert attrs is not None + assert az.ManagedIdentityCredential.last_init_kwargs == { + "client_id": "11111111-2222-3333-4444-555555555555" + } + # client_id must be returned so Connection can persist it for the + # bulkcopy fresh-token path (UID is gone from result_str by then). + assert credential_kwargs == {"client_id": "11111111-2222-3333-4444-555555555555"} + + def test_process_connection_string_msi_system_assigned_no_kwargs(self): + """System-assigned MSI: no UID → credential_kwargs is None.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;Database=testdb" + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs is None + + def test_msi_braced_uid_value_is_unwrapped(self): + """A braced UID value (UID={hello=world}) must be unwrapped by the + canonical _ConnectionStringParser; the inner '=' must NOT split the + value. Without parser-aware extraction the helper would return + '{hello=world}' verbatim and ManagedIdentityCredential would reject + it.""" + conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID={hello=world};Database=testdb" + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs == {"client_id": "hello=world"} + + def test_msi_braced_uid_with_semicolon_is_preserved(self): + """A braced UID value containing a semicolon (legal under ODBC) must + be returned intact, not truncated at the inner ';'.""" + weird_id = "abc;def;ghi" + conn_str = ( + f"Server=test;Authentication=ActiveDirectoryMSI;" f"UID={{{weird_id}}};Database=testdb" + ) + _, _, auth_type, credential_kwargs = process_connection_string(conn_str) + assert auth_type == "msi" + assert credential_kwargs == {"client_id": weird_id} + + def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): + """Regression test (cursor.bulkcopy() end-to-end) for the silent + system-assigned fallback: the bulkcopy fresh-token code path must + forward Connection._credential_kwargs to AADAuth.get_raw_token, + not re-parse the (now UID-stripped) connection_str. + + Fails if cursor.py is reverted to call extract_credential_kwargs on + self.connection.connection_str, OR if Connection stops persisting + _credential_kwargs.""" + from mssql_python.cursor import Cursor + + client_id = "11111111-2222-3333-4444-555555555555" + + # Mock Connection holding what Connection.__init__ would store after + # process_connection_string strips UID from the user-supplied string. + mock_conn = MagicMock() + # Post-sanitization string: NO UID. If cursor re-parses this, the + # forwarded kwargs will be {} and the assert below will fail. + mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" + mock_conn._auth_type = "msi" + mock_conn._credential_kwargs = {"client_id": client_id} + mock_conn._is_connected = True + + cursor = Cursor.__new__(Cursor) + cursor._connection = mock_conn + cursor.closed = False + cursor.hstmt = None + + captured = {} + + def fake_get_raw_token(auth_type, credential_kwargs=None): + captured["auth_type"] = auth_type + captured["credential_kwargs"] = credential_kwargs + return SAMPLE_TOKEN + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with ( + patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}), + patch("mssql_python.auth.AADAuth.get_raw_token", side_effect=fake_get_raw_token), + ): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert captured["auth_type"] == "msi" + assert captured["credential_kwargs"] == {"client_id": client_id}, ( + f"bulkcopy must forward Connection._credential_kwargs verbatim; " + f"got {captured['credential_kwargs']!r}. If this is {{}} or None, " + f"the cursor likely re-parses the (UID-stripped) connection_str." + ) + + class TestCredentialInstanceCache: """Tests for the credential instance caching behavior.""" @@ -624,6 +806,50 @@ def test_auth_type_stored_on_connection(self, mock_ddbc_conn): assert conn._auth_type == "default" conn.close() + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_persisted_for_user_assigned_msi(self, mock_ddbc_conn): + """Connection.__init__ must capture MSI client_id BEFORE + remove_sensitive_params strips UID, and persist it on + self._credential_kwargs so cursor.bulkcopy() can use it later.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + client_id = "11111111-2222-3333-4444-555555555555" + conn = connect( + f"Server=test;Database=testdb;Authentication=ActiveDirectoryMSI;UID={client_id}" + ) + assert conn._auth_type == "msi" + assert conn._credential_kwargs == {"client_id": client_id} + # And the connection_str on the Connection should NOT contain UID + # (this is what makes _credential_kwargs the source of truth). + assert "UID=" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_none_for_system_assigned_msi(self, mock_ddbc_conn): + """System-assigned MSI: no UID → _credential_kwargs stays None.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryMSI") + assert conn._auth_type == "msi" + assert conn._credential_kwargs is None + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_kwargs_none_for_non_msi_auth(self, mock_ddbc_conn): + """Non-MSI auth types must not pick up credential_kwargs even if + UID is present (e.g. SQL auth UID).""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault;UID=user@x" + ) + assert conn._auth_type == "default" + assert conn._credential_kwargs is None + conn.close() + class TestCredentialCacheThreadSafety: """Verify thread-safe behavior of credential instance cache.""" @@ -760,7 +986,7 @@ class TestProcessConnectionStringTokenFailureFallthrough: def test_returns_none_attrs_when_token_acquisition_fails(self): """When auth type is detected but token acquisition fails, - process_connection_string should return (conn_str, None, auth_type).""" + process_connection_string should return (conn_str, None, auth_type, kwargs).""" import sys azure_identity = sys.modules["azure.identity"] @@ -773,7 +999,7 @@ def __init__(self): try: azure_identity.DefaultAzureCredential = CredentialThatAlwaysFails conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type = process_connection_string(conn_str) + result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) # Auth type was detected assert auth_type == "default" @@ -782,5 +1008,7 @@ def __init__(self): # Connection string is still returned (sensitive params removed) assert "Server=test" in result_str assert "Database=testdb" in result_str + # Default auth has no credential kwargs + assert credential_kwargs is None finally: azure_identity.DefaultAzureCredential = original