Skip to content

Commit 2106eda

Browse files
feat: encryption support (#37)
Co-authored-by: quettabit <27509167+quettabit@users.noreply.github.com>
1 parent 8f0c870 commit 2106eda

12 files changed

Lines changed: 296 additions & 15 deletions

File tree

docs/source/api-reference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@
9292
9393
.. autoenum:: AppendRetryPolicy
9494
95+
.. autoenum:: Encryption
96+
9597
.. autoenum:: StorageClass
9698
9799
.. autoenum:: TimestampingMode

src/s2_sdk/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Batching,
3030
CommandRecord,
3131
Compression,
32+
Encryption,
3233
Endpoints,
3334
ExactMatch,
3435
Gauge,
@@ -84,6 +85,7 @@
8485
"Record",
8586
"AppendInput",
8687
"AppendAck",
88+
"Encryption",
8789
"IndexedAppendAck",
8890
"StreamPosition",
8991
"SeqNum",

src/s2_sdk/_append_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class AppendSession:
3737
"_client",
3838
"_compression",
3939
"_error",
40+
"_encryption_key",
4041
"_permits",
4142
"_queue",
4243
"_retry",
@@ -53,11 +54,13 @@ def __init__(
5354
compression: Compression,
5455
max_unacked_bytes: int,
5556
max_unacked_batches: int | None,
57+
encryption_key: str | None = None,
5658
) -> None:
5759
self._client = client
5860
self._stream_name = stream_name
5961
self._retry = retry
6062
self._compression = compression
63+
self._encryption_key = encryption_key
6164
self._permits = _AppendPermits(max_unacked_bytes, max_unacked_batches)
6265

6366
self._queue: asyncio.Queue[AppendInput | None] = asyncio.Queue()
@@ -121,6 +124,7 @@ async def _run(self) -> None:
121124
retry=self._retry,
122125
compression=self._compression,
123126
ack_timeout=self._client._request_timeout,
127+
encryption_key=self._encryption_key,
124128
):
125129
self._resolve_next(ack)
126130
except BaseException as e:

src/s2_sdk/_mappers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
BasinConfig,
1212
BasinInfo,
1313
BasinScope,
14+
Encryption,
1415
ExactMatch,
1516
Gauge,
1617
Label,
@@ -47,6 +48,8 @@ def basin_config_to_json(config: BasinConfig | None) -> dict[str, Any] | None:
4748
result["default_stream_config"] = stream_config_to_json(
4849
config.default_stream_config
4950
)
51+
if config.stream_cipher is not None:
52+
result["stream_cipher"] = config.stream_cipher.value
5053
if config.create_stream_on_append is not None:
5154
result["create_stream_on_append"] = config.create_stream_on_append
5255
if config.create_stream_on_read is not None:
@@ -56,8 +59,10 @@ def basin_config_to_json(config: BasinConfig | None) -> dict[str, Any] | None:
5659

5760
def basin_config_from_json(data: dict[str, Any]) -> BasinConfig:
5861
dsc = data.get("default_stream_config")
62+
stream_cipher = data.get("stream_cipher")
5963
return BasinConfig(
6064
default_stream_config=stream_config_from_json(dsc) if dsc else None,
65+
stream_cipher=Encryption(stream_cipher) if stream_cipher else None,
6166
create_stream_on_append=data.get("create_stream_on_append"),
6267
create_stream_on_read=data.get("create_stream_on_read"),
6368
)
@@ -146,10 +151,12 @@ def stream_info_from_json(data: dict[str, Any]) -> StreamInfo:
146151
created_at = datetime.fromisoformat(data["created_at"])
147152
deleted_at_str = data.get("deleted_at")
148153
deleted_at = datetime.fromisoformat(deleted_at_str) if deleted_at_str else None
154+
cipher = data.get("cipher")
149155
return StreamInfo(
150156
name=data["name"],
151157
created_at=created_at,
152158
deleted_at=deleted_at,
159+
cipher=Encryption(cipher) if cipher else None,
153160
)
154161

155162

src/s2_sdk/_ops.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import uuid
23
from collections.abc import AsyncIterator
34
from datetime import datetime
@@ -31,11 +32,20 @@
3132
from s2_sdk._producer import Producer
3233
from s2_sdk._retrier import Retrier, http_retry_on, is_safe_to_retry_unary
3334
from s2_sdk._s2s._read_session import run_read_session
34-
from s2_sdk._types import ONE_MIB, Compression, Endpoints, Retry, Timeout, metered_bytes
35+
from s2_sdk._types import (
36+
_S2_ENCRYPTION_KEY_HEADER,
37+
ONE_MIB,
38+
Compression,
39+
Endpoints,
40+
Retry,
41+
Timeout,
42+
metered_bytes,
43+
)
3544
from s2_sdk._validators import (
3645
validate_append_input,
3746
validate_basin,
3847
validate_batching,
48+
validate_encryption_key,
3949
validate_max_unacked,
4050
validate_retry,
4151
)
@@ -608,23 +618,40 @@ async def create_stream(
608618
)
609619
return stream_info_from_json(response.json())
610620

