Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/extensions/filters/http/rate_limit_quota/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ envoy_cc_library(
srcs = ["client_impl.cc"],
hdrs = ["client_impl.h"],
deps = [
":filter_persistence",
":global_client_lib",
":quota_bucket_cache",
"//envoy/grpc:async_client_interface",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstddef>
#include <memory>
#include <utility>

#include "envoy/type/v3/ratelimit_strategy.pb.h"
#include "envoy/type/v3/token_bucket.pb.h"
Expand All @@ -23,8 +24,9 @@ void LocalRateLimitClientImpl::createBucket(
std::chrono::milliseconds fallback_ttl, bool initial_request_allowed) {
// Intentionally crash if the local client is initialized with a null global
// client or TLS slot due to a bug.
global_client_->createBucket(bucket_id, id, default_bucket_action, std::move(fallback_action),
fallback_ttl, initial_request_allowed);
tls_store_->global_client->createBucket(bucket_id, id, default_bucket_action,
std::move(fallback_action), fallback_ttl,
initial_request_allowed);
}

std::shared_ptr<CachedBucket> LocalRateLimitClientImpl::getBucket(size_t id) {
Expand Down
20 changes: 9 additions & 11 deletions source/extensions/filters/http/rate_limit_quota/client_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "source/common/grpc/typed_async_client.h"
#include "source/extensions/filters/http/common/factory_base.h"
#include "source/extensions/filters/http/rate_limit_quota/filter_persistence.h"
#include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h"
#include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h"

Expand All @@ -30,10 +31,8 @@ using GrpcAsyncClient =
class LocalRateLimitClientImpl : public RateLimitClient,
public Logger::Loggable<Logger::Id::rate_limit_quota> {
public:
explicit LocalRateLimitClientImpl(
GlobalRateLimitClientImpl* global_client,
Envoy::ThreadLocal::TypedSlot<ThreadLocalBucketsCache>& buckets_cache_tls)
: global_client_(global_client), buckets_cache_tls_(buckets_cache_tls) {}
explicit LocalRateLimitClientImpl(std::shared_ptr<GlobalTlsStores::TlsStore> tls_store)
: tls_store_(std::move(tls_store)) {}

void createBucket(const BucketId& bucket_id, size_t id, const BucketAction& default_bucket_action,
std::unique_ptr<envoy::type::v3::RateLimitStrategy> fallback_action,
Expand All @@ -45,19 +44,18 @@ class LocalRateLimitClientImpl : public RateLimitClient,

private:
inline std::shared_ptr<BucketsCache> getBucketsCache() {
return (buckets_cache_tls_.get().has_value()) ? buckets_cache_tls_.get()->quota_buckets_
: nullptr;
return (tls_store_->buckets_tls.get().has_value())
? tls_store_->buckets_tls.get()->quota_buckets_
: nullptr;
}

// Lockless access to global resources via TLS.
GlobalRateLimitClientImpl* global_client_;
ThreadLocal::TypedSlot<ThreadLocalBucketsCache>& buckets_cache_tls_;
std::shared_ptr<GlobalTlsStores::TlsStore> tls_store_;
};

inline std::unique_ptr<RateLimitClient>
createLocalRateLimitClient(GlobalRateLimitClientImpl* global_client,
ThreadLocal::TypedSlot<ThreadLocalBucketsCache>& buckets_cache_tls_) {
return std::make_unique<LocalRateLimitClientImpl>(global_client, buckets_cache_tls_);
createLocalRateLimitClient(std::shared_ptr<GlobalTlsStores::TlsStore> tls_store) {
return std::make_unique<LocalRateLimitClientImpl>(std::move(tls_store));
}

} // namespace RateLimitQuota
Expand Down
5 changes: 2 additions & 3 deletions source/extensions/filters/http/rate_limit_quota/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ Http::FilterFactoryCb RateLimitQuotaFilterFactory::createFilterFactoryFromProtoT

return [&, config = std::move(config), config_with_hash_key, tls_store = std::move(tls_store),
matcher = std::move(matcher)](Http::FilterChainFactoryCallbacks& callbacks) -> void {
std::unique_ptr<RateLimitClient> local_client =
createLocalRateLimitClient(tls_store->global_client.get(), tls_store->buckets_tls);
std::unique_ptr<RateLimitClient> local_client = createLocalRateLimitClient(tls_store);

callbacks.addStreamFilter(std::make_shared<RateLimitQuotaFilter>(
config, context, std::move(local_client), config_with_hash_key, matcher));
config, context, tls_store, std::move(local_client), config_with_hash_key, matcher));
};
}

Expand Down
8 changes: 7 additions & 1 deletion source/extensions/filters/http/rate_limit_quota/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "source/common/protobuf/utility.h"
#include "source/extensions/filters/http/common/factory_base.h"
#include "source/extensions/filters/http/common/pass_through_filter.h"
#include "source/extensions/filters/http/rate_limit_quota/filter_persistence.h"
#include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h"
#include "source/extensions/filters/http/rate_limit_quota/matcher.h"
#include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h"
Expand Down Expand Up @@ -50,11 +51,13 @@ class RateLimitQuotaFilter : public Http::PassThroughFilter,
public:
RateLimitQuotaFilter(FilterConfigConstSharedPtr config,
Server::Configuration::FactoryContext& factory_context,
std::shared_ptr<GlobalTlsStores::TlsStore> tls_store,
std::unique_ptr<RateLimitClient> local_client,
Grpc::GrpcServiceConfigWithHashKey config_with_hash_key,
Matcher::MatchTreeSharedPtr<Http::HttpMatchingData> matcher)
: config_(std::move(config)), config_with_hash_key_(config_with_hash_key),
factory_context_(factory_context), matcher_(matcher), client_(std::move(local_client)),
factory_context_(factory_context), matcher_(matcher), tls_store_(std::move(tls_store)),
client_(std::move(local_client)),
time_source_(factory_context.serverFactoryContext().mainThreadDispatcher().timeSource()) {}

Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap&, bool end_stream) override;
Expand Down Expand Up @@ -95,6 +98,9 @@ class RateLimitQuotaFilter : public Http::PassThroughFilter,
// shouldn't be recorded.
bool first_skipped_match_ = true;

