Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions grpc-circuitbreaker-utils/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
plugins {
`java-library`
jacoco
id("org.hypertrace.publish-plugin")
id("org.hypertrace.jacoco-report-plugin")
}

dependencies {

api(platform("io.grpc:grpc-bom:1.68.3"))
api("io.grpc:grpc-context")
api("io.grpc:grpc-api")
api("io.grpc:grpc-inprocess")
api(platform("io.netty:netty-bom:4.1.118.Final"))
constraints {
api("com.google.protobuf:protobuf-java:3.25.5") {
because("https://nvd.nist.gov/vuln/detail/CVE-2024-7254")
}
}

implementation(project(":grpc-context-utils"))
implementation("org.slf4j:slf4j-api:1.7.36")
implementation("io.grpc:grpc-core")
implementation("io.github.resilience4j:resilience4j-circuitbreaker:1.7.1")
implementation("com.typesafe:config:1.4.2")
implementation("com.google.inject:guice:7.0.0")
implementation("org.hypertrace.core.serviceframework:platform-metrics:0.1.87")

annotationProcessor("org.projectlombok:lombok:1.18.24")
compileOnly("org.projectlombok:lombok:1.18.24")

testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
testImplementation("org.mockito:mockito-core:5.8.0")
testRuntimeOnly("io.grpc:grpc-netty")
}

