Skip to content

Commit 6120872

Browse files
committed
[AIT-316] refactor: enforce strict Annotation type usage and extend handling
- Refactored to mandate the `Annotation` type across annotation-related methods in `RealtimeAnnotations` and `RestAnnotations`. - Introduced `_copy_with` in `Annotation` for simplified object cloning with modifications. - Enhanced data validation in `encode_data` to raise `AblyException` for unsupported payloads.
1 parent 20288a6 commit 6120872

5 files changed

Lines changed: 95 additions & 27 deletions

File tree

ably/realtime/annotations.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ably.rest.annotations import RestAnnotations, construct_validate_annotation
77
from ably.transport.websockettransport import ProtocolMessageAction
8-
from ably.types.annotation import AnnotationAction
8+
from ably.types.annotation import Annotation, AnnotationAction
99
from ably.types.channelstate import ChannelState
1010
from ably.types.flags import Flag
1111
from ably.util.eventemitter import EventEmitter
@@ -40,13 +40,13 @@ def __init__(self, channel: RealtimeChannel, connection_manager: ConnectionManag
4040
self.__subscriptions = EventEmitter()
4141
self.__rest_annotations = RestAnnotations(channel)
4242

43-
async def publish(self, msg_or_serial, annotation: dict, params: dict | None = None):
43+
async def publish(self, msg_or_serial, annotation: Annotation, params: dict | None = None):
4444
"""
4545
Publish an annotation on a message via the realtime connection.
4646
4747
Args:
4848
msg_or_serial: Either a message serial (string) or a Message object
49-
annotation: Dict containing annotation properties (type, name, data, etc.)
49+
annotation: Annotation object
5050
params: Optional dict of query parameters
5151
5252
Returns:
@@ -87,7 +87,7 @@ async def publish(self, msg_or_serial, annotation: dict, params: dict | None = N
8787
async def delete(
8888
self,
8989
msg_or_serial,
90-
annotation: dict,
90+
annotation: Annotation,
9191
params: dict | None = None,
9292
):
9393
"""
@@ -98,7 +98,7 @@ async def delete(
9898
9999
Args:
100100
msg_or_serial: Either a message serial (string) or a Message object
101-
annotation: Dict containing annotation properties
101+
annotation: Annotation containing annotation properties
102102
params: Optional dict of query parameters
103103
104104
Returns:
@@ -107,9 +107,11 @@ async def delete(
107107
Raises:
108108
AblyException: If the request fails or inputs are invalid
109109
"""
110-
annotation_values = annotation.copy()
111-
annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE
112-
return await self.publish(msg_or_serial, annotation_values, params)
110+
return await self.publish(
111+
msg_or_serial,
112+
annotation._copy_with(action=AnnotationAction.ANNOTATION_DELETE),
113+
params,
114+
)
113115

114116
async def subscribe(self, *args):
115117
"""
@@ -163,6 +165,10 @@ async def subscribe(self, *args):
163165
# Check if ANNOTATION_SUBSCRIBE mode is enabled
164166
if self.__channel.state == ChannelState.ATTACHED:
165167
if Flag.ANNOTATION_SUBSCRIBE not in self.__channel.modes:
168+
if annotation_type is not None:
169+
self.__subscriptions.off(annotation_type, listener)
170+
else:
171+
self.__subscriptions.off(listener)
166172
raise AblyException(
167173
message="You are trying to add an annotation listener, but you haven't requested the "
168174
"annotation_subscribe channel mode in ChannelOptions, so this won't do anything "

ably/rest/annotations.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def serial_from_msg_or_serial(msg_or_serial):
4949
return message_serial
5050

5151

52-
def construct_validate_annotation(msg_or_serial, annotation: dict):
52+
def construct_validate_annotation(msg_or_serial, annotation: Annotation) -> Annotation:
5353
"""
5454
Construct and validate an Annotation from input values.
5555
5656
Args:
5757
msg_or_serial: Either a string serial or a Message object
58-
annotation: Dict of annotation properties or Annotation object
58+
annotation: Annotation object
5959
6060
Returns:
6161
Annotation: The constructed annotation
@@ -65,18 +65,17 @@ def construct_validate_annotation(msg_or_serial, annotation: dict):
6565
"""
6666
message_serial = serial_from_msg_or_serial(msg_or_serial)
6767

68-
if not annotation or (not isinstance(annotation, dict) and not isinstance(annotation, Annotation)):
68+
if not annotation or not isinstance(annotation, Annotation):
6969
raise AblyException(
7070
message='Second argument of annotations.publish() must be a dict or Annotation '
7171
'(the intended annotation to publish)',
7272
status_code=400,
7373
code=40003,
7474
)
7575

76-
annotation_values = annotation.copy()
77-
annotation_values['message_serial'] = message_serial
78-
79-
return Annotation.from_values(annotation_values)
76+
return annotation._copy_with(
77+
message_serial=message_serial,
78+
)
8079

8180

8281
class RestAnnotations:
@@ -109,15 +108,15 @@ def __base_path_for_serial(self, serial):
109108
async def publish(
110109
self,
111110
msg_or_serial,
112-
annotation: dict | Annotation,
111+
annotation: Annotation,
113112
params: dict | None = None,
114113
):
115114
"""
116115
Publish an annotation on a message.
117116
118117
Args:
119118
msg_or_serial: Either a message serial (string) or a Message object
120-
annotation: Dict containing annotation properties (type, name, data, etc.) or Annotation object
119+
annotation: Annotation object
121120
params: Optional dict of query parameters
122121
123122
Returns:
@@ -152,7 +151,7 @@ async def publish(
152151
async def delete(
153152
self,
154153
msg_or_serial,
155-
annotation: dict | Annotation,
154+
annotation: Annotation,
156155
params: dict | None = None,
157156
):
158157
"""
@@ -163,7 +162,7 @@ async def delete(
163162
164163
Args:
165164
msg_or_serial: Either a message serial (string) or a Message object
166-
annotation: Dict containing annotation properties or Annotation object
165+
annotation: Annotation object
167166
params: Optional dict of query parameters
168167
169168
Returns:
@@ -172,13 +171,11 @@ async def delete(
172171
Raises:
173172
AblyException: If the request fails or inputs are invalid
174173
"""
175-
# Set action to delete
176-
if isinstance(annotation, Annotation):
177-
annotation_values = annotation.as_dict()
178-
else:
179-
annotation_values = annotation.copy()
180-
annotation_values['action'] = AnnotationAction.ANNOTATION_DELETE
181-
return await self.publish(msg_or_serial, annotation_values, params)
174+
return await self.publish(
175+
msg_or_serial,
176+
annotation._copy_with(action=AnnotationAction.ANNOTATION_DELETE),
177+
params,
178+
)
182179

183180
async def get(self, msg_or_serial, params: dict | None = None):
184181
"""

ably/rest/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self, ably: AblyRest | AblyRealtime, options: Options):
8989

9090
async def get_auth_transport_param(self):
9191
auth_credentials = {}
92-
if self.auth_options.client_id:
92+
if self.auth_options.client_id and self.auth_options.client_id != '*':
9393
auth_credentials["clientId"] = self.auth_options.client_id
9494
if self.__auth_mechanism == Auth.Method.BASIC:
9595
key_name = self.__auth_options.key_name

ably/types/annotation.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
log = logging.getLogger(__name__)
99

1010

11+
# Sentinel value to distinguish between "not provided" and "explicitly None"
12+
_UNSET = object()
13+
14+
1115
class AnnotationAction(IntEnum):
1216
"""Annotation action types"""
1317
ANNOTATION_CREATE = 0
@@ -59,6 +63,7 @@ def __init__(self,
5963
self.__client_id = to_text(client_id) if client_id is not None else None
6064
self.__timestamp = timestamp
6165
self.__extras = extras
66+
self.__encoding = encoding
6267

6368
def __eq__(self, other):
6469
if isinstance(other, Annotation):
@@ -204,6 +209,62 @@ def __str__(self):
204209
def __repr__(self):
205210
return self.__str__()
206211

212+
def _copy_with(self,
213+
action=_UNSET,
214+
serial=_UNSET,
215+
message_serial=_UNSET,
216+
type=_UNSET,
217+
name=_UNSET,
218+
count=_UNSET,
219+
data=_UNSET,
220+
encoding=_UNSET,
221+
client_id=_UNSET,
222+
timestamp=_UNSET,
223+
extras=_UNSET):
224+
"""
225+
Create a copy of this Annotation with optionally modified fields.
226+
227+
To explicitly set a field to None, pass None as the value.
228+
Fields not provided will retain their original values.
229+
230+
Args:
231+
action: Override the action type (or None to clear it)
232+
serial: Override the serial (or None to clear it)
233+
message_serial: Override the message serial (or None to clear it)
234+
type: Override the type (or None to clear it)
235+
name: Override the name (or None to clear it)
236+
count: Override the count (or None to clear it)
237+
data: Override the data payload (or None to clear it)
238+
encoding: Override the encoding format (or None to clear it)
239+
client_id: Override the client ID (or None to clear it)
240+
timestamp: Override the timestamp (or None to clear it)
241+
extras: Override the extras metadata (or None to clear it)
242+
243+
Returns:
244+
A new Annotation instance with the specified fields updated
245+
246+
Example:
247+
# Keep existing name, change type
248+
new_ann = annotation.copy_with(type="like")
249+
250+
# Explicitly set name to None
251+
new_ann = annotation.copy_with(name=None)
252+
"""
253+
# Get encoding from the mixin's property
254+
return Annotation(
255+
action=self.__action if action is _UNSET else action,
256+
serial=self.__serial if serial is _UNSET else serial,
257+
message_serial=self.__message_serial if message_serial is _UNSET else message_serial,
258+
type=self.__type if type is _UNSET else type,
259+
name=self.__name if name is _UNSET else name,
260+
count=self.__count if count is _UNSET else count,
261+
data=self.__data if data is _UNSET else data,
262+
encoding=self.__encoding if encoding is _UNSET else encoding,
263+
client_id=self.__client_id if client_id is _UNSET else client_id,
264+
timestamp=self.__timestamp if timestamp is _UNSET else timestamp,
265+
extras=self.__extras if extras is _UNSET else extras,
266+
)
267+
207268

208269
def make_annotation_response_handler(cipher=None):
209270
"""Create a response handler for annotation API responses"""

ably/util/encoding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any
44

55
from ably.util.crypto import CipherData
6+
from ably.util.exceptions import AblyException
67

78

89
def encode_data(data: Any, encoding_array: list, binary: bool = False):
@@ -29,6 +30,9 @@ def encode_data(data: Any, encoding_array: list, binary: bool = False):
2930

3031
result = { 'data': data }
3132

33+
if not (isinstance(data, (bytes, str, list, dict, bytearray)) or data is None):
34+
raise AblyException("Invalid data payload", 400, 40011)
35+
3236
if encoding:
3337
result['encoding'] = '/'.join(encoding).strip('/')
3438

0 commit comments

Comments
 (0)