// Anchors the lifetime of the global client and its resources for the
// duration of the filter.
std::shared_ptr<GlobalTlsStores::TlsStore> tls_store_;
// Own a local, filter-specific client to provider functions needed by worker
// threads.
std::unique_ptr<RateLimitClient> client_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "envoy/grpc/async_client_manager.h"
#include "envoy/server/factory_context.h"

#include "source/common/common/logger.h"
#include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h"
#include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h"

Expand All @@ -31,7 +32,16 @@ initTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_with_hash_key,
// Quota bucket & global client TLS objects are created with the config and
// kept alive via shared_ptr to a storage struct. The local rate limit client
// in each filter instance assumes that the slot will outlive them.
std::shared_ptr<TlsStore> tls_store = std::make_shared<TlsStore>(context, target_address, domain);
Envoy::Event::Dispatcher& dispatcher = context.serverFactoryContext().mainThreadDispatcher();
auto deleter = [main_dispatcher = &dispatcher](TlsStore* store) {
if (main_dispatcher->isThreadSafe()) {
delete store;
} else {
main_dispatcher->post([store]() { delete store; });
}
};

std::shared_ptr<TlsStore> tls_store(new TlsStore(context, target_address, domain), deleter);
auto tl_buckets_cache =
std::make_shared<ThreadLocalBucketsCache>(std::make_shared<BucketsCache>());
tls_store->buckets_tls.set(
Expand All @@ -46,7 +56,7 @@ initTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_with_hash_key,

// Create the global client resource to be shared via TLS to all worker
// threads (accessed through a filter-specific LocalRateLimitClient).
std::unique_ptr<GlobalRateLimitClientImpl> tl_global_client = createGlobalRateLimitClientImpl(
std::shared_ptr<GlobalRateLimitClientImpl> tl_global_client = createGlobalRateLimitClientImpl(
context, domain, reporting_interval, tls_store->buckets_tls, config_with_hash_key);
tls_store->global_client = std::move(tl_global_client);

Expand All @@ -60,11 +70,8 @@ GlobalTlsStores::getTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_wi
Server::Configuration::FactoryContext& context,
absl::string_view target_address, absl::string_view domain) {
TlsStoreIndex index = std::make_pair(std::string(target_address), std::string(domain));
// Find existing TlsStore or initialize a new one.
auto it = stores().find(index);
if (it != stores().end()) {
ENVOY_LOG(debug, "Found existing cache & RLQS client for target ({}) and domain ({}).",
index.first, index.second);
return it->second.lock();
}
ENVOY_LOG(debug, "Creating a new cache & RLQS client for target ({}) and domain ({}).",
Expand All @@ -76,6 +83,14 @@ GlobalTlsStores::getTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_wi
return tls_store;
}

void GlobalTlsStores::clearTlsStore(const std::pair<std::string, std::string>& index) {
stores().erase(index);
if (stores().empty() && getEmptiedCb() != nullptr) {
getEmptiedCb()();
getEmptiedCb() = nullptr;
}
}

} // namespace RateLimitQuota
} // namespace HttpFilters
} // namespace Extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ namespace RateLimitQuota {
// GlobalTlsStores holds a singleton hash map of rate_limit_quota TLS stores,
// indexed by their combined RLQS server targets & domains.
//
// This follows the data sharing model of FactoryRegistry, and similarly does
// not guarantee thread-safety. Additions or removals of indices can only be
// done on the main thread, as part of filter factory creation and garbage
// collection respectively.
//
// Note, multiple RLQS clients with different configs (e.g. timeouts) can hit
// the same index (destination + domain). The global map does not guarantee
// which config will be selected for the client creation.
Expand All @@ -52,12 +47,23 @@ class GlobalTlsStores : public Logger::Loggable<Logger::Id::rate_limit_quota> {
// The global client must be cleaned up by the server main thread before
// it shuts down.
if (global_client != nullptr) {
main_dispatcher_.deferredDelete(std::move(global_client));
// SharedClientDeleter wraps a shared_ptr to GlobalRateLimitClientImpl
// to allow it to be used with deferredDelete. It ensures
// deleteIsPending() is called and the shared_ptr is dropped on the main
// thread.
struct SharedClientDeleter : public Event::DeferredDeletable {
SharedClientDeleter(std::shared_ptr<GlobalRateLimitClientImpl> client)
: client_(std::move(client)) {}
void deleteIsPending() override { client_->deleteIsPending(); }
std::shared_ptr<GlobalRateLimitClientImpl> client_;
};
main_dispatcher_.deferredDelete(
std::make_unique<SharedClientDeleter>(std::move(global_client)));
}
GlobalTlsStores::clearTlsStore(std::make_pair(target_address_, domain_));
}

