diff --git a/.github/workflows/pr-code-coverage.yml b/.github/workflows/pr-code-coverage.yml index f2f1aad9..c07204f3 100644 --- a/.github/workflows/pr-code-coverage.yml +++ b/.github/workflows/pr-code-coverage.yml @@ -426,8 +426,8 @@ jobs: --arg covered_lines "${{ env.COVERED_LINES }}" \ --arg total_lines "${{ env.TOTAL_LINES }}" \ --arg patch_coverage_pct "${{ env.PATCH_COVERAGE_PCT }}" \ - --arg low_coverage_files "${{ env.LOW_COVERAGE_FILES }}" \ - --arg patch_coverage_summary "${{ env.PATCH_COVERAGE_SUMMARY }}" \ + --arg low_coverage_files "$LOW_COVERAGE_FILES" \ + --arg patch_coverage_summary "$PATCH_COVERAGE_SUMMARY" \ --arg ado_url "${{ env.ADO_URL }}" \ '{ pr_number: $pr_number, diff --git a/mssql_python/connection_string_parser.py b/mssql_python/connection_string_parser.py index 9dd88db2..cdf17620 100644 --- a/mssql_python/connection_string_parser.py +++ b/mssql_python/connection_string_parser.py @@ -108,8 +108,10 @@ def _normalize_params(params: Dict[str, str], warn_rejected: bool = True) -> Dic if normalized_key in _RESERVED_PARAMETERS: continue - # Parameter is allowed - filtered[normalized_key] = value + # First-wins: match ODBC behaviour where the first + # occurrence of a synonym group takes precedence. + if normalized_key not in filtered: + filtered[normalized_key] = value else: # Parameter is not in allow-list # Note: In normal flow, this should be empty since parser validates first diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 4c03b4df..3acbffcc 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -17,7 +17,7 @@ import warnings from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING, Iterable from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes -from mssql_python.helpers import check_error +from mssql_python.helpers import check_error, connstr_to_pycore_params from mssql_python.logging import logger from mssql_python import ddbc_bindings from mssql_python.exceptions import ( @@ -2498,6 +2498,7 @@ def nextset(self) -> Union[bool, None]: ) return True + # ── Mapping from ODBC connection-string keywords (lowercase, as _parse returns) def _bulkcopy( self, table_name: str, @@ -2632,38 +2633,10 @@ def _bulkcopy( "Specify the target database explicitly to avoid accidentally writing to system databases." ) - # Build connection context for bulk copy library - # Note: Password is extracted separately to avoid storing it in the main context - # dict that could be accidentally logged or exposed in error messages. - trust_cert = params.get("trustservercertificate", "yes").lower() in ("yes", "true") - - # Parse encryption setting from connection string - encrypt_param = params.get("encrypt") - if encrypt_param is not None: - encrypt_value = encrypt_param.strip().lower() - if encrypt_value in ("yes", "true", "mandatory", "required"): - encryption = "Required" - elif encrypt_value in ("no", "false", "optional"): - encryption = "Optional" - else: - # Pass through unrecognized values (e.g., "Strict") to the underlying driver - encryption = encrypt_param - else: - encryption = "Optional" - - context = { - "server": params.get("server"), - "database": params.get("database"), - "trust_server_certificate": trust_cert, - "encryption": encryption, - } - - # Build pycore_context with appropriate authentication. - # For Azure AD: acquire a FRESH token right now instead of reusing - # the one from connect() time — avoids expired-token errors when - # bulkcopy() is called long after the original connection. - pycore_context = dict(context) + # Translate parsed connection string into the dict py-core expects. + pycore_context = connstr_to_pycore_params(params) + # Token acquisition — only thing cursor must handle (needs azure-identity SDK) if self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection from mssql_python.auth import AADAuth @@ -2680,10 +2653,6 @@ def _bulkcopy( "Bulk copy: acquired fresh Azure AD token for auth_type=%s", self.connection._auth_type, ) - else: - # SQL Server authentication — use uid/password from connection string - pycore_context["user_name"] = params.get("uid", "") - pycore_context["password"] = params.get("pwd", "") pycore_connection = None pycore_cursor = None @@ -2722,9 +2691,8 @@ def _bulkcopy( finally: # Clear sensitive data to minimize memory exposure if pycore_context: - pycore_context.pop("password", None) - pycore_context.pop("user_name", None) - pycore_context.pop("access_token", None) + for key in ("password", "user_name", "access_token"): + pycore_context.pop(key, None) # Clean up bulk copy resources for resource in (pycore_cursor, pycore_connection): if resource and hasattr(resource, "close"): diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index 4d785b48..8c7b9060 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -250,6 +250,97 @@ def _sanitize_for_logging(input_val: Any, max_length: int = max_log_length) -> s return True, None, sanitized_attr, sanitized_val +def connstr_to_pycore_params(params: dict) -> dict: + """Translate parsed ODBC connection-string params for py-core's bulk copy path. + + When ``cursor.bulkcopy()`` is called, mssql-python opens a *separate* + connection through mssql-py-core. + py-core's ``connection.rs`` expects a Python dict with snake_case keys — + different from the ODBC-style keys that ``_ConnectionStringParser._parse`` + returns. + + This function bridges that gap: it maps lowercase ODBC keys (e.g. + ``"trustservercertificate"``) to py-core keys (``"trust_server_certificate"``) + and converts numeric strings to ``int`` for timeout/size params. + Boolean params (TrustServerCertificate, MultiSubnetFailover) are passed as + strings — ``connection.rs`` validates Yes/No and rejects invalid values. + Unrecognised keys are silently dropped. + """ + # Only keys listed below are forwarded to py-core. + # Unknown/reserved keys (app, workstationid, language, connect_timeout, + # mars_connection) are silently dropped here. In the normal connect() + # path the parser validates keywords first (validate_keywords=True), + # but bulkcopy parses with validation off, so this mapping is the + # authoritative filter in that path. + key_map = { + # auth / credentials + "uid": "user_name", + "pwd": "password", + "trusted_connection": "trusted_connection", + "authentication": "authentication", + # server (accept parser synonyms) + "server": "server", + "addr": "server", + "address": "server", + # database + "database": "database", + "applicationintent": "application_intent", + # encryption / TLS (include snake_case alias the parser may emit) + "encrypt": "encryption", + "trustservercertificate": "trust_server_certificate", + "trust_server_certificate": "trust_server_certificate", + "hostnameincertificate": "host_name_in_certificate", + "servercertificate": "server_certificate", + # Kerberos + "serverspn": "server_spn", + # network + "multisubnetfailover": "multi_subnet_failover", + "ipaddresspreference": "ip_address_preference", + "keepalive": "keep_alive", + "keepaliveinterval": "keep_alive_interval", + # sizing / limits ("packet size" with space is a common pyodbc-ism) + "packetsize": "packet_size", + "packet size": "packet_size", + "connectretrycount": "connect_retry_count", + "connectretryinterval": "connect_retry_interval", + } + int_keys = { + "packet_size", + "connect_retry_count", + "connect_retry_interval", + "keep_alive", + "keep_alive_interval", + } + + pycore_params: dict = {} + + for connstr_key, pycore_key in key_map.items(): + raw_value = params.get(connstr_key) + if raw_value is None: + continue + + # First-wins: match ODBC behaviour — first synonym in the + # connection string takes precedence (e.g. Addr before Server). + if pycore_key in pycore_params: + continue + + # ODBC values are always strings; py-core expects native types for int keys. + # Boolean params (trust_server_certificate, multi_subnet_failover) are passed + # as strings — all Yes/No validation is in connection.rs for single-location + # consistency with Encrypt, ApplicationIntent, IPAddressPreference, etc. + if pycore_key in int_keys: + # Numeric params (timeouts, packet size, etc.) — skip on bad input + try: + pycore_params[pycore_key] = int(raw_value) + except (ValueError, TypeError): + pass # let py-core fall back to its compiled-in default + else: + # String params (server, database, encryption, etc.) — pass through + pycore_params[pycore_key] = raw_value + + return pycore_params + + # Settings functionality moved here to avoid circular imports # Initialize the locale setting only once at module import time diff --git a/tests/test_010_connection_string_parser.py b/tests/test_010_connection_string_parser.py index af55004d..d632092d 100644 --- a/tests/test_010_connection_string_parser.py +++ b/tests/test_010_connection_string_parser.py @@ -440,3 +440,99 @@ def test_incomplete_entry_recovery(self): # Should have error about incomplete 'Server' errors = exc_info.value.errors assert any("Server" in err and "Incomplete specification" in err for err in errors) + + +class TestSynonymFirstWins: + """ + Verify that _normalize_params uses first-wins for synonym keys. + + ODBC Driver 18 behaviour (confirmed via live test against sqlcconn.cpp): + - Same key repeated → first-wins (fFromAttrOrProp guard) + - Addr vs Address → same KEY_ADDR slot, first-wins + - Addr/Address vs Server → separate slots, Addr/Address takes priority + + _ConnectionStringParser._parse() rejects exact duplicate keys outright. + These tests cover synonyms that map to the same canonical key during + normalization (e.g. addr/address/server → "Server"). + """ + + @staticmethod + def _normalize(raw: dict) -> dict: + """Shorthand for calling _normalize_params with warnings suppressed.""" + return _ConnectionStringParser._normalize_params(raw, warn_rejected=False) + + # ---- server / addr / address synonyms -------------------------------- + + def test_server_then_addr_first_wins(self): + """Server=A;Addr=B → first-wins keeps A.""" + result = self._normalize({"server": "hostA", "addr": "hostB"}) + assert result["Server"] == "hostA" + + def test_addr_then_server_first_wins(self): + """Addr=A;Server=B → first-wins keeps A.""" + result = self._normalize({"addr": "hostA", "server": "hostB"}) + assert result["Server"] == "hostA" + + def test_address_then_server_first_wins(self): + """Address=A;Server=B → first-wins keeps A.""" + result = self._normalize({"address": "hostA", "server": "hostB"}) + assert result["Server"] == "hostA" + + def test_addr_then_address_first_wins(self): + """Addr=A;Address=B → first-wins keeps A.""" + result = self._normalize({"addr": "hostA", "address": "hostB"}) + assert result["Server"] == "hostA" + + def test_all_three_server_synonyms_first_wins(self): + """Addr=A;Address=B;Server=C → first-wins keeps A.""" + result = self._normalize({"addr": "hostA", "address": "hostB", "server": "hostC"}) + assert result["Server"] == "hostA" + + def test_server_only_no_synonyms(self): + """Single key has no conflict.""" + result = self._normalize({"server": "hostA"}) + assert result["Server"] == "hostA" + + # ---- trustservercertificate / trust_server_certificate synonyms ------ + + def test_trustservercertificate_then_snake_case_first_wins(self): + """trustservercertificate=Yes;trust_server_certificate=No → first-wins keeps Yes.""" + result = self._normalize( + {"trustservercertificate": "Yes", "trust_server_certificate": "No"} + ) + assert result["TrustServerCertificate"] == "Yes" + + def test_snake_case_then_trustservercertificate_first_wins(self): + """trust_server_certificate=No;trustservercertificate=Yes → first-wins keeps No.""" + result = self._normalize( + {"trust_server_certificate": "No", "trustservercertificate": "Yes"} + ) + assert result["TrustServerCertificate"] == "No" + + # ---- packetsize / "packet size" synonyms ----------------------------- + + def test_packetsize_then_packet_space_first_wins(self): + """packetsize=8192;packet size=4096 → first-wins keeps 8192.""" + result = self._normalize({"packetsize": "8192", "packet size": "4096"}) + assert result["PacketSize"] == "8192" + + def test_packet_space_then_packetsize_first_wins(self): + """packet size=4096;packetsize=8192 → first-wins keeps 4096.""" + result = self._normalize({"packet size": "4096", "packetsize": "8192"}) + assert result["PacketSize"] == "4096" + + # ---- non-synonym keys are unaffected --------------------------------- + + def test_different_keys_both_kept(self): + """Non-synonym keys should both be present.""" + result = self._normalize({"server": "host", "database": "mydb", "uid": "sa"}) + assert result == {"Server": "host", "Database": "mydb", "UID": "sa"} + + # ---- reserved keys filtered regardless of order ---------------------- + + def test_reserved_keys_always_filtered(self): + """Driver and APP are always stripped, even when first.""" + result = self._normalize({"driver": "foo", "server": "host", "app": "bar"}) + assert "Driver" not in result + assert "APP" not in result + assert result["Server"] == "host" diff --git a/tests/test_011_connection_string_allowlist.py b/tests/test_011_connection_string_allowlist.py index 97735bb3..85f41174 100644 --- a/tests/test_011_connection_string_allowlist.py +++ b/tests/test_011_connection_string_allowlist.py @@ -130,8 +130,8 @@ def test__normalize_params_handles_address_variants(self): """Test filtering handles address/addr/server as synonyms.""" params = {"address": "addr1", "addr": "addr2", "server": "server1"} filtered = _ConnectionStringParser._normalize_params(params, warn_rejected=False) - # All three are synonyms that map to 'Server', last one wins - assert filtered["Server"] == "server1" + # All three are synonyms that map to 'Server', first one wins + assert filtered["Server"] == "addr1" assert "Address" not in filtered assert "Addr" not in filtered