611-
def stream(self, name: str) -> "S2Stream":
621+
def stream(
622+
self,
623+
name: str,
624+
*,
625+
encryption_key: bytes | str | None = None,
626+
) -> "S2Stream":
612627
"""Get an :class:`S2Stream` for performing stream-level operations.
613628
614629
Args:
615630
name: Name of the stream.
631+
encryption_key: Key for encrypting records on append and decrypting
632+
them on read. Required when encryption is enabled via
633+
:attr:`BasinConfig.stream_cipher` (see :class:`Encryption`
634+
for supported algorithms).
635+
If ``bytes``, it will get converted to a base64 encoded str.
636+
If ``str``, it must be base64 encoded.
616637
617638
Returns:
618639
An :class:`S2Stream` bound to the given stream name.
619640
620641
Tip:
621642
Also available via subscript: ``s2["my-basin"]["my-stream"]``.
622643
"""
644+
if isinstance(encryption_key, str):
645+
validate_encryption_key(encryption_key)
646+
elif isinstance(encryption_key, bytes):
647+
encryption_key = base64.b64encode(encryption_key).decode()
648+
623649
return S2Stream(
624650
name,
625651
self._client,
626652
retry=self._retry,
627653
compression=self._compression,
654+
encryption_key=encryption_key,
628655
)
629656