tasks.test {
useJUnitPlatform()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package org.hypertrace.circuitbreaker.grpcutils;

import com.typesafe.config.Config;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class CircuitBreakerConfigProvider {

public static final String CIRCUIT_BREAKER_CONFIG = "circuit.breaker.config";
public static final String DEFAULT_CONFIG_KEY = "default";

// Whether to enable circuit breaker or not.
private static final String ENABLED = "enabled";

// Percentage of failures to trigger OPEN state
private static final String FAILURE_RATE_THRESHOLD = "failureRateThreshold";
// Percentage of slow calls to trigger OPEN state
private static final String SLOW_CALL_RATE_THRESHOLD = "slowCallRateThreshold";
// Define what a "slow" call is
private static final String SLOW_CALL_DURATION_THRESHOLD = "slowCallDurationThreshold";
// Number of calls to consider in the sliding window
private static final String SLIDING_WINDOW_SIZE = "slidingWindowSize";
// Time before retrying after OPEN state
private static final String WAIT_DURATION_IN_OPEN_STATE = "waitDurationInOpenState";
// Minimum calls before evaluating failure rate
private static final String MINIMUM_NUMBER_OF_CALLS = "minimumNumberOfCalls";
// Calls allowed in HALF_OPEN state before deciding to
// CLOSE or OPEN again
private static final String PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE =
"permittedNumberOfCallsInHalfOpenState";
private static final String SLIDING_WINDOW_TYPE = "slidingWindowType";

// Cache for storing CircuitBreakerConfig instances
private static final ConcurrentHashMap<String, CircuitBreakerConfig> configCache =
new ConcurrentHashMap<>();

// Global flag for circuit breaker enablement
private boolean circuitBreakerEnabled = false;

public CircuitBreakerConfigProvider(Config config) {
initialize(config);
}

public CircuitBreakerConfigProvider() {}

/** Initializes and caches all CircuitBreaker configurations. */
public void initialize(Config config) {
if (!config.hasPath(CIRCUIT_BREAKER_CONFIG)) {
log.warn("No circuit breaker configurations found in the config file.");
return;
}

Config circuitBreakerConfig = config.getConfig(CIRCUIT_BREAKER_CONFIG);

// Read global enabled flag (default to false if not provided)
circuitBreakerEnabled =
circuitBreakerConfig.hasPath(ENABLED) && circuitBreakerConfig.getBoolean(ENABLED);

// Load all circuit breaker configurations and cache them
Map<String, CircuitBreakerConfig> allConfigs =
circuitBreakerConfig.root().keySet().stream()
.filter(key -> !key.equals(ENABLED)) // Ignore the global enabled flag
.collect(
Collectors.toMap(
key -> key, // Circuit breaker key
key -> createCircuitBreakerConfig(circuitBreakerConfig.getConfig(key))));

// Store in cache
configCache.putAll(allConfigs);

log.info(
"Loaded {} circuit breaker configurations, Global Enabled: {}. Configs: {}",
allConfigs.size(),
circuitBreakerEnabled,
allConfigs);
}

/**
* Retrieves the CircuitBreakerConfig for a specific key. Falls back to default if key-specific
* config is not found.
*/
public CircuitBreakerConfig getConfig(String circuitBreakerKey) {
return configCache.getOrDefault(circuitBreakerKey, configCache.get(DEFAULT_CONFIG_KEY));
}

/** Checks if Circuit Breaker is globally enabled. */
public boolean isCircuitBreakerEnabled() {
return circuitBreakerEnabled;
}

private CircuitBreakerConfig createCircuitBreakerConfig(Config config) {
return CircuitBreakerConfig.custom()
.failureRateThreshold((float) config.getDouble(FAILURE_RATE_THRESHOLD))
.slowCallRateThreshold((float) config.getDouble(SLOW_CALL_RATE_THRESHOLD))
.slowCallDurationThreshold(config.getDuration(SLOW_CALL_DURATION_THRESHOLD))
.slidingWindowType(getSlidingWindowType(config.getString(SLIDING_WINDOW_TYPE)))
.slidingWindowSize(config.getInt(SLIDING_WINDOW_SIZE))
.waitDurationInOpenState(config.getDuration(WAIT_DURATION_IN_OPEN_STATE))
.permittedNumberOfCallsInHalfOpenState(
config.getInt(PERMITTED_NUMBER_OF_CALLS_IN_HALF_OPEN_STATE))
.minimumNumberOfCalls(config.getInt(MINIMUM_NUMBER_OF_CALLS))
.build();
}

private CircuitBreakerConfig.SlidingWindowType getSlidingWindowType(String slidingWindowType) {
return CircuitBreakerConfig.SlidingWindowType.valueOf(slidingWindowType);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.hypertrace.circuitbreaker.grpcutils;

import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class CircuitBreakerEventListener {
private static final Set<String> attachedCircuitBreakers = ConcurrentHashMap.newKeySet();

public static synchronized void attachListeners(CircuitBreaker circuitBreaker) {
if (!attachedCircuitBreakers.add(
circuitBreaker.getName())) { // Ensures only one listener is attached
return;
}
circuitBreaker
.getEventPublisher()
.onStateTransition(
event ->
log.info(
"State transition: {} for circuit breaker {} ",
event.getStateTransition(),
event.getCircuitBreakerName()))
.onCallNotPermitted(
event ->
log.debug(
"Call not permitted: Circuit is OPEN for circuit breaker {} ",
event.getCircuitBreakerName()))
.onEvent(
event -> {
log.debug(
"Circuit breaker event type {} for circuit breaker name {} ",
event.getEventType(),
event.getCircuitBreakerName());
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package org.hypertrace.circuitbreaker.grpcutils;

import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class CircuitBreakerInterceptor implements ClientInterceptor {

public static final CallOptions.Key<String> CIRCUIT_BREAKER_KEY =
CallOptions.Key.createWithDefault("circuitBreakerKey", "default");
private final CircuitBreakerRegistry circuitBreakerRegistry;
private final CircuitBreakerConfigProvider circuitBreakerConfigProvider;
private final CircuitBreakerMetricsNotifier circuitBreakerMetricsNotifier;

public CircuitBreakerInterceptor(
CircuitBreakerRegistry circuitBreakerRegistry,
CircuitBreakerConfigProvider circuitBreakerConfigProvider,
CircuitBreakerMetricsNotifier circuitBreakerMetricsNotifier) {
this.circuitBreakerRegistry = circuitBreakerRegistry;
this.circuitBreakerConfigProvider = circuitBreakerConfigProvider;
this.circuitBreakerMetricsNotifier = circuitBreakerMetricsNotifier;
}

// Intercepts the call and applies circuit breaker logic
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
if (!circuitBreakerConfigProvider.isCircuitBreakerEnabled()) {
return next.newCall(method, callOptions);
}

// Get circuit breaker key from CallOptions
String circuitBreakerKey = callOptions.getOption(CIRCUIT_BREAKER_KEY);
CircuitBreaker circuitBreaker = getCircuitBreaker(circuitBreakerKey);
return new ForwardingClientCall.SimpleForwardingClientCall<>(
next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
long startTime = System.nanoTime();

// Wrap response listener to track failures
Listener<RespT> wrappedListener =
new ForwardingClientCallListener.SimpleForwardingClientCallListener<>(
responseListener) {
@Override
public void onClose(Status status, Metadata trailers) {
long duration = System.nanoTime() - startTime;
if (status.isOk()) {
circuitBreaker.onSuccess(duration, TimeUnit.NANOSECONDS);
} else {
log.debug(
"Circuit Breaker '{}' detected failure. Status: {}, Description: {}",
circuitBreaker.getName(),
status.getCode(),
status.getDescription());
circuitBreaker.onError(
duration, TimeUnit.NANOSECONDS, status.asRuntimeException());
}
super.onClose(status, trailers);
}
};

super.start(wrappedListener, headers);
}

@Override
public void sendMessage(ReqT message) {
if (!circuitBreaker.tryAcquirePermission()) {
handleCircuitBreakerRejection(circuitBreakerKey, circuitBreaker);
String rejectionReason =
circuitBreaker.getState() == CircuitBreaker.State.HALF_OPEN
? "Circuit Breaker is HALF-OPEN and rejecting excess requests"
: "Circuit Breaker is OPEN and blocking requests";
throw Status.UNAVAILABLE.withDescription(rejectionReason).asRuntimeException();
}
super.sendMessage(message);
}
};
}

private void handleCircuitBreakerRejection(
String circuitBreakerKey, CircuitBreaker circuitBreaker) {
String tenantId = getTenantId(circuitBreakerKey);
if (circuitBreaker.getState() == CircuitBreaker.State.HALF_OPEN) {
circuitBreakerMetricsNotifier.incrementCount(tenantId, "circuitbreaker.halfopen.rejected");
log.debug(
"Circuit Breaker '{}' is HALF-OPEN and rejecting excess requests for tenant '{}'.",
circuitBreakerKey,
tenantId);
} else if (circuitBreaker.getState() == CircuitBreaker.State.OPEN) {
circuitBreakerMetricsNotifier.incrementCount(tenantId, "circuitbreaker.open.blocked");
log.debug(
"Circuit Breaker '{}' is OPEN. Blocking request for tenant '{}'.",
circuitBreakerKey,
tenantId);
} else {
log.debug( // Added unexpected state handling for safety
"Unexpected Circuit Breaker state '{}' for '{}'. Blocking request.",
circuitBreaker.getState(),
circuitBreakerKey);
}
}

private static String getTenantId(String circuitBreakerKey) {
if (!circuitBreakerKey.contains(".")) {
return "Unknown";
}
return circuitBreakerKey.split("\\.", 2)[0]; // Ensures only the first split
}

/** Retrieve the Circuit Breaker based on the key. */
private CircuitBreaker getCircuitBreaker(String circuitBreakerKey) {
CircuitBreaker circuitBreaker =
circuitBreakerRegistry.circuitBreaker(
circuitBreakerKey, circuitBreakerConfigProvider.getConfig(circuitBreakerKey));
CircuitBreakerEventListener.attachListeners(circuitBreaker);
return circuitBreaker;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.hypertrace.circuitbreaker.grpcutils;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Meter;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.noop.NoopCounter;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.hypertrace.core.serviceframework.metrics.PlatformMetricsRegistry;

public class CircuitBreakerMetricsNotifier {
private static final ConcurrentHashMap<String, Counter> counterMap = new ConcurrentHashMap<>();
public static final String UNKNOWN_TENANT = "unknown";

public void incrementCount(String tenantId, String counterName) {
getCounter(tenantId, counterName).increment();
}

public Counter getCounter(String tenantId, String counterName) {
if (tenantId == null || tenantId.equals(UNKNOWN_TENANT)) {
return getNoopCounter();
}
return counterMap.computeIfAbsent(
tenantId + counterName,
(unused) ->
PlatformMetricsRegistry.registerCounter(counterName, Map.of("tenantId", tenantId)));
}

private NoopCounter getNoopCounter() {
Meter.Id dummyId = new Meter.Id("noopCounter", Tags.empty(), null, null, Meter.Type.COUNTER);
return new NoopCounter(dummyId);
}
}
Loading
Loading