Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lib/BatchMessageContainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ void BatchMessageContainer::clear() {
}

std::unique_ptr<OpSendMsg> BatchMessageContainer::createOpSendMsg(const FlushCallback& flushCallback) {
auto op = createOpSendMsgHelper(flushCallback, batch_);
auto op = createOpSendMsgHelper(batch_);
if (flushCallback) {
op->addTrackerCallback(flushCallback);
}
clear();
return op;
}
Expand Down
39 changes: 3 additions & 36 deletions lib/BatchMessageContainerBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@
*/
#include "BatchMessageContainerBase.h"

#include "ClientConnection.h"
#include "CompressionCodec.h"
#include "MessageAndCallbackBatch.h"
#include "MessageCrypto.h"
#include "MessageImpl.h"
#include "OpSendMsg.h"
#include "ProducerImpl.h"
#include "PulsarApi.pb.h"
#include "SharedBuffer.h"

namespace pulsar {
Expand All @@ -40,38 +36,9 @@ BatchMessageContainerBase::BatchMessageContainerBase(const ProducerImpl& produce
BatchMessageContainerBase::~BatchMessageContainerBase() {}

std::unique_ptr<OpSendMsg> BatchMessageContainerBase::createOpSendMsgHelper(
const FlushCallback& flushCallback, const MessageAndCallbackBatch& batch) const {
auto sendCallback = batch.createSendCallback(flushCallback);
if (batch.empty()) {
return OpSendMsg::create(ResultOperationNotSupported, std::move(sendCallback));
}

MessageImplPtr impl = batch.msgImpl();
impl->metadata.set_num_messages_in_batch(batch.size());
auto compressionType = producerConfig_.getCompressionType();
if (compressionType != CompressionNone) {
impl->metadata.set_compression(static_cast<proto::CompressionType>(compressionType));
impl->metadata.set_uncompressed_size(impl->payload.readableBytes());
}
impl->payload = CompressionCodecProvider::getCodec(compressionType).encode(impl->payload);

auto msgCrypto = msgCryptoWeakPtr_.lock();
if (msgCrypto && producerConfig_.isEncryptionEnabled()) {
SharedBuffer encryptedPayload;
if (!msgCrypto->encrypt(producerConfig_.getEncryptionKeys(), producerConfig_.getCryptoKeyReader(),
impl->metadata, impl->payload, encryptedPayload)) {
return OpSendMsg::create(ResultCryptoError, std::move(sendCallback));
}
impl->payload = encryptedPayload;
}

if (impl->payload.readableBytes() > ClientConnection::getMaxMessageSize()) {
return OpSendMsg::create(ResultMessageTooBig, std::move(sendCallback));
}

return OpSendMsg::create(impl->metadata, batch.messagesCount(), batch.messagesSize(),
producerConfig_.getSendTimeout(), batch.createSendCallback(flushCallback),
nullptr, producerId_, impl->payload);
MessageAndCallbackBatch& batch) const {
auto crypto = msgCryptoWeakPtr_.lock();
return batch.createOpSendMsg(producerId_, producerConfig_, crypto.get());
}

} // namespace pulsar
3 changes: 1 addition & 2 deletions lib/BatchMessageContainerBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ class BatchMessageContainerBase : public boost::noncopyable {
void updateStats(const Message& msg);
void resetStats();

std::unique_ptr<OpSendMsg> createOpSendMsgHelper(const FlushCallback& flushCallback,
const MessageAndCallbackBatch& batch) const;
std::unique_ptr<OpSendMsg> createOpSendMsgHelper(MessageAndCallbackBatch& flushCallback) const;

virtual void clear() = 0;
};
Expand Down
25 changes: 11 additions & 14 deletions lib/BatchMessageKeyBasedContainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,19 @@ void BatchMessageKeyBasedContainer::clear() {

std::vector<std::unique_ptr<OpSendMsg>> BatchMessageKeyBasedContainer::createOpSendMsgs(
const FlushCallback& flushCallback) {
// Sorted the batches by sequence id
std::vector<const MessageAndCallbackBatch*> sortedBatches;
for (const auto& kv : batches_) {
sortedBatches.emplace_back(&kv.second);
// Store raw pointers to use std::sort
std::vector<OpSendMsg*> rawOpSendMsgs;
for (auto& kv : batches_) {
rawOpSendMsgs.emplace_back(createOpSendMsgHelper(kv.second).release());
}
std::sort(sortedBatches.begin(), sortedBatches.end(),
[](const MessageAndCallbackBatch* lhs, const MessageAndCallbackBatch* rhs) {
return lhs->sequenceId() < rhs->sequenceId();
});
std::sort(rawOpSendMsgs.begin(), rawOpSendMsgs.end(), [](const OpSendMsg* lhs, const OpSendMsg* rhs) {
return lhs->sendArgs->sequenceId < rhs->sendArgs->sequenceId;
});
rawOpSendMsgs.back()->addTrackerCallback(flushCallback);

std::vector<std::unique_ptr<OpSendMsg>> opSendMsgs{sortedBatches.size()};
for (size_t i = 0; i + 1 < opSendMsgs.size(); i++) {
opSendMsgs[i].reset(createOpSendMsgHelper(nullptr, *sortedBatches[i]).release());
}
if (!opSendMsgs.empty()) {
opSendMsgs.back().reset(createOpSendMsgHelper(flushCallback, *sortedBatches.back()).release());
std::vector<std::unique_ptr<OpSendMsg>> opSendMsgs{rawOpSendMsgs.size()};
for (size_t i = 0; i < opSendMsgs.size(); i++) {
opSendMsgs[i].reset(rawOpSendMsgs[i]);
}
clear();
return opSendMsgs;
Expand Down
5 changes: 3 additions & 2 deletions lib/Commands.cc
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,8 @@ void Commands::initBatchMessageMetadata(const Message& msg, pulsar::proto::Messa
uint64_t Commands::serializeSingleMessageInBatchWithPayload(const Message& msg, SharedBuffer& batchPayLoad,
unsigned long maxMessageSizeInBytes) {
const auto& msgMetadata = msg.impl_->metadata;
SingleMessageMetadata metadata;
thread_local SingleMessageMetadata metadata;
metadata.Clear();
if (msgMetadata.has_partition_key()) {
metadata.set_partition_key(msgMetadata.partition_key());
}
Expand Down Expand Up @@ -868,7 +869,7 @@ uint64_t Commands::serializeSingleMessageInBatchWithPayload(const Message& msg,
int payloadSize = msg.impl_->payload.readableBytes();
metadata.set_payload_size(payloadSize);

int msgMetadataSize = metadata.ByteSize();
auto msgMetadataSize = metadata.ByteSizeLong();

unsigned long requiredSpace = sizeof(uint32_t) + msgMetadataSize + payloadSize;
if (batchPayLoad.writableBytes() <= sizeof(uint32_t) + msgMetadataSize + payloadSize) {
Expand Down
96 changes: 66 additions & 30 deletions lib/MessageAndCallbackBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,59 +22,95 @@

#include "ClientConnection.h"
#include "Commands.h"
#include "LogUtils.h"
#include "MessageImpl.h"

DECLARE_LOG_OBJECT()
#include "CompressionCodec.h"
#include "MessageCrypto.h"
#include "OpSendMsg.h"
#include "PulsarApi.pb.h"

namespace pulsar {

MessageAndCallbackBatch::MessageAndCallbackBatch() {}

MessageAndCallbackBatch::~MessageAndCallbackBatch() {}

void MessageAndCallbackBatch::add(const Message& msg, const SendCallback& callback) {
if (empty()) {
msgImpl_.reset(new MessageImpl);
Commands::initBatchMessageMetadata(msg, msgImpl_->metadata);
if (callbacks_.empty()) {
metadata_.reset(new proto::MessageMetadata);
Commands::initBatchMessageMetadata(msg, *metadata_);
sequenceId_ = metadata_->sequence_id();
}
LOG_DEBUG(" Before serialization payload size in bytes = " << msgImpl_->payload.readableBytes());
sequenceId_ = Commands::serializeSingleMessageInBatchWithPayload(msg, msgImpl_->payload,
ClientConnection::getMaxMessageSize());
LOG_DEBUG(" After serialization payload size in bytes = " << msgImpl_->payload.readableBytes());
messages_.emplace_back(msg);
callbacks_.emplace_back(callback);

++messagesCount_;
messagesSize_ += msg.getLength();
}

std::unique_ptr<OpSendMsg> MessageAndCallbackBatch::createOpSendMsg(
uint64_t producerId, const ProducerConfiguration& producerConfig, MessageCrypto* crypto) {
auto callback = createSendCallback();
if (empty()) {
return OpSendMsg::create(ResultOperationNotSupported, std::move(callback));
}

// The magic number 64 is just an estimated size increment after setting some fields of the
// SingleMessageMetadata. It does not have to be accurate because it's only used to reduce the
// reallocation of the payload buffer.
static const size_t kEstimatedHeaderSize =
Comment thread
BewareMyPower marked this conversation as resolved.
Outdated
sizeof(uint32_t) + proto::MessageMetadata{}.ByteSizeLong() + 64;
const auto maxMessageSize = ClientConnection::getMaxMessageSize();
// Estimate the buffer size just to avoid resizing the buffer
size_t maxBufferSize = kEstimatedHeaderSize * messages_.size();
for (const auto& msg : messages_) {
maxBufferSize += msg.getLength();
}
auto payload = SharedBuffer::allocate(maxBufferSize);
for (const auto& msg : messages_) {
sequenceId_ = Commands::serializeSingleMessageInBatchWithPayload(msg, payload, maxMessageSize);
}
metadata_->set_sequence_id(sequenceId_);
metadata_->set_num_messages_in_batch(messages_.size());
auto compressionType = producerConfig.getCompressionType();
if (compressionType != CompressionNone) {
metadata_->set_compression(static_cast<proto::CompressionType>(compressionType));
metadata_->set_uncompressed_size(payload.readableBytes());
}
payload = CompressionCodecProvider::getCodec(compressionType).encode(payload);

if (producerConfig.isEncryptionEnabled() && crypto) {
SharedBuffer encryptedPayload;
if (!crypto->encrypt(producerConfig.getEncryptionKeys(), producerConfig.getCryptoKeyReader(),
*metadata_, payload, encryptedPayload)) {
return OpSendMsg::create(ResultCryptoError, std::move(callback));
}
payload = encryptedPayload;
}

if (payload.readableBytes() > ClientConnection::getMaxMessageSize()) {
return OpSendMsg::create(ResultMessageTooBig, std::move(callback));
}

auto op = OpSendMsg::create(*metadata_, callbacks_.size(), messagesSize_, producerConfig.getSendTimeout(),
std::move(callback), nullptr, producerId, payload);
clear();
return op;
}

void MessageAndCallbackBatch::clear() {
msgImpl_.reset();
messages_.clear();
callbacks_.clear();
messagesCount_ = 0;
messagesSize_ = 0;
}

static void completeSendCallbacks(const std::vector<SendCallback>& callbacks, Result result,
const MessageId& id) {
int32_t numOfMessages = static_cast<int32_t>(callbacks.size());
LOG_DEBUG("Batch complete [Result = " << result << "] [numOfMessages = " << numOfMessages << "]");
for (int32_t i = 0; i < numOfMessages; i++) {
callbacks[i](result, MessageIdBuilder::from(id).batchIndex(i).batchSize(numOfMessages).build());
}
}

void MessageAndCallbackBatch::complete(Result result, const MessageId& id) const {
completeSendCallbacks(callbacks_, result, id);
}

SendCallback MessageAndCallbackBatch::createSendCallback(const FlushCallback& flushCallback) const {
SendCallback MessageAndCallbackBatch::createSendCallback() const {
const auto& callbacks = callbacks_;
if (flushCallback) {
return [callbacks, flushCallback](Result result, const MessageId& id) {
completeSendCallbacks(callbacks, result, id);
flushCallback(result);
};
} else {
return [callbacks] // save a copy of `callbacks_`
(Result result, const MessageId& id) { completeSendCallbacks(callbacks, result, id); };
}
return [callbacks](Result result, const MessageId& id) { completeSendCallbacks(callbacks, result, id); };
}

} // namespace pulsar
42 changes: 19 additions & 23 deletions lib/MessageAndCallbackBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,24 @@

#include <atomic>
#include <boost/noncopyable.hpp>
#include <memory>
#include <vector>

namespace pulsar {

class MessageImpl;
using MessageImplPtr = std::shared_ptr<MessageImpl>;
struct OpSendMsg;
class MessageCrypto;
using FlushCallback = std::function<void(Result)>;

class MessageAndCallbackBatch : public boost::noncopyable {
namespace proto {
class MessageMetadata;
}

class MessageAndCallbackBatch final : public boost::noncopyable {
public:
MessageAndCallbackBatch();
~MessageAndCallbackBatch();

// Wrapper methods of STL container
bool empty() const noexcept { return callbacks_.empty(); }
size_t size() const noexcept { return callbacks_.size(); }
Expand All @@ -46,34 +54,22 @@ class MessageAndCallbackBatch : public boost::noncopyable {
*/
void add(const Message& msg, const SendCallback& callback);

/**
* Clear the internal stats
*/
void clear();
std::unique_ptr<OpSendMsg> createOpSendMsg(uint64_t producerId,
const ProducerConfiguration& producerConfig,
MessageCrypto* crypto);

/**
* Complete all the callbacks with given parameters
*
* @param result this batch's send result
* @param id this batch's message id
*/
void complete(Result result, const MessageId& id) const;

SendCallback createSendCallback(const FlushCallback& flushCallback) const;

const MessageImplPtr& msgImpl() const { return msgImpl_; }
uint64_t sequenceId() const noexcept { return sequenceId_; }

uint32_t messagesCount() const { return messagesCount_; }
uint64_t messagesSize() const { return messagesSize_; }
void clear();

private:
MessageImplPtr msgImpl_;
std::unique_ptr<proto::MessageMetadata> metadata_;
std::vector<Message> messages_;
std::vector<SendCallback> callbacks_;
std::atomic<uint64_t> sequenceId_{static_cast<uint64_t>(-1L)};

uint32_t messagesCount_{0};
uint64_t messagesSize_{0ull};

SendCallback createSendCallback() const;
};

} // namespace pulsar
Expand Down
6 changes: 4 additions & 2 deletions lib/OpSendMsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct SendArguments {
const uint64_t producerId;
const uint64_t sequenceId;
const proto::MessageMetadata metadata;
const SharedBuffer payload;
SharedBuffer payload;

SendArguments(uint64_t producerId, uint64_t sequenceId, const proto::MessageMetadata& metadata,
const SharedBuffer& payload)
Expand Down Expand Up @@ -73,7 +73,9 @@ struct OpSendMsg {
}

void addTrackerCallback(std::function<void(Result)> trackerCallback) {
trackerCallbacks.emplace_back(trackerCallback);
if (trackerCallback) {
trackerCallbacks.emplace_back(trackerCallback);
}
}

private:
Expand Down