Solves websocket memory leak (#1080)

* removes raw pointers to websocket connection
* refactor do_write
* checks the right buffer for early termination
This commit is contained in:
Stefano Petrilli
2025-09-17 08:09:00 +02:00
committed by GitHub
parent ef67fd1e73
commit 9609ba2196
3 changed files with 90 additions and 83 deletions

View File

@@ -627,12 +627,12 @@ namespace crow
}
void add_websocket(crow::websocket::connection* conn)
void add_websocket(std::shared_ptr<websocket::connection> conn)
{
websockets_.push_back(conn);
}
void remove_websocket(crow::websocket::connection* conn)
void remove_websocket(std::shared_ptr<websocket::connection> conn)
{
websockets_.erase(std::remove(websockets_.begin(), websockets_.end(), conn), websockets_.end());
}
@@ -846,7 +846,7 @@ namespace crow
bool server_started_{false};
std::condition_variable cv_started_;
std::mutex start_mutex_;
std::vector<crow::websocket::connection*> websockets_;
std::vector<std::shared_ptr<websocket::connection>> websockets_;
};
/// \brief Alias of Crow<Middlewares...>. Useful if you want

View File

@@ -445,17 +445,19 @@ namespace crow // NOTE: Already documented in "crow/app.h"
void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<SocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<SocketAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}
void handle_upgrade(const request& req, response&, UnixSocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<UnixSocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<UnixSocketAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}
#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
{
new crow::websocket::Connection<SSLAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<SSLAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}
#endif

View File

@@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <memory>
#include "crow/logging.h"
#include "crow/socket_adaptors.h"
#include "crow/http_request.h"
@@ -102,36 +103,34 @@ namespace crow // NOTE: Already documented in "crow/app.h"
/// A websocket connection.
template<typename Adaptor, typename Handler>
class Connection : public connection
class Connection : public connection, public std::enable_shared_from_this<Connection<Adaptor, Handler>>
{
public:
/// Constructor for a connection.
/// Factory for a connection.
///
/// Requires a request with an "Upgrade: websocket" header.<br>
/// Automatically handles the handshake.
Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler,
uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler,
bool mirror_protocols):
adaptor_(std::move(adaptor)),
handler_(handler),
max_payload_bytes_(max_payload),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)),
accept_handler_(std::move(accept_handler))
static void create(const crow::request& req, Adaptor adaptor, Handler* handler,
uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler,
bool mirror_protocols)
{
auto conn = std::shared_ptr<Connection>(new Connection(std::move(adaptor),
handler, max_payload,
std::move(open_handler),
std::move(message_handler),
std::move(close_handler),
std::move(error_handler),
std::move(accept_handler)));
// Perform handshake validation
if (!utility::string_equals(req.get_header_value("upgrade"), "websocket"))
{
adaptor_.close();
handler_->remove_websocket(this);
delete this;
conn->adaptor_.close();
return;
}
@@ -142,26 +141,24 @@ namespace crow // NOTE: Already documented in "crow/app.h"
auto subprotocol = utility::find_first_of(subprotocols.begin(), subprotocols.end(), requested_subprotocols.begin(), requested_subprotocols.end());
if (subprotocol != subprotocols.end())
{
subprotocol_ = *subprotocol;
conn->subprotocol_ = *subprotocol;
}
}
if (mirror_protocols & !requested_subprotocols_header.empty())
{
subprotocol_ = requested_subprotocols_header;
conn->subprotocol_ = requested_subprotocols_header;
}
if (accept_handler_)
if (conn->accept_handler_)
{
void* ud = nullptr;
if (!accept_handler_(req, &ud))
if (!conn->accept_handler_(req, &ud))
{
adaptor_.close();
handler_->remove_websocket(this);
delete this;
conn->adaptor_.close();
return;
}
userdata(ud);
conn->userdata(ud);
}
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
@@ -172,22 +169,11 @@ namespace crow // NOTE: Already documented in "crow/app.h"
uint8_t digest[20];
s.getDigestBytes(digest);
start(crow::utility::base64encode((unsigned char*)digest, 20));
conn->handler_->add_websocket(conn);
conn->start(crow::utility::base64encode((unsigned char*)digest, 20));
}
~Connection() noexcept override
{
// Do not modify anchor_ here since writing shared_ptr is not atomic.
auto watch = std::weak_ptr<void>{anchor_};
// Wait until all unhandled asynchronous operations to join.
// As the deletion occurs inside 'check_destroy()', which already locks
// anchor, use count can be 1 on valid deletion context.
while (watch.use_count() > 2) // 1 for 'check_destroy() routine', 1 for 'this->anchor_'
{
std::this_thread::yield();
}
}
~Connection() noexcept override = default;
template<typename Callable>
struct WeakWrappedMessage
@@ -717,38 +703,38 @@ namespace crow // NOTE: Already documented in "crow/app.h"
/// Also destroys the object if the Close flag is set.
void do_write()
{
if (sending_buffers_.empty())
{
sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_)
{
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[&, watch](const error_code& ec, std::size_t /*bytes_transferred*/) {
if (!ec && !close_connection_)
{
sending_buffers_.clear();
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
auto anchor = watch.lock();
if (anchor == nullptr) { return; }
if (write_buffers_.empty()) return;
sending_buffers_.clear();
close_connection_ = true;
check_destroy();
}
});
sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_)
{
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[this, watch](const error_code& ec, std::size_t /*bytes_transferred*/) {
auto anchor = watch.lock();
if (anchor == nullptr)
return;
if (!ec && !close_connection_)
{
sending_buffers_.clear();
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
sending_buffers_.clear();
close_connection_ = true;
check_destroy();
}
});
}
/// Destroy the Connection.
@@ -757,11 +743,14 @@ namespace crow // NOTE: Already documented in "crow/app.h"
// Note that if the close handler was not yet called at this point we did not receive a close packet (or send one)
// and thus we use ClosedAbnormally unless instructed otherwise
if (!is_close_handler_called_)
{
if (close_handler_)
{
close_handler_(*this, "uncleanly", code);
handler_->remove_websocket(this);
if (sending_buffers_.empty() && !is_reading)
delete this;
}
}
handler_->remove_websocket(this->shared_from_this());
}
@@ -796,6 +785,22 @@ namespace crow // NOTE: Already documented in "crow/app.h"
}
private:
Connection(Adaptor&& adaptor, Handler* handler, uint64_t max_payload,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler):
adaptor_(std::move(adaptor)),
handler_(handler),
max_payload_bytes_(max_payload),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)),
accept_handler_(std::move(accept_handler))
{}
Adaptor adaptor_;
Handler* handler_;