630657
@fallible
@@ -757,6 +784,7 @@ class S2Stream:
757784
"_name",
758785
"_client",
759786
"_compression",
787+
"_encryption_key",
760788
"_retry",
761789
"_retrier",
762790
"_append_retrier",
@@ -769,11 +797,13 @@ def __init__(
769797
*,
770798
retry: Retry,
771799
compression: Compression,
800+
encryption_key: str | None = None,
772801
) -> None:
773802
self._name = name
774803
self._client = client
775804
self._retry = retry
776805
self._compression = compression
806+
self._encryption_key = encryption_key
777807
self._retrier = Retrier(
778808
should_retry_on=http_retry_on,
779809
max_attempts=retry.max_attempts,
@@ -797,6 +827,15 @@ def name(self) -> str:
797827
"""Stream name."""
798828
return self._name
799829

830+
def _request_headers(
831+
self, headers: dict[str, str] | None = None
832+
) -> dict[str, str] | None:
833+
if self._encryption_key is None:
834+
return headers
835+
merged = dict(headers or {})
836+
merged[_S2_ENCRYPTION_KEY_HEADER] = self._encryption_key
837+
return merged
838+
800839
@fallible
801840
async def check_tail(self) -> types.StreamPosition:
802841
"""Check the tail of a stream.
@@ -831,10 +870,12 @@ async def append(self, inp: types.AppendInput) -> types.AppendAck:
831870
"POST",
832871
_stream_path(self.name, "/records"),
833872
content=body,
834-
headers={
835-
"content-type": "application/x-protobuf",
836-
"accept": "application/x-protobuf",
837-
},
873+
headers=self._request_headers(
874+
{
875+
"content-type": "application/x-protobuf",
876+
"accept": "application/x-protobuf",
877+
}
878+
),
838879
)
839880
ack = pb.AppendAck()
840881
ack.ParseFromString(response.content)
@@ -878,6 +919,7 @@ def append_session(
878919
compression=self._compression,
879920
max_unacked_bytes=max_unacked_bytes,
880921
max_unacked_batches=max_unacked_batches,
922+
encryption_key=self._encryption_key,
881923
)
882924

883925
def producer(
@@ -922,6 +964,7 @@ def producer(
922964
stream_name=self.name,
923965
retry=self._retry,
924966
compression=self._compression,
967+
encryption_key=self._encryption_key,
925968
fencing_token=fencing_token,
926969
match_seq_num=match_seq_num,
927970
max_unacked_bytes=max_unacked_bytes,
@@ -971,7 +1014,7 @@ async def read(
9711014
"GET",
9721015
_stream_path(self.name, "/records"),
9731016
params=params,
974-
headers={"accept": "application/x-protobuf"},
1017+
headers=self._request_headers({"accept": "application/x-protobuf"}),
9751018
)
9761019

9771020
proto_batch = pb.ReadBatch()
@@ -1030,6 +1073,7 @@ async def read_session(
10301073
wait,
10311074
ignore_command_records,
10321075
retry=self._retry,
1076+
encryption_key=self._encryption_key,
10331077
):
10341078
yield batch
10351079

src/s2_sdk/_producer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
match_seq_num: int | None,
6161
max_unacked_bytes: int,
6262
batching: Batching,
63+
encryption_key: str | None = None,
6364
) -> None:
6465
self._session = AppendSession(
6566
client=client,
@@ -68,6 +69,7 @@ def __init__(
6869
compression=compression,
6970
max_unacked_bytes=max_unacked_bytes,
7071
max_unacked_batches=None,
72+
encryption_key=encryption_key,
7173
)
7274
self._fencing_token = fencing_token
7375
self._match_seq_num = match_seq_num

src/s2_sdk/_s2s/_append_session.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
read_messages,
2020
)
2121
from s2_sdk._types import (
22+
_S2_ENCRYPTION_KEY_HEADER,
2223
AppendAck,
2324
AppendInput,
2425
AppendRetryPolicy,
@@ -43,6 +44,7 @@ async def run_append_session(
4344
retry: Retry,
4445
compression: Compression,
4546
ack_timeout: float | None = None,
47+
encryption_key: str | None = None,
4648
) -> AsyncIterable[AppendAck]:
4749
input_queue: asyncio.Queue[AppendInput | None] = asyncio.Queue(
4850
maxsize=_QUEUE_MAX_SIZE
@@ -85,6 +87,7 @@ async def retrying_inner():
8587
compression,
8688
frame_signal,
8789
ack_timeout,
90+
encryption_key,
8891
)
8992
return
9093
except Exception as e:
@@ -135,14 +138,19 @@ async def _run_attempt(
135138
compression: Compression,
136139
frame_signal: FrameSignal | None,
137140
ack_timeout: float | None = None,
141+
encryption_key: str | None = None,
138142
) -> None:
143+
headers = {
144+
"content-type": "s2s/proto",
145+
"accept": "s2s/proto",
146+
}
147+
if encryption_key is not None:
148+
headers[_S2_ENCRYPTION_KEY_HEADER] = encryption_key
149+
139150
async with client.streaming_request(
140151
"POST",
141152
_stream_records_path(stream_name),
142-
headers={
143-
"content-type": "s2s/proto",
144-
"accept": "s2s/proto",
145-
},
153+
headers=headers,
146154
content=_body_gen(inflight_inputs, input_queue, pending_resend, compression),
147155
frame_signal=frame_signal,
148156
) as response:

src/s2_sdk/_s2s/_read_session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from s2_sdk._s2s import _stream_records_path
1313
from s2_sdk._s2s._protocol import parse_error_info, read_messages
1414
from s2_sdk._types import (
15+
_S2_ENCRYPTION_KEY_HEADER,
1516
ReadBatch,
1617
ReadLimit,
1718
Retry,
@@ -36,6 +37,7 @@ async def run_read_session(
3637
wait: int | None,
3738
ignore_command_records: bool,
3839
retry: Retry,
40+
encryption_key: str | None = None,
3941
) -> AsyncIterable[ReadBatch]:
4042
params = _build_read_params(start, limit, until_timestamp, clamp_to_tail, wait)
4143
backoffs = compute_backoffs(
@@ -50,6 +52,10 @@ async def run_read_session(
5052

5153
last_tail_at: float | None = None
5254

55+
headers = {"content-type": "s2s/proto"}
56+
if encryption_key is not None:
57+
headers[_S2_ENCRYPTION_KEY_HEADER] = encryption_key
58+
5359
while True:
5460
if wait is not None:
5561
params["wait"] = _remaining_wait(wait, last_tail_at)
@@ -59,7 +65,7 @@ async def run_read_session(
5965
"GET",
6066
_stream_records_path(stream_name),
6167
params=params,
62-
headers={"content-type": "s2s/proto"},
68+
headers=headers,
6369
) as response:
6470
if response.status_code != 200:
6571
body = await response.aread()

src/s2_sdk/_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
T = TypeVar("T")
1212

1313
ONE_MIB = 1024 * 1024
14+
_S2_ENCRYPTION_KEY_HEADER = "s2-encryption-key"
1415

1516

1617
def _parse_scheme(url: str) -> str:
@@ -50,6 +51,13 @@ class AppendRetryPolicy(_DocEnum):
5051
)
5152

5253

54+
class Encryption(_DocEnum):
55+
"""Encryption algorithm."""
56+
57+
AEGIS_256 = "aegis-256", "AEGIS-256."
58+
AES_256_GCM = "aes-256-gcm", "AES-256-GCM."
59+
60+
5361
class Endpoints:
5462
"""S2 service endpoints. See `endpoints <https://s2.dev/docs/api/endpoints>`_."""
5563

@@ -479,6 +487,9 @@ class BasinConfig:
479487
default_stream_config: StreamConfig | None = None
480488
"""Default configuration for streams in this basin."""
481489

490+
stream_cipher: Encryption | None = None
491+
"""Encryption algorithm to apply to newly created streams in the basin."""
492+
482493
create_stream_on_append: bool | None = None
483494
"""Create stream on append if it doesn't exist."""
484495

@@ -516,6 +527,9 @@ class StreamInfo:
516527
deleted_at: datetime | None
517528
"""Deletion time if the stream is being deleted."""
518529

530+
cipher: Encryption | None = None
531+
"""Encryption algorithm for this stream, if encryption is enabled."""
532+
519533

520534
@dataclass(slots=True)
521535
class ExactMatch:

0 commit comments

Comments
 (0)