Skip to content

Commit 48208f4

Browse files
committed
Address feedback
1 parent d59b196 commit 48208f4

3 files changed

Lines changed: 89 additions & 46 deletions

File tree

src/confluent_kafka/cimpl.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,10 @@ class ShareConsumer:
558558
def subscribe(self, topics: List[str]) -> None: ...
559559
def unsubscribe(self) -> None: ...
560560
def subscription(self) -> List[str]: ...
561-
def consume_batch(self, timeout: float = -1) -> List[Message]: ...
561+
def poll(self, timeout: float = -1) -> List[Message]: ...
562562
def close(self) -> None: ...
563+
def __enter__(self) -> "ShareConsumer": ...
564+
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> Optional[bool]: ...
563565

564566
class _AdminClientImpl:
565567
def __init__(self, config: Dict[str, Union[str, int, float, bool]]) -> None: ...

src/confluent_kafka/src/ShareConsumer.c

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ static void ShareConsumer_dealloc(ShareConsumerHandle *self) {
6161
if (self->rkshare) {
6262
CallState cs;
6363
CallState_begin((Handle *)self, &cs);
64+
/* TODO KIP-932: Use rd_kafka_share_destroy_flags() once
65+
* available in the librdkafka public API. */
6466
rd_kafka_share_destroy(self->rkshare);
6567
self->rkshare = NULL;
6668
CallState_end((Handle *)self, &cs);
@@ -197,12 +199,12 @@ static PyObject *ShareConsumer_subscription(ShareConsumerHandle *self) {
197199

198200

199201
/**
200-
* @brief Consume a batch of messages from the share consumer.
202+
* @brief Poll for a batch of messages from the share consumer.
201203
*
202204
*/
203-
static PyObject *ShareConsumer_consume_batch(ShareConsumerHandle *self,
204-
PyObject *args,
205-
PyObject *kwargs) {
205+
static PyObject *ShareConsumer_poll(ShareConsumerHandle *self,
206+
PyObject *args,
207+
PyObject *kwargs) {
206208
double tmout = -1.0f;
207209
static char *kws[] = {"timeout", NULL};
208210
rd_kafka_message_t **rkmessages = NULL;
@@ -299,20 +301,7 @@ static PyObject *ShareConsumer_consume_batch(ShareConsumerHandle *self,
299301

300302
/* Handle error from rd_kafka_share_consume_batch() */
301303
if (error) {
302-
const char *error_str = rd_kafka_error_string(error);
303-
int is_fatal = rd_kafka_error_is_fatal(error);
304-
int is_retriable = rd_kafka_error_is_retriable(error);
305-
306-
if (is_fatal) {
307-
PyErr_Format(PyExc_RuntimeError, "Fatal error: %s",
308-
error_str);
309-
} else {
310-
PyErr_Format(KafkaException,
311-
"Error: %s (retriable: %s)", error_str,
312-
is_retriable ? "yes" : "no");
313-
}
314-
315-
rd_kafka_error_destroy(error);
304+
cfl_PyErr_from_error_destroy(error);
316305
free(rkmessages);
317306
return NULL;
318307
}
@@ -348,6 +337,8 @@ static PyObject *ShareConsumer_close(ShareConsumerHandle *self) {
348337
Py_RETURN_NONE;
349338

350339
CallState_begin((Handle *)self, &cs);
340+
/* TODO KIP-932: rd_kafka_share_consumer_close() return type will change
341+
* to rd_kafka_error_t *. Update error handling accordingly. */
351342
err = rd_kafka_share_consumer_close(self->rkshare);
352343
rd_kafka_share_destroy(self->rkshare);
353344
self->rkshare = NULL;
@@ -364,6 +355,36 @@ static PyObject *ShareConsumer_close(ShareConsumerHandle *self) {
364355
}
365356

366357

358+
/**
359+
* @brief Context manager entry — returns self.
360+
*/
361+
static PyObject *ShareConsumer_enter(ShareConsumerHandle *self) {
362+
Py_INCREF(self);
363+
return (PyObject *)self;
364+
}
365+
366+
/**
367+
* @brief Context manager exit — calls close().
368+
*/
369+
static PyObject *ShareConsumer_exit(ShareConsumerHandle *self, PyObject *args) {
370+
PyObject *exc_type, *exc_value, *exc_traceback;
371+
372+
if (!PyArg_UnpackTuple(args, "__exit__", 3, 3, &exc_type, &exc_value,
373+
&exc_traceback))
374+
return NULL;
375+
376+
/* Cleanup: call close() */
377+
if (self->rkshare) {
378+
PyObject *result = ShareConsumer_close(self);
379+
if (!result)
380+
return NULL;
381+
Py_DECREF(result);
382+
}
383+
384+
Py_RETURN_NONE;
385+
}
386+
387+
367388
/**
368389
* @brief ShareConsumer methods.
369390
*/
@@ -399,14 +420,10 @@ static PyMethodDef ShareConsumer_methods[] = {
399420
" :raises RuntimeError: if called on a closed share consumer\n"
400421
"\n"},
401422

402-
{"consume_batch", (PyCFunction)ShareConsumer_consume_batch,
403-
METH_VARARGS | METH_KEYWORDS,
404-
".. py:function:: consume_batch([timeout=-1])\n"
423+
{"poll", (PyCFunction)ShareConsumer_poll, METH_VARARGS | METH_KEYWORDS,
424+
".. py:function:: poll([timeout=-1])\n"
405425
"\n"
406-
" Consume a batch of messages from the share consumer.\n"
407-
"\n"
408-
" This is the ONLY consumption method for ShareConsumer.\n"
409-
" Share consumers do NOT have a poll() method - they are batch-only.\n"
426+
" Poll for a batch of messages from the share consumer.\n"
410427
"\n"
411428
" The application must check each Message object's error() method\n"
412429
" to distinguish between proper messages (error() returns None)\n"
@@ -421,9 +438,8 @@ static PyMethodDef ShareConsumer_methods[] = {
421438
" Default: -1 (infinite)\n"
422439
" :returns: List of Message objects (possibly empty on timeout)\n"
423440
" :rtype: list(Message)\n"
424-
" :raises RuntimeError: if called on a closed share consumer or on "
425-
"fatal error\n"
426-
" :raises KafkaException: on non-fatal errors\n"
441+
" :raises KafkaException: on error\n"
442+
" :raises RuntimeError: if called on a closed share consumer\n"
427443
" :raises KeyboardInterrupt: if Ctrl+C pressed during consumption\n"
428444
"\n"},
429445

@@ -438,6 +454,15 @@ static PyMethodDef ShareConsumer_methods[] = {
438454
" :raises KafkaException: on error\n"
439455
"\n"},
440456

457+
/* TODO KIP-932: Add set_sasl_credentials once librdkafka exposes
458+
* rd_kafka_sasl_set_credentials() (or the underlying rd_kafka_t *)
459+
* for rd_kafka_share_t handles. */
460+
461+
{"__enter__", (PyCFunction)ShareConsumer_enter, METH_NOARGS,
462+
"Context manager entry."},
463+
{"__exit__", (PyCFunction)ShareConsumer_exit, METH_VARARGS,
464+
"Context manager exit. Automatically closes the share consumer."},
465+
441466
{NULL}};
442467

443468

@@ -530,10 +555,6 @@ PyTypeObject ShareConsumerType = {
530555
"assigned to multiple consumers. Messages are delivered to only one "
531556
"consumer.\n"
532557
"\n"
533-
".. note::\n"
534-
" ShareConsumer only supports batch consumption via consume_batch().\n"
535-
" There is NO poll() method for single messages.\n"
536-
"\n"
537558
":param dict config: Configuration properties. At a minimum, "
538559
"``group.id`` **must** be set and ``bootstrap.servers`` **should** be "
539560
"set.\n"

tests/test_ShareConsumer.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def test_unsubscribe():
6666
sc.close()
6767

6868

69-
def test_consume_batch_no_broker():
70-
"""Test consume_batch() returns empty list when no broker available."""
69+
def test_poll_no_broker():
70+
"""Test poll() returns empty list when no broker available."""
7171
sc = ShareConsumer(
7272
{
7373
'group.id': 'test-share-group',
@@ -79,13 +79,33 @@ def test_consume_batch_no_broker():
7979
sc.subscribe(['test-topic'])
8080

8181
# Should timeout and return empty list
82-
messages = sc.consume_batch(timeout=0.1)
82+
messages = sc.poll(timeout=0.1)
8383
assert isinstance(messages, list)
8484
# May be empty or contain error messages
8585

8686
sc.close()
8787

8888

89+
def test_context_manager():
90+
"""Test that ShareConsumer works as a context manager and closes on exit."""
91+
with ShareConsumer(
92+
{
93+
'group.id': 'test-share-group',
94+
'bootstrap.servers': 'localhost:9092',
95+
'socket.timeout.ms': 100,
96+
}
97+
) as sc:
98+
assert sc is not None
99+
sc.subscribe(['test-topic'])
100+
subscription = sc.subscription()
101+
assert 'test-topic' in subscription
102+
103+
# After exiting the context manager, the consumer should be closed
104+
with pytest.raises(RuntimeError) as ex:
105+
sc.subscribe(['test-topic'])
106+
assert ex.match('Share consumer closed')
107+
108+
89109
def test_close_idempotent():
90110
"""Test that close() can be called multiple times."""
91111
sc = ShareConsumer(
@@ -126,7 +146,7 @@ def test_any_method_after_close_throws_exception():
126146
assert ex.match('Share consumer closed')
127147

128148
with pytest.raises(RuntimeError) as ex:
129-
sc.consume_batch(timeout=0.1)
149+
sc.poll(timeout=0.1)
130150
assert ex.match('Share consumer closed')
131151

132152

@@ -159,8 +179,8 @@ def test_concurrent_consumers():
159179
sc1.subscribe(['test-topic'])
160180
sc2.subscribe(['test-topic'])
161181

162-
messages1 = sc1.consume_batch(timeout=2.0)
163-
messages2 = sc2.consume_batch(timeout=2.0)
182+
messages1 = sc1.poll(timeout=2.0)
183+
messages2 = sc2.poll(timeout=2.0)
164184

165185
# Verify no overlap (share group semantics)
166186
offsets1 = {(msg.topic(), msg.partition(), msg.offset()) for msg in messages1 if not msg.error()}
@@ -190,14 +210,14 @@ def my_error_cb(error):
190210
)
191211

192212
sc.subscribe(['test-topic'])
193-
sc.consume_batch(timeout=0.5)
213+
sc.poll(timeout=0.5)
194214

195215
assert len(error_called) > 0, "error_cb should have been called"
196216
sc.close()
197217

198218

199219
def test_error_cb_exception_propagates():
200-
"""Test that an exception raised in error_cb propagates to consume_batch."""
220+
"""Test that an exception raised in error_cb propagates to poll."""
201221
error_called = []
202222

203223
def error_cb_that_raises(error):
@@ -216,7 +236,7 @@ def error_cb_that_raises(error):
216236
sc.subscribe(['test-topic'])
217237

218238
with pytest.raises(RuntimeError) as exc_info:
219-
sc.consume_batch(timeout=0.5)
239+
sc.poll(timeout=0.5)
220240

221241
assert "Test exception from error_cb" in str(exc_info.value)
222242
assert len(error_called) > 0
@@ -245,7 +265,7 @@ def my_stats_cb(stats_json):
245265
)
246266

247267
sc.subscribe(['test-topic'])
248-
sc.consume_batch(timeout=0.5)
268+
sc.poll(timeout=0.5)
249269

250270
assert len(stats_called) > 0, "stats_cb should have been called"
251271
# Verify we got valid JSON string
@@ -257,7 +277,7 @@ def my_stats_cb(stats_json):
257277

258278

259279
def test_stats_cb_exception_propagates():
260-
"""Test that an exception raised in stats_cb propagates to consume_batch."""
280+
"""Test that an exception raised in stats_cb propagates to poll."""
261281
stats_called = []
262282

263283
def stats_cb_that_raises(stats_json):
@@ -277,7 +297,7 @@ def stats_cb_that_raises(stats_json):
277297
sc.subscribe(['test-topic'])
278298

279299
with pytest.raises(RuntimeError) as exc_info:
280-
sc.consume_batch(timeout=0.5)
300+
sc.poll(timeout=0.5)
281301

282302
assert "Test exception from stats_cb" in str(exc_info.value)
283303
assert len(stats_called) > 0
@@ -307,5 +327,5 @@ def my_throttle_cb(event):
307327
sc.subscribe(['test-topic'])
308328

309329
# throttle_cb won't fire without broker throttling — just verify no crash
310-
sc.consume_batch(timeout=0.2)
330+
sc.poll(timeout=0.2)
311331
sc.close()

0 commit comments

Comments
 (0)