Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cloudbuild/setup_vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ fi
if [ "$INSTALL_FSSPEC_HEAD" = "true" ]; then
echo '--- Installing fsspec HEAD ---'
pip install --force-reinstall git+https://github.com/fsspec/filesystem_spec.git > /dev/null
echo "fsspec version: $(python3 -c 'import fsspec; print(fsspec.__version__)')"
fi
23 changes: 18 additions & 5 deletions gcsfs/extended_gcsfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
from glob import has_magic

import grpc
from fsspec import asyn
from fsspec.callbacks import NoOpCallback
from google.api_core import exceptions as api_exceptions
Expand Down Expand Up @@ -179,11 +180,23 @@ async def _get_control_plane_client(self):
"grpc_asyncio"
)
)
channel = transport_cls.create_channel(
credentials=self.credential,
options=[("grpc.primary_user_agent", f"{USER_AGENT}/{version}")],
quota_project_id=self._user_project,
)
channel_kwargs = {
"credentials": self.credential,
"options": [("grpc.primary_user_agent", f"{USER_AGENT}/{version}")],
"quota_project_id": self._user_project,
}
if self._location:
endpoint = self._location.split("://")[-1]
channel_kwargs["host"] = endpoint

if self._location and self._location.startswith("http://"):
host = channel_kwargs["host"]
channel = grpc.aio.insecure_channel(
host, options=channel_kwargs.get("options")
)
else:
channel = transport_cls.create_channel(**channel_kwargs)

transport = transport_cls(channel=channel)
self._storage_control_client = storage_control_v2.StorageControlAsyncClient(
transport=transport
Expand Down
2 changes: 1 addition & 1 deletion gcsfs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _avoid_adc_timeout(monkeypatch):
yield


@pytest.fixture(autouse=True)
@pytest.fixture(scope="session", autouse=True)
def _mock_get_bucket_type_on_emulator():
"""Mock _get_bucket_type to return UNKNOWN instantly on emulator."""
if not is_real_gcs():
Expand Down
97 changes: 87 additions & 10 deletions gcsfs/tests/test_extended_hns_gcsfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,24 +2153,101 @@ async def test_get_control_plane_client_quota_project_id(
requester_pays, expected_quota_project
):

fs = ExtendedGcsFileSystem(project="my-project", requester_pays=requester_pays)
with mock.patch.dict(os.environ, {"STORAGE_EMULATOR_HOST": ""}):
ExtendedGcsFileSystem.clear_instance_cache()
fs = ExtendedGcsFileSystem(project="my-project", requester_pays=requester_pays)

mock_transport_cls = mock.Mock()
mock_channel = mock.Mock()
mock_transport_cls.create_channel.return_value = mock_channel

with mock.patch.object(
storage_control_v2.StorageControlAsyncClient,
"get_transport_class",
return_value=mock_transport_cls,
) as mock_get_transport:

await fs._get_control_plane_client()

mock_get_transport.assert_called_once_with("grpc_asyncio")
mock_transport_cls.create_channel.assert_called_once()
kwargs = mock_transport_cls.create_channel.call_args.kwargs
assert kwargs["quota_project_id"] == expected_quota_project


@pytest.mark.asyncio
@pytest.mark.parametrize(
"endpoint_url, env_updates, expected_host, expected_insecure",
[
("https://my-endpoint.com", {}, "my-endpoint.com", False),
(
None,
{
"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "apis-tpczero.goog",
"STORAGE_EMULATOR_HOST": "",
},
"storage.apis-tpczero.goog",
False,
),
(
None,
{
"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "apis-tpczero.goog",
"STORAGE_EMULATOR_HOST": "http://my-emulator.com",
},
"my-emulator.com",
True,
),
(
None,
{"STORAGE_EMULATOR_HOST": "http://my-emulator.com"},
"my-emulator.com",
True,
),
(
None,
{"STORAGE_EMULATOR_HOST": "https://my-emulator.com"},
"my-emulator.com",
False,
),
(None, {"STORAGE_EMULATOR_HOST": ""}, "storage.googleapis.com", False),
],
)
async def test_get_control_plane_client_endpoint(
endpoint_url, env_updates, expected_host, expected_insecure
):
fs_kwargs = {"token": "anon"}
if endpoint_url:
fs_kwargs["endpoint_url"] = endpoint_url

mock_transport_cls = mock.Mock()
mock_channel = mock.Mock()
mock_transport_cls.create_channel.return_value = mock_channel

with mock.patch.object(
storage_control_v2.StorageControlAsyncClient,
"get_transport_class",
return_value=mock_transport_cls,
) as mock_get_transport:
import os

with (
mock.patch.object(
storage_control_v2.StorageControlAsyncClient,
"get_transport_class",
return_value=mock_transport_cls,
),
mock.patch("grpc.aio.insecure_channel") as mock_insecure_channel,
mock.patch.dict(os.environ, env_updates),
):
ExtendedGcsFileSystem.clear_instance_cache()
fs = ExtendedGcsFileSystem(**fs_kwargs)

await fs._get_control_plane_client()

mock_get_transport.assert_called_once_with("grpc_asyncio")
mock_transport_cls.create_channel.assert_called_once()
kwargs = mock_transport_cls.create_channel.call_args.kwargs
assert kwargs["quota_project_id"] == expected_quota_project
if expected_insecure:
mock_insecure_channel.assert_called_once_with(
expected_host, options=mock.ANY
)
else:
mock_transport_cls.create_channel.assert_called_once()
kwargs = mock_transport_cls.create_channel.call_args.kwargs
assert kwargs.get("host") == expected_host


def test_extended_gcsfs_retry_init():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ log_cli_level = "DEBUG"

[tool.isort]
profile = "black"
known_third_party = ["aiohttp", "click", "decorator", "fsspec", "fuse", "google", "google_auth_oauthlib", "numpy", "prettytable", "psutil", "pytest", "pytest_asyncio", "requests", "resource_monitor", "yaml"]
known_third_party = ["aiohttp", "click", "decorator", "fsspec", "fuse", "google", "google_auth_oauthlib", "grpc", "numpy", "prettytable", "psutil", "pytest", "pytest_asyncio", "requests", "resource_monitor", "yaml"]
Loading