Skip to content

Commit 27abc5c

Browse files
committed
Address cursor review: raise on unsupported endpoint, refresh session token per-request
- transform_request_to_oci now raises ValueError for endpoints other than 'embed' and 'chat' instead of silently returning the untransformed body - Session token auth uses a refreshing wrapper that re-reads the token file before each signing call, so OCI CLI token refreshes are picked up without restarting the client - Add test_unsupported_endpoint_raises to cover the new explicit error - Update test_session_auth_prefers_security_token_signer to expect multi-call behaviour from the refreshing signer
1 parent 2ac5ec4 commit 27abc5c

2 files changed

Lines changed: 49 additions & 12 deletions

File tree

src/cohere/oci_client.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -399,24 +399,46 @@ def map_request_to_oci(
399399
if "signer" in oci_config:
400400
signer = oci_config["signer"] # Instance/resource principal
401401
elif "security_token_file" in oci_config:
402-
# Session-based authentication with security token (fallback if no user field)
403-
token_file_path = os.path.expanduser(oci_config["security_token_file"])
404-
with open(token_file_path, "r") as f:
405-
security_token = f.read().strip()
406-
407-
# Load private key using OCI's utility function
402+
# Session-based authentication with security token.
403+
# The token file is re-read on every request so that OCI CLI token refreshes
404+
# (e.g. `oci session refresh`) are picked up without restarting the client.
408405
key_file = oci_config.get("key_file")
409406
if not key_file:
410407
raise ValueError(
411408
"OCI config profile is missing 'key_file'. "
412409
"Session-based auth requires a key_file entry in your OCI config profile."
413410
)
411+
token_file_path = os.path.expanduser(oci_config["security_token_file"])
414412
private_key = oci.signer.load_private_key_from_file(os.path.expanduser(key_file))
415413

416-
signer = oci.auth.signers.SecurityTokenSigner(
417-
token=security_token,
418-
private_key=private_key,
419-
)
414+
class _RefreshingSecurityTokenSigner:
415+
"""Wraps SecurityTokenSigner and re-reads the token file before each signing call."""
416+
417+
def __init__(self) -> None:
418+
self._token_file = token_file_path
419+
self._private_key = private_key
420+
self._refresh()
421+
422+
def _refresh(self) -> None:
423+
with open(self._token_file, "r") as _f:
424+
_token = _f.read().strip()
425+
self._signer = oci.auth.signers.SecurityTokenSigner(
426+
token=_token,
427+
private_key=self._private_key,
428+
)
429+
430+
# Delegate all attribute access to the inner signer, refreshing first.
431+
def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
432+
self._refresh()
433+
return self._signer(*args, **kwargs)
434+
435+
def __getattr__(self, name: str) -> typing.Any:
436+
if name.startswith("_"):
437+
raise AttributeError(name)
438+
self._refresh()
439+
return getattr(self._signer, name)
440+
441+
signer = _RefreshingSecurityTokenSigner()
420442
elif "user" in oci_config:
421443
signer = oci.signer.Signer(
422444
tenancy=oci_config["tenancy"],
@@ -814,7 +836,10 @@ def transform_request_to_oci(
814836

815837
return oci_body
816838

817-
return cohere_body
839+
raise ValueError(
840+
f"Endpoint '{endpoint}' is not supported by OCI Generative AI on-demand inference. "
841+
"Supported endpoints: ['embed', 'chat']"
842+
)
818843

819844

820845
def transform_oci_response_to_cohere(

tests/test_oci_client.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,15 @@ def test_v1_client_rejects_v2_request(self):
737737
)
738738
self.assertIn("OciClient ", str(ctx.exception))
739739

740+
def test_unsupported_endpoint_raises(self):
741+
"""Test that transform_request_to_oci raises for unsupported endpoints."""
742+
from cohere.oci_client import transform_request_to_oci
743+
744+
with self.assertRaises(ValueError) as ctx:
745+
transform_request_to_oci("rerank", {"model": "rerank-v3.5"}, "compartment-123")
746+
self.assertIn("rerank", str(ctx.exception))
747+
self.assertIn("not supported", str(ctx.exception))
748+
740749
def test_v1_chat_request_optional_params(self):
741750
"""Test V1 chat request forwards supported optional params."""
742751
from cohere.oci_client import transform_request_to_oci
@@ -930,10 +939,13 @@ def test_session_auth_prefers_security_token_signer(self):
930939

931940
hook(request)
932941

933-
mock_oci.auth.signers.SecurityTokenSigner.assert_called_once_with(
942+
# SecurityTokenSigner is called at least once (init) and again per request
943+
# (token file is re-read on each signing call to pick up refreshed tokens).
944+
mock_oci.auth.signers.SecurityTokenSigner.assert_called_with(
934945
token="session-token",
935946
private_key="private-key",
936947
)
948+
self.assertGreaterEqual(mock_oci.auth.signers.SecurityTokenSigner.call_count, 1)
937949
mock_oci.signer.Signer.assert_not_called()
938950

939951
def test_embed_response_lowercases_embedding_keys(self):

0 commit comments

Comments
 (0)