|
11 | 11 | from upath import UPath |
12 | 12 | from upath.implementations.cloud import GCSPath |
13 | 13 | from upath.implementations.cloud import S3Path |
| 14 | +from upath.registry import get_upath_class |
| 15 | +from upath.registry import register_implementation |
14 | 16 | from upath.types import ReadablePath |
15 | 17 | from upath.types import WritablePath |
16 | 18 |
|
@@ -112,12 +114,35 @@ def test_subclass(local_testdir): |
112 | 114 | class MyPath(UPath): |
113 | 115 | pass |
114 | 116 |
|
115 | | - with pytest.warns( |
116 | | - DeprecationWarning, match=r"MyPath\(...\) detected protocol '' .*" |
117 | | - ): |
118 | | - path = MyPath(local_testdir) |
119 | | - assert str(path) == pathlib.Path(local_testdir).as_posix() |
| 117 | + with pytest.raises(ValueError, match=r".*incompatible with"): |
| 118 | + MyPath(local_testdir) |
| 119 | + |
| 120 | + |
| 121 | +@pytest.fixture(scope="function") |
| 122 | +def upath_registry_snapshot(): |
| 123 | + """Save and restore the upath registry state around a test.""" |
| 124 | + from upath.registry import _registry |
| 125 | + |
| 126 | + # Save the current state of the registry's mutable mapping |
| 127 | + saved_m = _registry._m.maps[0].copy() |
| 128 | + try: |
| 129 | + yield |
| 130 | + finally: |
| 131 | + # Restore the registry state |
| 132 | + _registry._m.maps[0].clear() |
| 133 | + _registry._m.maps[0].update(saved_m) |
| 134 | + get_upath_class.cache_clear() |
| 135 | + |
| 136 | + |
| 137 | +def test_subclass_registered(upath_registry_snapshot): |
| 138 | + class MyPath(UPath): |
| 139 | + pass |
| 140 | + |
| 141 | + register_implementation("memory", MyPath, clobber=True) |
| 142 | + path = MyPath("memory:///test_path") |
| 143 | + assert str(path) == "memory:///test_path" |
120 | 144 | assert issubclass(MyPath, UPath) |
| 145 | + assert isinstance(path, MyPath) |
121 | 146 | assert isinstance(path, pathlib_abc.ReadablePath) |
122 | 147 | assert isinstance(path, pathlib_abc.WritablePath) |
123 | 148 | assert not isinstance(path, pathlib.Path) |
@@ -453,33 +478,99 @@ def test_open_a_local_upath(tmp_path, protocol): |
453 | 478 | @pytest.mark.parametrize( |
454 | 479 | "uri,protocol", |
455 | 480 | [ |
| 481 | + # s3 compatible protocols |
456 | 482 | ("s3://bucket/folder", "s3"), |
457 | | - ("gs://bucket/folder", "gs"), |
| 483 | + ("s3a://bucket/folder", "s3a"), |
458 | 484 | ("bucket/folder", "s3"), |
| 485 | + # gcs compatible |
| 486 | + ("gs://bucket/folder", "gs"), |
| 487 | + ("gcs://bucket/folder", "gcs"), |
| 488 | + ("bucket/folder", "gs"), |
| 489 | + # azure compatible |
| 490 | + ("az://container/blob", "az"), |
| 491 | + ("abfs://container/blob", "abfs"), |
| 492 | + ("abfss://container/blob", "abfss"), |
| 493 | + ("adl://container/blob", "adl"), |
| 494 | + # memory |
459 | 495 | ("memory://folder", "memory"), |
| 496 | + ("/folder", "memory"), |
| 497 | + # file/local |
460 | 498 | ("file:/tmp/folder", "file"), |
461 | 499 | ("/tmp/folder", "file"), |
| 500 | + ("file:/tmp/folder", "local"), |
| 501 | + ("/tmp/folder", "local"), |
462 | 502 | ("/tmp/folder", ""), |
463 | 503 | ("a/b/c", ""), |
| 504 | + # http/https |
| 505 | + ("http://example.com/path", "http"), |
| 506 | + ("https://example.com/path", "https"), |
| 507 | + # ftp |
| 508 | + ("ftp://example.com/path", "ftp"), |
| 509 | + # sftp/ssh |
| 510 | + ("sftp://example.com/path", "sftp"), |
| 511 | + ("ssh://example.com/path", "ssh"), |
| 512 | + # smb |
| 513 | + ("smb://server/share/path", "smb"), |
| 514 | + # hdfs |
| 515 | + ("hdfs://namenode/path", "hdfs"), |
| 516 | + # webdav - requires base_url, skip for now |
| 517 | + # github |
| 518 | + ("github://owner:repo@branch/path", "github"), |
| 519 | + # data |
| 520 | + ("data:text/plain;base64,SGVsbG8=", "data"), |
| 521 | + # huggingface |
| 522 | + ("hf://datasets/user/repo/path", "hf"), |
464 | 523 | ], |
465 | 524 | ) |
466 | 525 | def test_constructor_compatible_protocol_uri(uri, protocol): |
467 | 526 | p = UPath(uri, protocol=protocol) |
468 | 527 | assert p.protocol == protocol |
469 | 528 |
|
470 | 529 |
|
471 | | -@pytest.mark.parametrize( |
472 | | - "uri,protocol", |
473 | | - [ |
474 | | - ("s3://bucket/folder", "gs"), |
475 | | - ("gs://bucket/folder", "s3"), |
476 | | - ("memory://folder", "s3"), |
477 | | - ("file:/tmp/folder", "s3"), |
478 | | - ("s3://bucket/folder", ""), |
479 | | - ("memory://folder", ""), |
480 | | - ("file:/tmp/folder", ""), |
481 | | - ], |
482 | | -) |
| 530 | +# Protocol to sample URI mapping |
| 531 | +_PROTOCOL_URIS = { |
| 532 | + "s3": "s3://bucket/folder", |
| 533 | + "gs": "gs://bucket/folder", |
| 534 | + "az": "az://container/blob", |
| 535 | + "memory": "memory://folder", |
| 536 | + "file": "file:/tmp/folder", |
| 537 | + "http": "http://example.com/path", |
| 538 | + "ftp": "ftp://example.com/path", |
| 539 | + "sftp": "sftp://example.com/path", |
| 540 | + "smb": "smb://server/share/path", |
| 541 | + "hdfs": "hdfs://namenode/path", |
| 542 | +} |
| 543 | + |
| 544 | +# Generate incompatible combinations: each protocol with URIs from other protocols |
| 545 | +_INCOMPATIBLE_CASES = [ |
| 546 | + (_PROTOCOL_URIS[uri_protocol], target_protocol) |
| 547 | + for target_protocol in _PROTOCOL_URIS |
| 548 | + for uri_protocol in _PROTOCOL_URIS |
| 549 | + if target_protocol != uri_protocol |
| 550 | +] |
| 551 | + |
| 552 | +# Also test explicit empty protocol with protocol-prefixed URIs |
| 553 | +_INCOMPATIBLE_CASES.extend([(uri, "") for uri in _PROTOCOL_URIS.values()]) |
| 554 | + |
| 555 | + |
| 556 | +@pytest.mark.parametrize("uri,protocol", _INCOMPATIBLE_CASES) |
483 | 557 | def test_constructor_incompatible_protocol_uri(uri, protocol): |
484 | | - with pytest.raises(ValueError, match=r".*incompatible with"): |
| 558 | + with pytest.raises(TypeError, match=r".*incompatible with"): |
485 | 559 | UPath(uri, protocol=protocol) |
| 560 | + |
| 561 | + |
| 562 | +# Test subclass instantiation with incompatible URIs |
| 563 | +# Use protocols that have registered implementations we can get via get_upath_class |
| 564 | +_SUBCLASS_INCOMPATIBLE_CASES = [ |
| 565 | + (_PROTOCOL_URIS[uri_protocol], target_protocol) |
| 566 | + for target_protocol in _PROTOCOL_URIS |
| 567 | + for uri_protocol in _PROTOCOL_URIS |
| 568 | + if target_protocol != uri_protocol |
| 569 | +] |
| 570 | + |
| 571 | + |
| 572 | +@pytest.mark.parametrize("uri,protocol", _SUBCLASS_INCOMPATIBLE_CASES) |
| 573 | +def test_subclass_constructor_incompatible_protocol_uri(uri, protocol): |
| 574 | + cls = get_upath_class(protocol) |
| 575 | + with pytest.raises(TypeError, match=r".*incompatible with"): |
| 576 | + cls(uri) |
0 commit comments