Skip to content

Commit df2e556

Browse files
committed
fixed serialization
1 parent 5ba4a8e commit df2e556

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

src/altertable_flightsql/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(
166166

167167
def _set_options(self, options: Mapping[str, flight_pb2.SessionOptionValue]):
168168
cmd = flight_pb2.SetSessionOptionsRequest(session_options=options)
169-
action = flight.Action("SetSessionOptions", _pack_command(cmd))
169+
action = flight.Action("SetSessionOptions", cmd.SerializeToString())
170170
list(self._client.do_action(action))
171171

172172
def _execute_query_command(self, cmd) -> flight.FlightStreamReader:

tests/test_client.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from google.protobuf import any_pb2
2+
3+
from altertable_flightsql.client import Client
4+
from altertable_flightsql.generated import arrow_flight_pb2 as flight_pb2
5+
6+
7+
class FakeFlightClient:
8+
def __init__(self):
9+
self.actions = []
10+
11+
def do_action(self, action):
12+
self.actions.append(action)
13+
return []
14+
15+
16+
def _action_body_bytes(action) -> bytes:
17+
body = action.body
18+
if hasattr(body, "to_pybytes"):
19+
return body.to_pybytes()
20+
return bytes(body)
21+
22+
23+
def test_set_options_serializes_flight_session_request_without_any():
24+
flight_client = FakeFlightClient()
25+
client = Client.__new__(Client)
26+
client._client = flight_client
27+
28+
session_options = {
29+
"catalog": flight_pb2.SessionOptionValue(string_value="test_catalog"),
30+
}
31+
client._set_options(session_options)
32+
33+
action = flight_client.actions[0]
34+
request = flight_pb2.SetSessionOptionsRequest(session_options=session_options)
35+
wrapped_request = any_pb2.Any()
36+
wrapped_request.Pack(request)
37+
38+
assert action.type == "SetSessionOptions"
39+
assert _action_body_bytes(action) == request.SerializeToString()
40+
assert _action_body_bytes(action) != wrapped_request.SerializeToString()

0 commit comments

Comments
 (0)