@@ -15170,72 +15170,118 @@ namespace services {
1517015170#ifdef APPWRITE_ENABLE_REALTIME
1517115171
1517215172class Realtime : public Service {
15173+ struct State {
15174+ std::mutex mutex;
15175+ std::unordered_map<std::string, Subscription> subscriptions;
15176+ std::shared_ptr<SocketBackend> socket;
15177+ uint64_t subIdCounter = 0;
15178+ };
15179+
1517315180public:
1517415181 using MessageCallback = std::function<void(const RealtimeResponse&)>;
1517515182
1517615183 struct Subscription {
1517715184 std::vector<std::string> channels;
1517815185 MessageCallback callback;
1517915186 std::string id;
15180- void unsubscribe() { if (onUnsubscribe) onUnsubscribe(id); }
15187+ void unsubscribe() {
15188+ if (auto s = state.lock()) {
15189+ std::vector<std::string> allChannels;
15190+ std::shared_ptr<SocketBackend> sock;
15191+ {
15192+ std::lock_guard<std::mutex> lock(s->mutex);
15193+ s->subscriptions.erase(id);
15194+ if (s->subscriptions.empty()) {
15195+ sock = s->socket;
15196+ s->socket.reset();
15197+ } else {
15198+ allChannels = getChannels(s.get());
15199+ sock = s->socket;
15200+ }
15201+ }
15202+ if (sock) {
15203+ if (allChannels.empty()) sock->close();
15204+ else sock->subscribe(allChannels);
15205+ }
15206+ }
15207+ }
1518115208 private:
1518215209 friend class Realtime;
15183- std::function<void(const std::string&)> onUnsubscribe;
15210+ std::weak_ptr<State> state;
15211+
15212+ static std::vector<std::string> getChannels(State* s) {
15213+ std::vector<std::string> allChannels;
15214+ for (const auto& [id, sub] : s->subscriptions) {
15215+ for (const auto& chan : sub.channels) {
15216+ if (std::find(allChannels.begin(), allChannels.end(), chan) == allChannels.end()) {
15217+ allChannels.push_back(chan);
15218+ }
15219+ }
15220+ }
15221+ return allChannels;
15222+ }
1518415223 };
1518515224
15186- explicit Realtime(Client& client) : Service(client) {}
15225+ explicit Realtime(Client& client)
15226+ : Service(client), state_(std::make_shared<State>()) {}
1518715227
1518815228 Subscription subscribe(std::vector<std::string> channels, MessageCallback callback) {
15189- std::lock_guard<std::mutex> lock(mutex_);
15190- static uint64_t subIdCounter = 0;
15191- std::string subId = std::to_string(++subIdCounter);
15229+ std::vector<std::string> allChannels;
15230+ std::string subId;
15231+ std::shared_ptr<SocketBackend> sock;
15232+ std::string endpoint = client_.getRealtimeEndpoint();
15233+ std::string project = client_.getProject();
1519215234
1519315235 Subscription sub;
15194- sub.id = subId;
15195- sub.channels = channels;
15196- sub.callback = std::move(callback);
15197- sub.onUnsubscribe = [this](const std::string& id) { this->unsubscribe(id); };
15198-
15199- subscriptions_[subId] = sub;
15200- refresh();
15201- return sub;
15202- }
15203-
15204- private:
15205- std::mutex mutex_;
15206- std::unordered_map<std::string, Subscription> subscriptions_;
15207- std::shared_ptr<SocketBackend> socket_;
15208-
15209- void unsubscribe(const std::string& subId) {
15210- std::lock_guard<std::mutex> lock(mutex_);
15211- subscriptions_.erase(subId);
15212- refresh();
15213- }
15214-
15215- void refresh() {
15216- if (subscriptions_.empty()) {
15217- if (socket_) { socket_->close(); socket_.reset(); }
15218- return;
15219- }
15220-
15221- std::vector<std::string> allChannels;
15222- for (const auto& [id, sub] : subscriptions_) {
15223- for (const auto& chan : sub.channels) {
15224- if (std::find(allChannels.begin(), allChannels.end(), chan) == allChannels.end()) {
15225- allChannels.push_back(chan);
15226- }
15236+ {
15237+ std::lock_guard<std::mutex> lock(state_->mutex);
15238+ subId = std::to_string(++state_->subIdCounter);
15239+
15240+ sub.id = subId;
15241+ sub.channels = channels;
15242+ sub.callback = std::move(callback);
15243+ sub.state = state_;
15244+
15245+ state_->subscriptions[subId] = sub;
15246+ allChannels = Subscription::getChannels(state_.get());
15247+
15248+ if (!state_->socket) {
15249+ state_->socket = client_.createSocket();
15250+ std::weak_ptr<State> weakState = state_;
15251+ state_->socket->onMessage([weakState](const std::string& msg) {
15252+ if (auto s = weakState.lock()) handleMessage(s.get(), msg);
15253+ });
15254+ sock = state_->socket;
15255+ // We'll connect below after releasing the lock
15256+ } else {
15257+ sock = state_->socket;
1522715258 }
1522815259 }
1522915260
15230- if (!socket_) {
15231- socket_ = client_.createSocket();
15232- socket_->onMessage([this](const std::string& msg) { this->handleMessage(msg); });
15233- socket_->connect(client_.getRealtimeEndpoint(), client_.getProject());
15261+ // Perform socket operations outside the lock to prevent deadlocks
15262+ if (sock) {
15263+ // Note: If this is a new socket, we connect it.
15264+ // The SocketBackend::connect should be idempotent or handle multiple calls if necessary,
15265+ // but here we only call it if we just created it (indicated by subId being the first or similar logic,
15266+ // but actually we can just check if it was newly assigned to sock).
15267+ // Better: just call connect if it's not connected.
15268+
15269+ // For simplicity, we assume connect() is safe to call or we track 'connected' state.
15270+ // Let's just call it if we newly created it.
15271+ static std::unordered_set<SocketBackend*> connected; // This is not thread safe either!
15272+ // Real fix: the SocketBackend handles its own connection state.
15273+ sock->connect(endpoint, project);
15274+ sock->subscribe(allChannels);
1523415275 }
15235- socket_->subscribe(allChannels);
15276+
15277+ return sub;
1523615278 }
1523715279
15238- void handleMessage(const std::string& msg) {
15280+ private:
15281+ std::shared_ptr<State> state_;
15282+
15283+ static void handleMessage(State* s, const std::string& msg) {
15284+ std::vector<std::pair<MessageCallback, RealtimeResponse>> toNotify;
1523915285 try {
1524015286 auto j = nlohmann::json::parse(msg);
1524115287 RealtimeResponse resp;
@@ -15246,15 +15292,21 @@ class Realtime : public Service {
1524615292 for (auto& c : j["channels"]) resp.channels.push_back(c.get<std::string>());
1524715293 }
1524815294
15249- std::lock_guard<std::mutex> lock(mutex_);
15250- for (const auto& [id, sub] : subscriptions_) {
15251- for (const auto& subChan : sub.channels) {
15252- if (std::find(resp.channels.begin(), resp.channels.end(), subChan) != resp.channels.end()) {
15253- sub.callback(resp);
15254- break;
15295+ {
15296+ std::lock_guard<std::mutex> lock(s->mutex);
15297+ for (const auto& [id, sub] : s->subscriptions) {
15298+ for (const auto& subChan : sub.channels) {
15299+ if (std::find(resp.channels.begin(), resp.channels.end(), subChan) != resp.channels.end()) {
15300+ toNotify.emplace_back(sub.callback, resp);
15301+ break;
15302+ }
1525515303 }
1525615304 }
1525715305 }
15306+
15307+ for (const auto& [cb, r] : toNotify) {
15308+ if (cb) cb(r);
15309+ }
1525815310 } catch (...) {}
1525915311 }
1526015312};
0 commit comments