std::unique_ptr<GlobalRateLimitClientImpl> global_client = nullptr;
std::shared_ptr<GlobalRateLimitClientImpl> global_client = nullptr;
ThreadLocal::TypedSlot<ThreadLocalBucketsCache> buckets_tls;

private:
Expand Down Expand Up @@ -105,13 +111,7 @@ class GlobalTlsStores : public Logger::Loggable<Logger::Id::rate_limit_quota> {
}

// Clear a specified index when it is no longer captured by any filter factories.
static void clearTlsStore(const TlsStoreIndex& index) {
stores().erase(index);
if (stores().empty() && getEmptiedCb() != nullptr) {
getEmptiedCb()();
getEmptiedCb() = nullptr;
}
}
static void clearTlsStore(const TlsStoreIndex& index);
};

} // namespace RateLimitQuota
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ void GlobalRateLimitClientImpl::deleteIsPending() {
async_client_->reset();
}

// Atomically read usage counters & reset them to 0.
void getUsageFromBucket(const CachedBucket& cached_bucket, TimeSource& time_source,
BucketQuotaUsage& usage) {
std::shared_ptr<QuotaUsage> cached_usage = cached_bucket.quota_usage;
Expand Down Expand Up @@ -152,12 +153,16 @@ void GlobalRateLimitClientImpl::createBucket(const BucketId& bucket_id, size_t i
std::chrono::milliseconds fallback_ttl,
bool initial_request_allowed) {
// Mutable to move fallback_action ownership into the main thread then into
// the created bucket.
main_dispatcher_.post([&, bucket_id, id, default_bucket_action,
// the created bucket. Captures a weak_ptr to 'this' as the posted operation
// can outlive the global client itself.
main_dispatcher_.post([weak_this = weak_from_this(), bucket_id, id, default_bucket_action,
fallback_action_ptr = std::move(fallback_action), fallback_ttl,
initial_request_allowed]() mutable {
createBucketImpl(bucket_id, id, default_bucket_action, std::move(fallback_action_ptr),
fallback_ttl, initial_request_allowed);
if (auto shared_this = weak_this.lock()) {
shared_this->createBucketImpl(bucket_id, id, default_bucket_action,
std::move(fallback_action_ptr), fallback_ttl,
initial_request_allowed);
}
});
}

Expand Down Expand Up @@ -274,8 +279,13 @@ void GlobalRateLimitClientImpl::onReceiveMessage(RateLimitQuotaResponsePtr&& res
if (response == nullptr) {
return;
}
main_dispatcher_.post(
[&, response = std::move(response)]() { onQuotaResponseImpl(response.get()); });
// Captures a weak_ptr to 'this' as the posted operation can outlive the
// global client itself.
main_dispatcher_.post([weak_this = weak_from_this(), response = std::move(response)]() {
if (auto shared_this = weak_this.lock()) {
shared_this->onQuotaResponseImpl(response.get());
}
});
}

// Updating a cached_bucket shouldn't reset the cached token bucket if the
Expand Down Expand Up @@ -432,12 +442,17 @@ void GlobalRateLimitClientImpl::onSendReportsTimer() {

void GlobalRateLimitClientImpl::startActionExpirationTimer(CachedBucket* cached_bucket, size_t id) {
// Pointer safety as all writes are against the source-of-truth.
cached_bucket->action_expiration_timer = main_dispatcher_.createTimer([&, id, cached_bucket]() {
onActionExpirationTimer(cached_bucket, id);
if (callbacks_ != nullptr) {
callbacks_->onActionExpiration();
}
});
cached_bucket->action_expiration_timer =
main_dispatcher_.createTimer([weak_this = weak_from_this(), id, cached_bucket]() {
auto shared_this = weak_this.lock();
if (!shared_this) {
return;
}
shared_this->onActionExpirationTimer(cached_bucket, id);
if (shared_this->callbacks_ != nullptr) {
shared_this->callbacks_->onActionExpiration();
}
});
std::chrono::milliseconds ttl = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::seconds(cached_bucket->cached_action->quota_assignment_action()
.assignment_time_to_live()
Expand Down Expand Up @@ -521,12 +536,17 @@ void GlobalRateLimitClientImpl::onActionExpirationTimer(CachedBucket* bucket, si
void GlobalRateLimitClientImpl::startFallbackExpirationTimer(CachedBucket* cached_bucket,
size_t id) {
// Pointer safety as all writes are against the source-of-truth.
cached_bucket->fallback_expiration_timer = main_dispatcher_.createTimer([&, id, cached_bucket]() {
onFallbackExpirationTimer(cached_bucket, id);
if (callbacks_ != nullptr) {
callbacks_->onFallbackExpiration();
}
});
cached_bucket->fallback_expiration_timer =
main_dispatcher_.createTimer([weak_this = weak_from_this(), id, cached_bucket]() {
auto shared_this = weak_this.lock();
if (!shared_this) {
return;
}
shared_this->onFallbackExpirationTimer(cached_bucket, id);
if (shared_this->callbacks_ != nullptr) {
shared_this->callbacks_->onFallbackExpiration();
}
});
cached_bucket->fallback_expiration_timer->enableTimer(cached_bucket->fallback_ttl);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class GlobalRateLimitClientCallbacks {
class GlobalRateLimitClientImpl : public Grpc::AsyncStreamCallbacks<
envoy::service::rate_limit_quota::v3::RateLimitQuotaResponse>,
public Event::DeferredDeletable,
// Required to safely capture 'this' in asynchronous callbacks
// posted to the main dispatcher, as they may outlive the client.
public std::enable_shared_from_this<GlobalRateLimitClientImpl>,
public Logger::Loggable<Logger::Id::rate_limit_quota> {
public:
// Note: rlqs_client is owned directly to ensure that it does not outlive the
Expand Down Expand Up @@ -188,14 +191,14 @@ class GlobalRateLimitClientImpl : public Grpc::AsyncStreamCallbacks<
* Create a shared rate limit client. It should be shared to each worker
* thread via TLS.
*/
inline std::unique_ptr<GlobalRateLimitClientImpl>
inline std::shared_ptr<GlobalRateLimitClientImpl>
createGlobalRateLimitClientImpl(Server::Configuration::FactoryContext& context,
absl::string_view domain_name,
std::chrono::milliseconds send_reports_interval,
ThreadLocal::TypedSlot<ThreadLocalBucketsCache>& buckets_tls,
const Grpc::GrpcServiceConfigWithHashKey& config_with_hash_key) {
Envoy::Event::Dispatcher& main_dispatcher = context.serverFactoryContext().mainThreadDispatcher();
return std::make_unique<GlobalRateLimitClientImpl>(config_with_hash_key, context, domain_name,
return std::make_shared<GlobalRateLimitClientImpl>(config_with_hash_key, context, domain_name,
send_reports_interval, buckets_tls,
main_dispatcher);
}
Expand Down
Loading
Loading