diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 166ce52..32ba53f 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -30,6 +30,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true @@ -82,6 +83,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 26fa671..58cca6f 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -33,6 +33,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml index 42060c6..58053ce 100644 --- a/.github/workflows/publish-test.yml +++ b/.github/workflows/publish-test.yml @@ -21,6 +21,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4a51297..cf3d4da 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -24,6 +24,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: "2.2.1" virtualenvs-create: true virtualenvs-in-project: true installer-parallel: true diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 9148de7..aaba181 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -336,6 +336,25 @@ def do_rollback(self, dbapi_connection): # Databricks SQL Does not support transactions pass + def do_ping(self, dbapi_connection): + """Check if the connection is usable. + + Called by SQLAlchemy when pool_pre_ping=True before checking out + a connection from the pool. If this returns False, the connection + is invalidated and a new one is created. + + Any error during the ping means the connection is unusable + """ + try: + cursor = dbapi_connection.cursor() + try: + cursor.execute("SELECT 1") + finally: + cursor.close() + return True + except Exception: + return False + @reflection.cache def has_table( self, connection, table_name, schema=None, catalog=None, **kwargs diff --git a/tests/test_local/e2e/test_basic.py b/tests/test_local/e2e/test_basic.py index ce0b5d8..c7ab88e 100644 --- a/tests/test_local/e2e/test_basic.py +++ b/tests/test_local/e2e/test_basic.py @@ -541,3 +541,50 @@ def test_table_comment_reflection(self, inspector: Inspector, table: Table): def test_column_comment(self, inspector: Inspector, table: Table): result = inspector.get_columns(table.name)[0].get("comment") assert result == "column comment" + + +def test_pool_pre_ping_with_closed_connection(connection_details): + """Test that pool_pre_ping detects closed connections and creates new ones. + + When a pooled connection is closed (simulating session expiration), + do_ping() detects it and SQLAlchemy creates a new connection. + """ + conn_string, connect_args = version_agnostic_connect_arguments(connection_details) + + engine = create_engine( + conn_string, + connect_args=connect_args, + pool_pre_ping=True, + pool_size=1, + max_overflow=0, + ) + + # Step 1: Use a connection and record its session ID + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + raw_conn = conn.connection.dbapi_connection + session_id_1 = raw_conn.get_session_id_hex() + assert session_id_1 is not None + + # Step 2: Close the pooled connection to simulate session expiration + pooled_conn = engine.pool._pool.queue[0] + pooled_conn.driver_connection.close() + assert not pooled_conn.driver_connection.open + + # Step 3: pool_pre_ping should detect the dead connection and create a new one + with engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")).scalar() + assert result is not None + + raw_conn = conn.connection.dbapi_connection + session_id_2 = raw_conn.get_session_id_hex() + assert session_id_2 is not None + + assert session_id_1 != session_id_2, ( + "pool_pre_ping should have detected the closed connection " + "and created a new one with a different session ID" + ) + + engine.dispose() diff --git a/tests/test_local/test_do_ping.py b/tests/test_local/test_do_ping.py new file mode 100644 index 0000000..c00cfa7 --- /dev/null +++ b/tests/test_local/test_do_ping.py @@ -0,0 +1,45 @@ +"""Tests for DatabricksDialect.do_ping() method.""" +from unittest.mock import MagicMock, patch +import pytest +from databricks.sqlalchemy import DatabricksDialect + + +class TestDoPing: + @pytest.fixture + def dialect(self): + return DatabricksDialect() + + def test_ping_success(self, dialect): + """do_ping returns True when SELECT 1 succeeds.""" + mock_conn = MagicMock() + assert dialect.do_ping(mock_conn) is True + mock_conn.cursor.assert_called_once() + mock_conn.cursor().execute.assert_called_once_with("SELECT 1") + mock_conn.cursor().close.assert_called_once() + + def test_ping_cursor_fails(self, dialect): + """do_ping returns False when cursor() raises (connection closed).""" + mock_conn = MagicMock() + mock_conn.cursor.side_effect = Exception("Cannot create cursor from closed connection") + assert dialect.do_ping(mock_conn) is False + + def test_ping_execute_fails(self, dialect): + """do_ping returns False when execute() raises (session expired).""" + mock_conn = MagicMock() + mock_conn.cursor().execute.side_effect = Exception("Invalid SessionHandle") + assert dialect.do_ping(mock_conn) is False + + def test_ping_cursor_closed_on_success(self, dialect): + """Cursor is closed after a successful ping.""" + mock_conn = MagicMock() + dialect.do_ping(mock_conn) + mock_conn.cursor().close.assert_called_once() + + def test_ping_cursor_closed_on_execute_failure(self, dialect): + """Cursor is closed even when execute() fails.""" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.execute.side_effect = Exception("network error") + dialect.do_ping(mock_conn) + mock_cursor.close.assert_called_once()