|
23 | 23 | AsyncIterator, |
24 | 24 | ) |
25 | 25 |
|
| 26 | +import os |
| 27 | +import json |
26 | 28 | import httpx |
27 | 29 | import warnings |
28 | 30 | from typing_extensions import Literal |
29 | 31 |
|
30 | 32 | from ..._types import Body, Query, Headers |
31 | 33 | from ..._utils._utils import deepcopy_minimal, with_sts_token, async_with_sts_token |
32 | | -from ..._utils._key_agreement import aes_gcm_decrypt_base64_string, aes_gcm_decrypt_base64_list |
| 34 | +from ..._utils._key_agreement import aes_gcm_decrypt_base64_string, aes_gcm_decrypt_base64_list, decrypt_validate |
33 | 35 | from ..._base_client import make_request_options |
34 | 36 | from ..._resource import SyncAPIResource, AsyncAPIResource |
35 | 37 | from ..._compat import cached_property |
@@ -69,7 +71,8 @@ def _process_messages( |
69 | 71 | part["text"] = f(part["text"]) |
70 | 72 | elif part.get("type", None) == "image_url": |
71 | 73 | if part["image_url"]["url"].startswith("data:"): |
72 | | - part["image_url"]["url"] = f(part["image_url"]["url"]) |
| 74 | + part["image_url"]["url"] = f( |
| 75 | + part["image_url"]["url"]) |
73 | 76 | else: |
74 | 77 | warnings.warn( |
75 | 78 | "encryption is not supported for image url, " |
@@ -103,15 +106,16 @@ def _encrypt( |
103 | 106 | model: str, |
104 | 107 | messages: Iterable[ChatCompletionMessageParam], |
105 | 108 | extra_headers: Headers, |
106 | | - ) -> tuple[bytes, bytes]: |
107 | | - client = self._client._get_endpoint_certificate(model) |
| 109 | + ) -> tuple[bytes, bytes, str, str]: |
| 110 | + client, ring_id, key_id = self._client._get_endpoint_certificate(model) |
108 | 111 | _crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair() |
109 | 112 | extra_headers["X-Session-Token"] = session_token |
110 | 113 | _process_messages( |
111 | 114 | messages, |
112 | | - lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x), |
| 115 | + lambda x: client.encrypt_string_with_key( |
| 116 | + _crypto_key, _crypto_nonce, x), |
113 | 117 | ) |
114 | | - return _crypto_key, _crypto_nonce |
| 118 | + return _crypto_key, _crypto_nonce, ring_id, key_id |
115 | 119 |
|
116 | 120 | def _decrypt_chunk( |
117 | 121 | self, key: bytes, nonce: bytes, resp: Stream[ChatCompletionChunk] |
@@ -142,10 +146,13 @@ def _decrypt( |
142 | 146 | choice.message is not None and choice.finish_reason != 'content_filter' |
143 | 147 | and choice.message.content is not None |
144 | 148 | ): |
145 | | - content = aes_gcm_decrypt_base64_string( |
146 | | - key, nonce, choice.message.content |
147 | | - ) |
148 | | - if content == '': |
| 149 | + try: |
| 150 | + content = aes_gcm_decrypt_base64_string( |
| 151 | + key, nonce, choice.message.content |
| 152 | + ) |
| 153 | + except Exception: |
| 154 | + content = '' |
| 155 | + if content == '' or not decrypt_validate(choice.message.content): |
149 | 156 | content = aes_gcm_decrypt_base64_list( |
150 | 157 | key, nonce, choice.message.content |
151 | 158 | ) |
@@ -197,7 +204,15 @@ def create( |
197 | 204 | ): |
198 | 205 | is_encrypt = True |
199 | 206 | messages = deepcopy_minimal(messages) |
200 | | - e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers) |
| 207 | + e2e_key, e2e_nonce, ring_id, key_id = self._encrypt( |
| 208 | + model, messages, extra_headers) |
| 209 | + if os.environ.get("VOLC_ARK_ENCRYPTION") == "AICC": |
| 210 | + info = { |
| 211 | + 'Version': 'AICCv0.1', |
| 212 | + 'RingID': ring_id, |
| 213 | + 'KeyID': key_id, |
| 214 | + } |
| 215 | + extra_headers["X-Encrypt-Info"] = json.dumps(info) |
201 | 216 |
|
202 | 217 | resp = self._post( |
203 | 218 | "/chat/completions", |
@@ -257,15 +272,16 @@ def _encrypt( |
257 | 272 | model: str, |
258 | 273 | messages: Iterable[ChatCompletionMessageParam], |
259 | 274 | extra_headers: Headers, |
260 | | - ) -> tuple[bytes, bytes]: |
261 | | - client = self._client._get_endpoint_certificate(model) |
| 275 | + ) -> tuple[bytes, bytes, str, str]: |
| 276 | + client, ring_id, key_id = self._client._get_endpoint_certificate(model) |
262 | 277 | _crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair() |
263 | 278 | extra_headers["X-Session-Token"] = session_token |
264 | 279 | _process_messages( |
265 | 280 | messages, |
266 | | - lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x), |
| 281 | + lambda x: client.encrypt_string_with_key( |
| 282 | + _crypto_key, _crypto_nonce, x), |
267 | 283 | ) |
268 | | - return _crypto_key, _crypto_nonce |
| 284 | + return _crypto_key, _crypto_nonce, ring_id, key_id |
269 | 285 |
|
270 | 286 | async def _decrypt_chunk( |
271 | 287 | self, key: bytes, nonce: bytes, resp: AsyncStream[ChatCompletionChunk] |
@@ -296,10 +312,13 @@ async def _decrypt( |
296 | 312 | choice.message is not None and choice.finish_reason != 'content_filter' |
297 | 313 | and choice.message.content is not None |
298 | 314 | ): |
299 | | - content = aes_gcm_decrypt_base64_string( |
300 | | - key, nonce, choice.message.content |
301 | | - ) |
302 | | - if content == '': |
| 315 | + try: |
| 316 | + content = aes_gcm_decrypt_base64_string( |
| 317 | + key, nonce, choice.message.content |
| 318 | + ) |
| 319 | + except Exception: |
| 320 | + content = '' |
| 321 | + if content == '' or not decrypt_validate(choice.message.content): |
303 | 322 | content = aes_gcm_decrypt_base64_list( |
304 | 323 | key, nonce, choice.message.content |
305 | 324 | ) |
@@ -351,7 +370,15 @@ async def create( |
351 | 370 | ): |
352 | 371 | is_encrypt = True |
353 | 372 | messages = deepcopy_minimal(messages) |
354 | | - e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers) |
| 373 | + e2e_key, e2e_nonce, ring_id, key_id = self._encrypt( |
| 374 | + model, messages, extra_headers) |
| 375 | + if os.environ.get("VOLC_ARK_ENCRYPTION") == "AICC": |
| 376 | + info = { |
| 377 | + 'Version': 'AICCv0.1', |
| 378 | + 'RingID': ring_id, |
| 379 | + 'KeyID': key_id, |
| 380 | + } |
| 381 | + extra_headers["X-Encrypt-Info"] = json.dumps(info) |
355 | 382 |
|
356 | 383 | resp = await self._post( |
357 | 384 | "/chat/completions", |
|
0 commit comments