Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/code-quality-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down Expand Up @@ -82,6 +83,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/publish-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
19 changes: 19 additions & 0 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,25 @@ def do_rollback(self, dbapi_connection):
# Databricks SQL Does not support transactions
pass

def do_ping(self, dbapi_connection):
"""Check if the connection is usable.

Called by SQLAlchemy when pool_pre_ping=True before checking out
a connection from the pool. If this returns False, the connection
is invalidated and a new one is created.

Any error during the ping means the connection is unusable
"""
try:
cursor = dbapi_connection.cursor()
try:
cursor.execute("SELECT 1")
finally:
cursor.close()
return True
except Exception:
return False

@reflection.cache
def has_table(
self, connection, table_name, schema=None, catalog=None, **kwargs
Expand Down
47 changes: 47 additions & 0 deletions tests/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,50 @@ def test_table_comment_reflection(self, inspector: Inspector, table: Table):
def test_column_comment(self, inspector: Inspector, table: Table):
result = inspector.get_columns(table.name)[0].get("comment")
assert result == "column comment"


def test_pool_pre_ping_with_closed_connection(connection_details):
"""Test that pool_pre_ping detects closed connections and creates new ones.

When a pooled connection is closed (simulating session expiration),
do_ping() detects it and SQLAlchemy creates a new connection.
"""
conn_string, connect_args = version_agnostic_connect_arguments(connection_details)

engine = create_engine(
conn_string,
connect_args=connect_args,
pool_pre_ping=True,
pool_size=1,
max_overflow=0,
)

# Step 1: Use a connection and record its session ID
with engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()")).scalar()
assert result is not None

raw_conn = conn.connection.dbapi_connection
session_id_1 = raw_conn.get_session_id_hex()
assert session_id_1 is not None

# Step 2: Close the pooled connection to simulate session expiration
pooled_conn = engine.pool._pool.queue[0]
pooled_conn.driver_connection.close()
assert not pooled_conn.driver_connection.open

# Step 3: pool_pre_ping should detect the dead connection and create a new one
with engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()")).scalar()
assert result is not None

raw_conn = conn.connection.dbapi_connection
session_id_2 = raw_conn.get_session_id_hex()
assert session_id_2 is not None

assert session_id_1 != session_id_2, (
"pool_pre_ping should have detected the closed connection "
"and created a new one with a different session ID"
)

engine.dispose()
45 changes: 45 additions & 0 deletions tests/test_local/test_do_ping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for DatabricksDialect.do_ping() method."""
from unittest.mock import MagicMock, patch
import pytest
from databricks.sqlalchemy import DatabricksDialect


class TestDoPing:
@pytest.fixture
def dialect(self):
return DatabricksDialect()

def test_ping_success(self, dialect):
"""do_ping returns True when SELECT 1 succeeds."""
mock_conn = MagicMock()
assert dialect.do_ping(mock_conn) is True
mock_conn.cursor.assert_called_once()
mock_conn.cursor().execute.assert_called_once_with("SELECT 1")
mock_conn.cursor().close.assert_called_once()

def test_ping_cursor_fails(self, dialect):
"""do_ping returns False when cursor() raises (connection closed)."""
mock_conn = MagicMock()
mock_conn.cursor.side_effect = Exception("Cannot create cursor from closed connection")
assert dialect.do_ping(mock_conn) is False

def test_ping_execute_fails(self, dialect):
"""do_ping returns False when execute() raises (session expired)."""
mock_conn = MagicMock()
mock_conn.cursor().execute.side_effect = Exception("Invalid SessionHandle")
assert dialect.do_ping(mock_conn) is False

def test_ping_cursor_closed_on_success(self, dialect):
"""Cursor is closed after a successful ping."""
mock_conn = MagicMock()
dialect.do_ping(mock_conn)
mock_conn.cursor().close.assert_called_once()

def test_ping_cursor_closed_on_execute_failure(self, dialect):
"""Cursor is closed even when execute() fails."""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_cursor.execute.side_effect = Exception("network error")
dialect.do_ping(mock_conn)
mock_cursor.close.assert_called_once()
Loading