Skip to content

Commit 85c8e70

Browse files
committed
feat: introduce ChannelOptions for enhanced channel configuration
- Added `ChannelOptions` class to handle channel parameters and cipher configurations. - Updated `RealtimeChannel` to support `ChannelOptions`
1 parent 30fdc5d commit 85c8e70

2 files changed

Lines changed: 179 additions & 8 deletions

File tree

ably/realtime/realtime_channel.py

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22
import asyncio
33
import logging
4-
from typing import Optional, TYPE_CHECKING
4+
from typing import Optional, TYPE_CHECKING, Dict, Any, Union
55
from ably.realtime.connection import ConnectionState
66
from ably.transport.websockettransport import ProtocolMessageAction
77
from ably.rest.channel import Channel, Channels as RestChannels
@@ -14,10 +14,75 @@
1414

1515
if TYPE_CHECKING:
1616
from ably.realtime.realtime import AblyRealtime
17+
from ably.util.crypto import CipherParams
1718

1819
log = logging.getLogger(__name__)
1920

2021

22+
class ChannelOptions:
23+
"""Channel options for Ably Realtime channels
24+
25+
Attributes
26+
----------
27+
cipher : CipherParams, optional
28+
Requests encryption for this channel when not null, and specifies encryption-related parameters.
29+
params : Dict[str, str], optional
30+
Channel parameters that configure the behavior of the channel.
31+
"""
32+
33+
def __init__(self, cipher: Optional[CipherParams] = None, params: Optional[dict] = None):
34+
self.__cipher = cipher
35+
self.__params = params
36+
# Validate params
37+
if self.__params and not isinstance(self.__params, dict):
38+
raise AblyException("params must be a dictionary", 40000, 400)
39+
40+
@property
41+
def cipher(self):
42+
"""Get cipher configuration"""
43+
return self.__cipher
44+
45+
@property
46+
def params(self) -> Dict[str, str]:
47+
"""Get channel parameters"""
48+
return self.__params
49+
50+
def __eq__(self, other):
51+
"""Check equality with another ChannelOptions instance"""
52+
if not isinstance(other, ChannelOptions):
53+
return False
54+
55+
return (self.__cipher == other.__cipher and
56+
self.__params == other.__params)
57+
58+
def __hash__(self):
59+
"""Make ChannelOptions hashable"""
60+
return hash((
61+
self.__cipher,
62+
tuple(sorted(self.__params.items())) if self.__params else None,
63+
))
64+
65+
def to_dict(self) -> Dict[str, Any]:
66+
"""Convert to dictionary representation"""
67+
result = {}
68+
if self.__cipher is not None:
69+
result['cipher'] = self.__cipher
70+
if self.__params:
71+
result['params'] = self.__params
72+
return result
73+
74+
@classmethod
75+
def from_dict(cls, options_dict: Dict[str, Any]) -> 'ChannelOptions':
76+
"""Create ChannelOptions from dictionary"""
77+
if not isinstance(options_dict, dict):
78+
raise AblyException("options must be a dictionary", 40000, 400)
79+
80+
return cls(
81+
cipher=options_dict.get('cipher'),
82+
params=options_dict.get('params'),
83+
)
84+
85+
2186
class RealtimeChannel(EventEmitter, Channel):
2287
"""
2388
Ably Realtime Channel
@@ -43,23 +108,40 @@ class RealtimeChannel(EventEmitter, Channel):
43108
Unsubscribe to messages from a channel
44109
"""
45110

46-
def __init__(self, realtime: AblyRealtime, name: str):
111+
def __init__(self, realtime: AblyRealtime, name: str, channel_options: Optional[ChannelOptions] = None):
47112
EventEmitter.__init__(self)
48113
self.__name = name
49114
self.__realtime = realtime
50115
self.__state = ChannelState.INITIALIZED
51116
self.__message_emitter = EventEmitter()
52117
self.__state_timer: Optional[Timer] = None
53118
self.__attach_resume = False
119+
self.__attach_serial: Optional[str] = None
54120
self.__channel_serial: Optional[str] = None
55121
self.__retry_timer: Optional[Timer] = None
56122
self.__error_reason: Optional[AblyException] = None
123+
self.__channel_options = channel_options or ChannelOptions()
124+
self.__params: Optional[Dict[str, str]] = None
57125

