Skip to content

Commit 0654585

Browse files
committed
fix(oci): resolve mypy type errors and add type-checking test gate
The OCI client code introduced several mypy errors that went unnoticed because mypy was configured but never enforced in tests or CI. Type fixes: - lazy_oci_deps.py: suppress import-untyped for oci SDK (no type stubs) - oci_client.py: cast response.stream to Iterator[bytes] (httpx types it as SyncByteStream | AsyncByteStream but it's iterable at runtime) - oci_client.py: use .get("model", "") to satisfy str expectation - test_oci_client.py: suppress attr-defined on dynamic module stubs New test gate (tests/test_oci_mypy.py): - Runs mypy on OCI source and test files as part of pytest - Uses --follow-imports=silent to isolate from pre-existing AWS errors - Skips gracefully if mypy is not on PATH - Ensures future type regressions fail the test suite immediately
1 parent a06825a commit 0654585

6 files changed

Lines changed: 687 additions & 294 deletions

File tree

poetry.lock

Lines changed: 613 additions & 283 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ oci = { version = "^2.165.0", optional = true }
5252

5353
[tool.poetry.extras]
5454
oci = ["oci"]
55+
aiohttp=["aiohttp", "httpx-aiohttp"]
5556

5657
[tool.poetry.group.dev.dependencies]
5758
mypy = "==1.13.0"
@@ -99,6 +100,3 @@ section-order = ["future", "standard-library", "third-party", "first-party"]
99100
[build-system]
100101
requires = ["poetry-core"]
101102
build-backend = "poetry.core.masonry.api"
102-
103-
[tool.poetry.extras]
104-
aiohttp=["aiohttp", "httpx-aiohttp"]

src/cohere/manually_maintained/lazy_oci_deps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def lazy_oci() -> Any:
2424
ImportError: If the OCI SDK is not installed
2525
"""
2626
try:
27-
import oci
27+
import oci # type: ignore[import-untyped]
2828
return oci
2929
except ImportError:
3030
raise ImportError(OCI_INSTALLATION_MESSAGE)

src/cohere/oci_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def _hook(response: httpx.Response) -> None:
530530

531531
# For streaming responses, wrap the stream with a transformer
532532
if is_stream:
533-
original_stream = response.stream
533+
original_stream = typing.cast(typing.Iterator[bytes], response.stream)
534534
transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint, is_v2)
535535
response.stream = Streamer(transformed_stream)
536536
# Reset consumption flags
@@ -640,7 +640,7 @@ def transform_request_to_oci(
640640
Returns:
641641
Transformed request body in OCI format
642642
"""
643-
model = normalize_model_for_oci(cohere_body.get("model"))
643+
model = normalize_model_for_oci(cohere_body.get("model", ""))
644644

645645
if endpoint == "embed":
646646
if "texts" in cohere_body:

tests/test_oci_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,19 @@
5555

5656
if "tokenizers" not in sys.modules:
5757
tokenizers_stub = types.ModuleType("tokenizers")
58-
tokenizers_stub.Tokenizer = object
58+
tokenizers_stub.Tokenizer = object # type: ignore[attr-defined]
5959
sys.modules["tokenizers"] = tokenizers_stub
6060

6161
if "fastavro" not in sys.modules:
6262
fastavro_stub = types.ModuleType("fastavro")
63-
fastavro_stub.parse_schema = lambda schema: schema
64-
fastavro_stub.reader = lambda *args, **kwargs: iter(())
65-
fastavro_stub.writer = lambda *args, **kwargs: None
63+
fastavro_stub.parse_schema = lambda schema: schema # type: ignore[attr-defined]
64+
fastavro_stub.reader = lambda *args, **kwargs: iter(()) # type: ignore[attr-defined]
65+
fastavro_stub.writer = lambda *args, **kwargs: None # type: ignore[attr-defined]
6666
sys.modules["fastavro"] = fastavro_stub
6767

6868
if "httpx_sse" not in sys.modules:
6969
httpx_sse_stub = types.ModuleType("httpx_sse")
70-
httpx_sse_stub.connect_sse = lambda *args, **kwargs: None
70+
httpx_sse_stub.connect_sse = lambda *args, **kwargs: None # type: ignore[attr-defined]
7171
sys.modules["httpx_sse"] = httpx_sse_stub
7272

7373

tests/test_oci_mypy.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Mypy type-checking gate for OCI client code.
2+
3+
Runs mypy on OCI source and test files and fails if any type errors are found.
4+
This prevents type regressions from being introduced silently.
5+
6+
Run with:
7+
pytest tests/test_oci_mypy.py
8+
"""
9+
10+
import os
11+
import shutil
12+
import subprocess
13+
import unittest
14+
15+
MYPY_BIN = shutil.which("mypy")
16+
17+
# Files that must stay mypy-clean
18+
OCI_SOURCE_FILES = [
19+
"src/cohere/oci_client.py",
20+
"src/cohere/manually_maintained/lazy_oci_deps.py",
21+
]
22+
23+
OCI_TEST_FILES = [
24+
"tests/test_oci_client.py",
25+
]
26+
27+
# --follow-imports=silent prevents mypy from crawling into transitive
28+
# dependencies (e.g. the AWS client) that have pre-existing errors.
29+
_MYPY_BASE = [
30+
"--config-file", "mypy.ini",
31+
"--follow-imports=silent",
32+
]
33+
34+
35+
def _run_mypy(files: list[str], extra_env: dict[str, str] | None = None) -> tuple[int, str]:
36+
"""Run mypy on the given files and return (exit_code, output)."""
37+
assert MYPY_BIN is not None
38+
env = {**os.environ, **(extra_env or {})}
39+
result = subprocess.run(
40+
[MYPY_BIN, *_MYPY_BASE, *files],
41+
capture_output=True,
42+
text=True,
43+
env=env,
44+
)
45+
return result.returncode, (result.stdout + result.stderr).strip()
46+
47+
48+
@unittest.skipIf(MYPY_BIN is None, "mypy not found on PATH")
49+
class TestOciMypy(unittest.TestCase):
50+
"""Ensure OCI files pass mypy with no new errors."""
51+
52+
def test_oci_source_types(self):
53+
"""OCI source files must be free of mypy errors."""
54+
code, output = _run_mypy(OCI_SOURCE_FILES)
55+
self.assertEqual(code, 0, f"mypy found type errors in OCI source:\n{output}")
56+
57+
def test_oci_test_types(self):
58+
"""OCI test files must be free of mypy errors."""
59+
# PYTHONPATH=src so mypy can resolve `import cohere`
60+
code, output = _run_mypy(OCI_TEST_FILES, extra_env={"PYTHONPATH": "src"})
61+
self.assertEqual(code, 0, f"mypy found type errors in OCI tests:\n{output}")
62+
63+
64+
if __name__ == "__main__":
65+
unittest.main()

0 commit comments

Comments
 (0)