Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions newrelic/hooks/messagebroker_confluentkafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import logging
import sys
import threading
import weakref

from newrelic.api.application import application_instance
from newrelic.api.error_trace import wrap_error_trace
Expand All @@ -33,6 +35,57 @@
HEARTBEAT_SESSION_TIMEOUT = "MessageBroker/Kafka/Heartbeat/SessionTimeout"
HEARTBEAT_POLL_TIMEOUT = "MessageBroker/Kafka/Heartbeat/PollTimeout"

KAFKA_CLUSTER_METRIC_PRODUCE = "MessageBroker/Kafka/Cluster/{0}/Topic/{1}/Produce"
KAFKA_CLUSTER_METRIC_CONSUME = "MessageBroker/Kafka/Cluster/{0}/Topic/{1}/Consume"


_nr_cluster_id_cache = {}
_nr_cluster_id_cache_lock = threading.Lock()


def _fetch_cluster_id(instance):
servers = getattr(instance, "_nr_bootstrap_servers", None)
# Sort so that equivalent broker sets with different orderings share the same key.
cache_key = ",".join(sorted(servers)) if servers else None

if cache_key:
with _nr_cluster_id_cache_lock:
cached = _nr_cluster_id_cache.get(cache_key)
if cached:
instance._nr_cluster_id = cached
return
if cached is not None:
return
_nr_cluster_id_cache[cache_key] = ""

# Hold only a weak reference so the thread closure does not extend the
# lifetime of a Producer/Consumer that the caller has already abandoned.
instance_ref = weakref.ref(instance)

def _run():
inst = instance_ref()
if inst is None:
# Instance was GC'd before the thread ran; clean up sentinel and exit.
if cache_key:
with _nr_cluster_id_cache_lock:
_nr_cluster_id_cache.pop(cache_key, None)
return
try:
meta = inst.list_topics(timeout=5)
cluster_id = getattr(meta, "cluster_id", None)
if cluster_id:
inst._nr_cluster_id = cluster_id
if cache_key:
with _nr_cluster_id_cache_lock:
_nr_cluster_id_cache[cache_key] = cluster_id
except Exception as e:
_logger.debug("NR Kafka cluster ID fetch failed", exc_info=True)
if cache_key:
with _nr_cluster_id_cache_lock:
_nr_cluster_id_cache.pop(cache_key, None)

threading.Thread(target=_run, daemon=True, name="NR-Kafka-ClusterId").start()


def wrap_Producer_produce(wrapped, instance, args, kwargs):
transaction = current_transaction()
Expand Down Expand Up @@ -63,6 +116,17 @@ def wrap_Producer_produce(wrapped, instance, args, kwargs):
for server_name in instance._nr_bootstrap_servers:
transaction.record_custom_metric(f"MessageBroker/Kafka/Nodes/{server_name}/Produce/{topic}", 1)

cluster_id = getattr(instance, "_nr_cluster_id", None)
if not cluster_id and hasattr(instance, "_nr_bootstrap_servers"):
_cache_key = ",".join(sorted(instance._nr_bootstrap_servers))
cluster_id = _nr_cluster_id_cache.get(_cache_key) or None
if cluster_id:
instance._nr_cluster_id = cluster_id # cache on instance for future calls
if cluster_id:
transaction.record_custom_metric(
KAFKA_CLUSTER_METRIC_PRODUCE.format(cluster_id, topic), 1
)

with MessageTrace(
library="Kafka", operation="Produce", destination_type="Topic", destination_name=topic, source=wrapped
):
Expand Down Expand Up @@ -171,6 +235,16 @@ def wrap_Consumer_poll(wrapped, instance, args, kwargs):
transaction.record_custom_metric(
f"MessageBroker/Kafka/Nodes/{server_name}/Consume/{destination_name}", 1
)
cluster_id = getattr(instance, "_nr_cluster_id", None)
if not cluster_id and hasattr(instance, "_nr_bootstrap_servers"):
_cache_key = ",".join(sorted(instance._nr_bootstrap_servers))
cluster_id = _nr_cluster_id_cache.get(_cache_key) or None
if cluster_id:
instance._nr_cluster_id = cluster_id
if cluster_id:
transaction.record_custom_metric(
KAFKA_CLUSTER_METRIC_CONSUME.format(cluster_id, destination_name), 1
)
transaction.add_messagebroker_info("Confluent-Kafka", get_package_version("confluent-kafka"))

return record
Expand Down Expand Up @@ -213,6 +287,16 @@ def wrap_SerializingProducer_init(wrapped, instance, args, kwargs):
if hasattr(instance, "_value_serializer") and callable(instance._value_serializer):
instance._value_serializer = wrap_serializer("Serialization/Value", "MessageBroker")(instance._value_serializer)

try:
conf = kwargs.get("conf") or (args[0] if args else {})
servers = conf.get("bootstrap.servers") if isinstance(conf, dict) else None
if servers:
instance._nr_bootstrap_servers = servers.split(",")
except Exception:
pass

_fetch_cluster_id(instance)


def wrap_DeserializingConsumer_init(wrapped, instance, args, kwargs):
wrapped(*args, **kwargs)
Expand All @@ -223,6 +307,16 @@ def wrap_DeserializingConsumer_init(wrapped, instance, args, kwargs):
if hasattr(instance, "_value_deserializer") and callable(instance._value_deserializer):
instance._value_deserializer = wrap_serializer("Deserialization/Value", "Message")(instance._value_deserializer)

try:
conf = kwargs.get("conf") or (args[0] if args else {})
servers = conf.get("bootstrap.servers") if isinstance(conf, dict) else None
if servers:
instance._nr_bootstrap_servers = servers.split(",")
except Exception:
pass

_fetch_cluster_id(instance)


def wrap_Producer_init(wrapped, instance, args, kwargs):
wrapped(*args, **kwargs)
Expand All @@ -236,6 +330,8 @@ def wrap_Producer_init(wrapped, instance, args, kwargs):
except Exception:
pass

_fetch_cluster_id(instance)


def wrap_Consumer_init(wrapped, instance, args, kwargs):
wrapped(*args, **kwargs)
Expand All @@ -249,6 +345,8 @@ def wrap_Consumer_init(wrapped, instance, args, kwargs):
except Exception:
pass

_fetch_cluster_id(instance)


def wrap_immutable_class(module, class_name):
# Wrap immutable binary extension class with a mutable Python subclass
Expand Down
Loading
Loading