Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TYPE_CHECKING,
Any,
)
from urllib.parse import urlparse
from urllib.parse import ParseResult, urlparse

import requests
from fsspec import AbstractFileSystem
Expand Down Expand Up @@ -244,7 +244,7 @@ def _gs(properties: Properties) -> AbstractFileSystem:
)


def _adls(properties: Properties) -> AbstractFileSystem:
def _adls(properties: Properties, hostname: str | None = None) -> AbstractFileSystem:
# https://fsspec.github.io/adlfs/api/

from adlfs import AzureBlobFileSystem
Expand All @@ -259,6 +259,10 @@ def _adls(properties: Properties) -> AbstractFileSystem:
if ADLS_SAS_TOKEN not in properties:
properties[ADLS_SAS_TOKEN] = sas_token

# Fallback: extract account_name from URI hostname (e.g. "account.dfs.core.windows.net" -> "account")
if hostname and ADLS_ACCOUNT_NAME not in properties:
properties[ADLS_ACCOUNT_NAME] = hostname.split(".")[0]

class StaticTokenCredential(AsyncTokenCredential):
_DEFAULT_EXPIRY_SECONDS = 3600

Expand Down Expand Up @@ -300,7 +304,7 @@ def _hf(properties: Properties) -> AbstractFileSystem:
)


SCHEME_TO_FS = {
SCHEME_TO_FS: dict[str, Callable[..., AbstractFileSystem]] = {
"": _file,
"file": _file,
"s3": _s3,
Expand All @@ -313,6 +317,8 @@ def _hf(properties: Properties) -> AbstractFileSystem:
"hf": _hf,
}

_ADLS_SCHEMES = frozenset({"abfs", "abfss", "wasb", "wasbs"})


class FsspecInputFile(InputFile):
"""An input file implementation for the FsspecFileIO.
Expand Down Expand Up @@ -414,8 +420,7 @@ class FsspecFileIO(FileIO):
"""A FileIO implementation that uses fsspec."""

def __init__(self, properties: Properties):
self._scheme_to_fs = {}
self._scheme_to_fs.update(SCHEME_TO_FS)
self._scheme_to_fs: dict[str, Callable[..., AbstractFileSystem]] = dict(SCHEME_TO_FS)
self._thread_locals = threading.local()
super().__init__(properties=properties)

Expand All @@ -429,7 +434,7 @@ def new_input(self, location: str) -> FsspecInputFile:
FsspecInputFile: An FsspecInputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self._get_fs_from_uri(uri)
return FsspecInputFile(location=location, fs=fs)

def new_output(self, location: str) -> FsspecOutputFile:
Expand All @@ -442,7 +447,7 @@ def new_output(self, location: str) -> FsspecOutputFile:
FsspecOutputFile: An FsspecOutputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self._get_fs_from_uri(uri)
return FsspecOutputFile(location=location, fs=fs)

def delete(self, location: str | InputFile | OutputFile) -> None:
Expand All @@ -459,20 +464,30 @@ def delete(self, location: str | InputFile | OutputFile) -> None:
str_location = location

uri = urlparse(str_location)
fs = self.get_fs(uri.scheme)
fs = self._get_fs_from_uri(uri)
fs.rm(str_location)

def get_fs(self, scheme: str) -> AbstractFileSystem:
def _get_fs_from_uri(self, uri: "ParseResult") -> AbstractFileSystem:
"""Get a filesystem from a parsed URI, using hostname for ADLS account resolution."""
if uri.scheme in _ADLS_SCHEMES:
return self.get_fs(uri.scheme, uri.hostname)
return self.get_fs(uri.scheme)

def get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme, cached per thread."""
if not hasattr(self._thread_locals, "get_fs_cached"):
self._thread_locals.get_fs_cached = lru_cache(self._get_fs)

return self._thread_locals.get_fs_cached(scheme)
return self._thread_locals.get_fs_cached(scheme, hostname)

def _get_fs(self, scheme: str) -> AbstractFileSystem:
def _get_fs(self, scheme: str, hostname: str | None = None) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme."""
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")

if scheme in _ADLS_SCHEMES:
return _adls(self.properties, hostname)

return self._scheme_to_fs[scheme](self.properties)

def __getstate__(self) -> dict[str, Any]:
Expand Down
58 changes: 58 additions & 0 deletions tests/io/test_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,64 @@ def test_adls_account_name_sas_token_extraction() -> None:
)


def test_adls_account_name_extracted_from_uri_hostname() -> None:
"""Test that account_name is extracted from the ABFSS URI hostname when not in properties."""
session_properties: Properties = {
"adls.tenant-id": "test-tenant-id",
"adls.client-id": "test-client-id",
"adls.client-secret": "test-client-secret",
}

with mock.patch("adlfs.AzureBlobFileSystem") as mock_adlfs:
adls_fileio = FsspecFileIO(properties=session_properties)

adls_fileio.new_input(
location="abfss://dd-michelada-us3-prod-dog@usagestorageprod.dfs.core.windows.net"
"/unified_datasets/aggregated/data/file.parquet"
)

mock_adlfs.assert_called_with(
connection_string=None,
credential=None,
account_name="usagestorageprod",
account_key=None,
sas_token=None,
tenant_id="test-tenant-id",
client_id="test-client-id",
client_secret="test-client-secret",
account_host=None,
anon=None,
)


def test_adls_account_name_not_overridden_when_in_properties() -> None:
"""Test that explicit adls.account-name in properties is not overridden by URI hostname."""
session_properties: Properties = {
"adls.account-name": "explicitly-configured-account",
"adls.tenant-id": "test-tenant-id",
"adls.client-id": "test-client-id",
"adls.client-secret": "test-client-secret",
}

with mock.patch("adlfs.AzureBlobFileSystem") as mock_adlfs:
adls_fileio = FsspecFileIO(properties=session_properties)

adls_fileio.new_input(location="abfss://container@usagestorageprod.dfs.core.windows.net/path/file.parquet")

mock_adlfs.assert_called_with(
connection_string=None,
credential=None,
account_name="explicitly-configured-account",
account_key=None,
sas_token=None,
tenant_id="test-tenant-id",
client_id="test-client-id",
client_secret="test-client-secret",
account_host=None,
anon=None,
)


@pytest.mark.gcs
def test_fsspec_new_input_file_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
"""Test creating a new input file from a fsspec file-io"""
Expand Down