58126
# Used to listen to state changes internally, if we use the public event emitter interface then internals
59127
# will be disrupted if the user called .off() to remove all listeners
60128
self.__internal_state_emitter = EventEmitter()
61129

62-
Channel.__init__(self, realtime, name, {})
130+
# Pass channel options as dictionary to parent Channel class
131+
Channel.__init__(self, realtime, name, self.__channel_options.to_dict())
132+
133+
async def set_options(self, channel_options: ChannelOptions) -> None:
134+
"""Set channel options"""
135+
should_reattach = self.should_reattach_to_set_options(channel_options)
136+
self.__channel_options = channel_options
137+
# Update parent class options
138+
self.options = channel_options.to_dict()
139+
140+
if should_reattach:
141+
self._attach_impl()
142+
state_change = await self.__internal_state_emitter.once_async()
143+
if state_change.current in (ChannelState.SUSPENDED, ChannelState.FAILED):
144+
raise state_change.reason
63145

64146
# RTL4
65147
async def attach(self) -> None:
@@ -108,6 +190,7 @@ def _attach_impl(self):
108190
# RTL4c
109191
attach_msg = {
110192
"action": ProtocolMessageAction.ATTACH,
193+
"params": self.__channel_options.params,
111194
"channel": self.name,
112195
}
113196

@@ -292,8 +375,6 @@ def _on_message(self, proto_msg: dict) -> None:
292375
action = proto_msg.get('action')
293376
# RTL4c1
294377
channel_serial = proto_msg.get('channelSerial')
295-
if channel_serial:
296-
self.__channel_serial = channel_serial
297378
# TM2a, TM2c, TM2f
298379
Message.update_inner_message_fields(proto_msg)
299380

@@ -303,6 +384,10 @@ def _on_message(self, proto_msg: dict) -> None:
303384
exception = None
304385
resumed = False
305386

387+
self.__attach_serial = channel_serial
388+
self.__channel_serial = channel_serial
389+
self.__params = proto_msg.get('params')
390+
306391
if error:
307392
exception = AblyException.from_dict(error)
308393

@@ -327,6 +412,7 @@ def _on_message(self, proto_msg: dict) -> None:
327412
self._request_state(ChannelState.ATTACHING)
328413
elif action == ProtocolMessageAction.MESSAGE:
329414
messages = Message.from_encoded_array(proto_msg.get('messages'))
415+
self.__channel_serial = channel_serial
330416
for message in messages:
331417
self.__message_emitter._emit(message.name, message)
332418
elif action == ProtocolMessageAction.ERROR:
@@ -431,6 +517,11 @@ def __on_retry_timer_expire(self) -> None:
431517
log.info("RealtimeChannel retry timer expired, attempting a new attach")
432518
self._request_state(ChannelState.ATTACHING)
433519

520+
def should_reattach_to_set_options(self, new_options: ChannelOptions) -> bool:
521+
if self.state != ChannelState.ATTACHING and self.state != ChannelState.ATTACHED:
522+
return False
523+
return self.__channel_options != new_options
524+
434525
# RTL23
435526
@property
436527
def name(self) -> str:
@@ -453,6 +544,11 @@ def error_reason(self) -> Optional[AblyException]:
453544
"""An AblyException instance describing the last error which occurred on the channel, if any."""
454545
return self.__error_reason
455546

547+
@property
548+
def params(self) -> Dict[str, str]:
549+
"""Get channel parameters"""
550+
return self.__params
551+
456552

457553
class Channels(RestChannels):
458554
"""Creates and destroys RealtimeChannel objects.
@@ -466,19 +562,38 @@ class Channels(RestChannels):
466562
"""
467563

