1+ import base64
12import uuid
23from collections .abc import AsyncIterator
34from datetime import datetime
3132from s2_sdk ._producer import Producer
3233from s2_sdk ._retrier import Retrier , http_retry_on , is_safe_to_retry_unary
3334from s2_sdk ._s2s ._read_session import run_read_session
34- from s2_sdk ._types import ONE_MIB , Compression , Endpoints , Retry , Timeout , metered_bytes
35+ from s2_sdk ._types import (
36+ _S2_ENCRYPTION_KEY_HEADER ,
37+ ONE_MIB ,
38+ Compression ,
39+ Endpoints ,
40+ Retry ,
41+ Timeout ,
42+ metered_bytes ,
43+ )
3544from s2_sdk ._validators import (
3645 validate_append_input ,
3746 validate_basin ,
3847 validate_batching ,
48+ validate_encryption_key ,
3949 validate_max_unacked ,
4050 validate_retry ,
4151)
@@ -608,23 +618,40 @@ async def create_stream(
608618 )
609619 return stream_info_from_json (response .json ())
610620
611- def stream (self , name : str ) -> "S2Stream" :
621+ def stream (
622+ self ,
623+ name : str ,
624+ * ,
625+ encryption_key : bytes | str | None = None ,
626+ ) -> "S2Stream" :
612627 """Get an :class:`S2Stream` for performing stream-level operations.
613628
614629 Args:
615630 name: Name of the stream.
631+ encryption_key: Key for encrypting records on append and decrypting
632+ them on read. Required when encryption is enabled via
633+ :attr:`BasinConfig.stream_cipher` (see :class:`Encryption`
634+ for supported algorithms).
635+ If ``bytes``, it will get converted to a base64 encoded str.
636+ If ``str``, it must be base64 encoded.
616637
617638 Returns:
618639 An :class:`S2Stream` bound to the given stream name.
619640
620641 Tip:
621642 Also available via subscript: ``s2["my-basin"]["my-stream"]``.
622643 """
644+ if isinstance (encryption_key , str ):
645+ validate_encryption_key (encryption_key )
646+ elif isinstance (encryption_key , bytes ):
647+ encryption_key = base64 .b64encode (encryption_key ).decode ()
648+
623649 return S2Stream (
624650 name ,
625651 self ._client ,
626652 retry = self ._retry ,
627653 compression = self ._compression ,
654+ encryption_key = encryption_key ,
628655 )
629656
630657 @fallible
@@ -757,6 +784,7 @@ class S2Stream:
757784 "_name" ,
758785 "_client" ,
759786 "_compression" ,
787+ "_encryption_key" ,
760788 "_retry" ,
761789 "_retrier" ,
762790 "_append_retrier" ,
@@ -769,11 +797,13 @@ def __init__(
769797 * ,
770798 retry : Retry ,
771799 compression : Compression ,
800+ encryption_key : str | None = None ,
772801 ) -> None :
773802 self ._name = name
774803 self ._client = client
775804 self ._retry = retry
776805 self ._compression = compression
806+ self ._encryption_key = encryption_key
777807 self ._retrier = Retrier (
778808 should_retry_on = http_retry_on ,
779809 max_attempts = retry .max_attempts ,
@@ -797,6 +827,15 @@ def name(self) -> str:
797827 """Stream name."""
798828 return self ._name
799829
830+ def _request_headers (
831+ self , headers : dict [str , str ] | None = None
832+ ) -> dict [str , str ] | None :
833+ if self ._encryption_key is None :
834+ return headers
835+ merged = dict (headers or {})
836+ merged [_S2_ENCRYPTION_KEY_HEADER ] = self ._encryption_key
837+ return merged
838+
800839 @fallible
801840 async def check_tail (self ) -> types .StreamPosition :
802841 """Check the tail of a stream.
@@ -831,10 +870,12 @@ async def append(self, inp: types.AppendInput) -> types.AppendAck:
831870 "POST" ,
832871 _stream_path (self .name , "/records" ),
833872 content = body ,
834- headers = {
835- "content-type" : "application/x-protobuf" ,
836- "accept" : "application/x-protobuf" ,
837- },
873+ headers = self ._request_headers (
874+ {
875+ "content-type" : "application/x-protobuf" ,
876+ "accept" : "application/x-protobuf" ,
877+ }
878+ ),
838879 )
839880 ack = pb .AppendAck ()
840881 ack .ParseFromString (response .content )
@@ -878,6 +919,7 @@ def append_session(
878919 compression = self ._compression ,
879920 max_unacked_bytes = max_unacked_bytes ,
880921 max_unacked_batches = max_unacked_batches ,
922+ encryption_key = self ._encryption_key ,
881923 )
882924
883925 def producer (
@@ -922,6 +964,7 @@ def producer(
922964 stream_name = self .name ,
923965 retry = self ._retry ,
924966 compression = self ._compression ,
967+ encryption_key = self ._encryption_key ,
925968 fencing_token = fencing_token ,
926969 match_seq_num = match_seq_num ,
927970 max_unacked_bytes = max_unacked_bytes ,
@@ -971,7 +1014,7 @@ async def read(
9711014 "GET" ,
9721015 _stream_path (self .name , "/records" ),
9731016 params = params ,
974- headers = {"accept" : "application/x-protobuf" },
1017+ headers = self . _request_headers ( {"accept" : "application/x-protobuf" }) ,
9751018 )
9761019
9771020 proto_batch = pb .ReadBatch ()
@@ -1030,6 +1073,7 @@ async def read_session(
10301073 wait ,
10311074 ignore_command_records ,
10321075 retry = self ._retry ,
1076+ encryption_key = self ._encryption_key ,
10331077 ):
10341078 yield batch
10351079
0 commit comments