468564
# RTS3
469-
def get(self, name: str) -> RealtimeChannel:
565+
def get(self, name: str, options: Optional[Union[dict, ChannelOptions]] = None) -> RealtimeChannel:
470566
"""Creates a new RealtimeChannel object, or returns the existing channel object.
471567
472568
Parameters
473569
----------
474570
475571
name: str
476572
Channel name
573+
options: ChannelOptions or dict, optional
574+
Channel options for the channel
477575
"""
576+
# Convert dict to ChannelOptions if needed
577+
if options is not None:
578+
if isinstance(options, dict):
579+
options = ChannelOptions.from_dict(options)
580+
elif not isinstance(options, ChannelOptions):
581+
raise AblyException("options must be ChannelOptions instance or dictionary", 40000, 400)
582+
478583
if name not in self.__all:
479-
channel = self.__all[name] = RealtimeChannel(self.__ably, name)
584+
channel = self.__all[name] = RealtimeChannel(self.__ably, name, options)
480585
else:
481586
channel = self.__all[name]
587+
# Update options if channel is not attached or currently attaching
588+
if options and channel.should_reattach_to_set_options(options):
589+
raise AblyException(
590+
'Channels.get() cannot be used to set channel options that would cause the channel to '
591+
'reattach. Please, use RealtimeChannel.setOptions() instead.',
592+
400,
593+
40000
594+
)
595+
elif options:
596+
channel.set_options(options)
482597
return channel
483598

484599
# RTS4

test/ably/realtime/realtimechannel_test.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import pytest
3-
from ably.realtime.realtime_channel import ChannelState, RealtimeChannel
3+
from ably.realtime.realtime_channel import ChannelState, RealtimeChannel, ChannelOptions
44
from ably.transport.websockettransport import ProtocolMessageAction
55
from ably.types.message import Message
66
from test.ably.testapp import TestApp
@@ -468,3 +468,59 @@ async def test_channel_error_cleared_upon_connect_from_terminal_state(self):
468468
assert channel.error_reason is None
469469

470470
await ably.close()
471+
472+
async def test_channel_params_received_by_relatime(self):
473+
ably = await TestApp.get_ably_realtime()
474+
channel_name = random_string(5)
475+
channel = ably.channels.get(channel_name, ChannelOptions(params={
476+
"rewind": "1"
477+
}))
478+
await channel.attach()
479+
assert channel.params["rewind"] == "1"
480+
481+
await ably.close()
482+
483+
async def test_channel_params_unknown_params_skipped_by_relatime(self):
484+
ably = await TestApp.get_ably_realtime()
485+
channel_name = random_string(5)
486+
channel = ably.channels.get(channel_name, ChannelOptions(params={
487+
"rewind": "1",
488+
"foo": "bar"
489+
}))
490+
await channel.attach()
491+
assert channel.params["rewind"] == "1"
492+
assert channel.params.get("foo") is None
493+
494+
await ably.close()
495+
496+
async def test_channel_params_as_dict(self):
497+
ably = await TestApp.get_ably_realtime()
498+
channel_name = random_string(5)
499+
channel = ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"}))
500+
await channel.attach()
501+
assert channel.params["delta"] == "vcdiff"
502+
503+
await ably.close()
504+
505+
async def test_channel_get_channel_with_same_params(self):
506+
ably = await TestApp.get_ably_realtime()
507+
channel_name = random_string(5)
508+
channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
509+
await channel.attach()
510+
same_channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
511+
assert channel == same_channel
512+
513+
await ably.close()
514+
515+
async def test_channel_get_channel_with_different_params(self):
516+
ably = await TestApp.get_ably_realtime()
517+
channel_name = random_string(5)
518+
channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
519+
await channel.attach()
520+
521+
with pytest.raises(AblyException):
522+
ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"}))
523+
524+
assert channel.params == {"rewind": "1"}
525+
526+
await ably.close()

0 commit comments

Comments
 (0)