Skip to content

Commit 50118ca

Browse files
JasMehta08jblomer
authored andcommitted
[curl] Add SendPutReq to RCurlConnection with tests
1 parent 439fa90 commit 50118ca

4 files changed

Lines changed: 313 additions & 1 deletion

File tree

net/curl/inc/ROOT/RCurlConnection.hxx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ public:
117117
/// a valid batching of requests into multiple multi-range requests takes place automatically.
118118
/// The fNBytesRecv member of the ranges is only well-defined on success.
119119
RStatus SendRangesReq(std::size_t N, RUserRange *ranges);
120+
/// Uploads data to the URL using an HTTP PUT request.
121+
RStatus SendPutReq(const unsigned char *data, std::size_t length);
120122

121123
const std::string &GetEscapedUrl() const { return fEscapedUrl; }
122124

net/curl/src/RCurlConnection.cxx

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,52 @@ void ReverseDisplacements(std::vector<std::size_t> &displacements, ROOT::Interna
552552
}
553553
}
554554

555+
/// State for the PUT upload read callback: tracks progress through the upload buffer.
556+
struct RUploadState {
557+
const unsigned char *fData = nullptr;
558+
std::size_t fLength = 0;
559+
std::size_t fOffset = 0;
560+
};
561+
562+
/// CURLOPT_READFUNCTION callback for PUT uploads. Copies up to `size * nmemb` bytes from the
563+
/// upload buffer into `buffer` and advances the offset. Returns the number of bytes copied,
564+
/// i.e. min(requested, remaining), or 0 at end-of-data to signal that the upload is complete.
565+
std::size_t CallbackPutRead(char *buffer, std::size_t size, std::size_t nmemb, void *userdata)
566+
{
567+
auto *state = static_cast<RUploadState *>(userdata);
568+
R__ASSERT(state->fOffset <= state->fLength);
569+
570+
std::size_t remaining = state->fLength - state->fOffset;
571+
if (remaining == 0)
572+
return 0;
573+
574+
std::size_t requested = size * nmemb;
575+
// CURL_READFUNC_ABORT (0x10000000) collides with a valid byte count at 256 MiB;
576+
// assert that curl never asks for that much in a single callback invocation.
577+
R__ASSERT(requested < CURL_READFUNC_ABORT);
578+
std::size_t nbytes = std::min(requested, remaining);
579+
memcpy(buffer, state->fData + state->fOffset, nbytes);
580+
state->fOffset += nbytes;
581+
return nbytes;
582+
}
583+
584+
/// CURLOPT_SEEKFUNCTION callback for PUT uploads. Required because CURLOPT_FOLLOWLOCATION
585+
/// is enabled: on a redirect curl needs to rewind the upload data before resending.
586+
/// Returns CURL_SEEKFUNC_OK (0) on success, CURL_SEEKFUNC_FAIL (1) for invalid offsets
587+
/// which aborts the transfer, or CURL_SEEKFUNC_CANTSEEK (2) for unsupported seek origins
588+
/// which lets curl try to work around it.
589+
int CallbackPutSeek(void *userdata, curl_off_t offset, int origin)
590+
{
591+
auto *state = static_cast<RUploadState *>(userdata);
592+
// curl documents that it will only use SEEK_SET; guard against anything else defensively.
593+
if (origin != SEEK_SET)
594+
return CURL_SEEKFUNC_CANTSEEK;
595+
if (offset < 0 || static_cast<std::size_t>(offset) > state->fLength)
596+
return CURL_SEEKFUNC_FAIL;
597+
state->fOffset = static_cast<std::size_t>(offset);
598+
return CURL_SEEKFUNC_OK;
599+
}
600+
555601
/// Wrapper around curl_easy_setopt that asserts on failure. Most option-setting calls in this
556602
/// file use valid options and values by construction, so failure indicates a programming error.
557603
template <typename T>
@@ -655,7 +701,6 @@ void ROOT::Internal::RCurlConnection::SetOptions()
655701
static const std::string kUserAgent = GetUserAgentString();
656702
SetCurlOption(fHandle, CURLOPT_USERAGENT, kUserAgent.c_str());
657703
SetCurlOption(fHandle, CURLOPT_FOLLOWLOCATION, 1);
658-
SetCurlOption(fHandle, CURLOPT_WRITEFUNCTION, CallbackData);
659704
}
660705

661706
/// Reset method-specific sticky curl options so that the easy handle is in a clean state
@@ -666,6 +711,8 @@ void ROOT::Internal::RCurlConnection::ResetHandle()
666711
SetCurlOption(fHandle, CURLOPT_HTTPGET, 0L);
667712
SetCurlOption(fHandle, CURLOPT_UPLOAD, 0L);
668713
SetCurlOption(fHandle, CURLOPT_RANGE, static_cast<const char *>(nullptr));
714+
SetCurlOption(fHandle, CURLOPT_WRITEFUNCTION, static_cast<curl_write_callback>(nullptr));
715+
SetCurlOption(fHandle, CURLOPT_WRITEDATA, static_cast<void *>(nullptr));
669716
SetCurlOption(fHandle, CURLOPT_READFUNCTION, static_cast<curl_read_callback>(nullptr));
670717
SetCurlOption(fHandle, CURLOPT_READDATA, static_cast<void *>(nullptr));
671718
SetCurlOption(fHandle, CURLOPT_SEEKFUNCTION, static_cast<curl_seek_callback>(nullptr));
@@ -791,6 +838,7 @@ ROOT::Internal::RCurlConnection::SendRangesReq(std::size_t N, RUserRange *ranges
791838
SetCurlOption(fHandle, CURLOPT_HTTPGET, 1);
792839

793840
RTransferState transfer(ranges, order, fHandle);
841+
SetCurlOption(fHandle, CURLOPT_WRITEFUNCTION, CallbackData);
794842
SetCurlOption(fHandle, CURLOPT_WRITEDATA, &transfer);
795843

796844
#ifndef HAS_CURL_EASY_HEADER
@@ -854,6 +902,26 @@ ROOT::Internal::RCurlConnection::SendRangesReq(std::size_t N, RUserRange *ranges
854902
return status;
855903
}
856904

905+
ROOT::Internal::RCurlConnection::RStatus
906+
ROOT::Internal::RCurlConnection::SendPutReq(const unsigned char *data, std::size_t length)
907+
{
908+
ResetHandle();
909+
910+
SetCurlOption(fHandle, CURLOPT_UPLOAD, 1L);
911+
SetCurlOption(fHandle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(length));
912+
913+
RUploadState uploadState{data, length, 0};
914+
SetCurlOption(fHandle, CURLOPT_READFUNCTION, CallbackPutRead);
915+
SetCurlOption(fHandle, CURLOPT_READDATA, &uploadState);
916+
SetCurlOption(fHandle, CURLOPT_SEEKFUNCTION, CallbackPutSeek);
917+
SetCurlOption(fHandle, CURLOPT_SEEKDATA, &uploadState);
918+
919+
RStatus status;
920+
Perform(status);
921+
922+
return status;
923+
}
924+
857925
void ROOT::Internal::RCurlConnection::SetCredentials(const RS3Credentials &credentials)
858926
{
859927
ClearCredentials();

net/curl/test/curl_connection.cxx

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,67 @@
44

55
#include "TServerSocket.h"
66

7+
#include <algorithm>
8+
#include <cctype>
79
#include <cstdint>
810
#include <cstring>
911
#include <string>
1012
#include <thread>
13+
#include <vector>
14+
15+
/// Return a lower-cased copy of the input string.
16+
static std::string ToLower(std::string s)
17+
{
18+
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
19+
return s;
20+
}
21+
22+
/// Accept a PUT request: read headers + body, optionally respond to Expect: 100-continue, send 200 OK.
23+
static void TaskRecvPut(TServerSocket *serverSocket, std::string *requestHeaders, std::string *requestBody)
24+
{
25+
requestHeaders->clear();
26+
requestBody->clear();
27+
auto sock = serverSocket->Accept();
28+
29+
const char *eof = "\r\n\r\n";
30+
const std::size_t eofLen = strlen(eof);
31+
std::size_t nextInEof = 0;
32+
char c;
33+
while (sock->RecvRaw(&c, 1)) {
34+
requestHeaders->push_back(c);
35+
if (c == eof[nextInEof]) {
36+
if (++nextInEof == eofLen)
37+
break;
38+
} else {
39+
nextInEof = 0;
40+
}
41+
}
42+
43+
// If the client sent Expect: 100-continue, respond with HTTP 100 before reading the body
44+
std::string headersLower = ToLower(*requestHeaders);
45+
if (headersLower.find("expect: 100-continue") != std::string::npos) {
46+
const char *continueResponse = "HTTP/1.1 100 Continue\r\n\r\n";
47+
sock->SendRaw(continueResponse, strlen(continueResponse));
48+
}
49+
50+
// Parse content-length (case-insensitive)
51+
std::size_t contentLength = 0;
52+
auto pos = headersLower.find("content-length: ");
53+
if (pos != std::string::npos) {
54+
auto valStart = pos + strlen("content-length: ");
55+
auto valEnd = headersLower.find("\r\n", valStart);
56+
contentLength = std::stoul(headersLower.substr(valStart, valEnd - valStart));
57+
}
58+
59+
if (contentLength > 0) {
60+
requestBody->resize(contentLength);
61+
sock->RecvRaw(&(*requestBody)[0], contentLength);
62+
}
63+
64+
const char *response = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
65+
sock->SendRaw(response, strlen(response));
66+
sock->Close();
67+
}
1168

1269
static void TaskRecv(TServerSocket *serverSocket, std::string *request)
1370
{
@@ -63,3 +120,131 @@ TEST(RCurlConnection, Cred)
63120
threadRecv.join();
64121
EXPECT_EQ(std::string::npos, request.find("\r\nAuthorization: "));
65122
}
123+
124+
TEST(RCurlConnection, Put)
125+
{
126+
TServerSocket sock(0, false, TServerSocket::kDefaultBacklog, -1, ESocketBindOption::kInaddrLoopback);
127+
const std::string url =
128+
std::string("http://") + sock.GetLocalInetAddress().GetHostAddress() + ":" + std::to_string(sock.GetLocalPort());
129+
130+
const unsigned char payload[] = "Hello, S3!";
131+
const std::size_t payloadLen = sizeof(payload) - 1; // exclude null terminator
132+
133+
std::string headers;
134+
std::string body;
135+
std::thread threadRecv(TaskRecvPut, &sock, &headers, &body);
136+
137+
ROOT::Internal::RCurlConnection conn(url);
138+
auto status = conn.SendPutReq(payload, payloadLen);
139+
140+
threadRecv.join();
141+
EXPECT_TRUE(static_cast<bool>(status));
142+
EXPECT_EQ(0u, headers.find("PUT "));
143+
144+
// Normalize headers to lower-case for case-insensitive matching
145+
std::string headersLower = ToLower(headers);
146+
auto clPos = headersLower.find("content-length: " + std::to_string(payloadLen));
147+
ASSERT_NE(std::string::npos, clPos) << "content-length header not found in request";
148+
149+
EXPECT_EQ(std::string(reinterpret_cast<const char *>(payload), payloadLen), body);
150+
}
151+
152+
/// GET (range read) after PUT on the same handle — verifies that WRITEFUNCTION is set correctly
153+
/// in SendRangesReq after a PUT cleared it.
154+
TEST(RCurlConnection, GetAfterPut)
155+
{
156+
TServerSocket sock(0, false, TServerSocket::kDefaultBacklog, -1, ESocketBindOption::kInaddrLoopback);
157+
const std::string url =
158+
std::string("http://") + sock.GetLocalInetAddress().GetHostAddress() + ":" + std::to_string(sock.GetLocalPort());
159+
160+
// First: do a PUT
161+
const unsigned char putPayload[] = "put-data";
162+
const std::size_t putPayloadLen = sizeof(putPayload) - 1;
163+
164+
std::string putHeaders;
165+
std::string putBody;
166+
std::thread threadRecvPut(TaskRecvPut, &sock, &putHeaders, &putBody);
167+
168+
ROOT::Internal::RCurlConnection conn(url);
169+
auto putStatus = conn.SendPutReq(putPayload, putPayloadLen);
170+
171+
threadRecvPut.join();
172+
EXPECT_TRUE(static_cast<bool>(putStatus));
173+
EXPECT_EQ(0u, putHeaders.find("PUT "));
174+
175+
// Second: do a GET (SendRangesReq) on the same handle.
176+
// The server sends a plain 200 response with the body "response-from-get".
177+
const std::string expectedBody = "response-from-get";
178+
std::string getHeaders;
179+
auto taskRecvGet = [&](TServerSocket *serverSocket) {
180+
getHeaders.clear();
181+
auto s = serverSocket->Accept();
182+
183+
const char *eof = "\r\n\r\n";
184+
const std::size_t eofLen = strlen(eof);
185+
std::size_t nextInEof = 0;
186+
char c;
187+
while (s->RecvRaw(&c, 1)) {
188+
getHeaders.push_back(c);
189+
if (c == eof[nextInEof]) {
190+
if (++nextInEof == eofLen)
191+
break;
192+
} else {
193+
nextInEof = 0;
194+
}
195+
}
196+
197+
std::string response = "HTTP/1.1 200 OK\r\nContent-Length: " + std::to_string(expectedBody.size()) +
198+
"\r\n\r\n" + expectedBody;
199+
s->SendRaw(response.data(), response.size());
200+
s->Close();
201+
};
202+
std::thread threadRecvGet(taskRecvGet, &sock);
203+
204+
std::vector<unsigned char> readBuf(expectedBody.size(), 0);
205+
ROOT::Internal::RCurlConnection::RUserRange range;
206+
range.fDestination = readBuf.data();
207+
range.fOffset = 0;
208+
range.fLength = expectedBody.size();
209+
auto getStatus = conn.SendRangesReq(1, &range);
210+
211+
threadRecvGet.join();
212+
EXPECT_TRUE(static_cast<bool>(getStatus));
213+
EXPECT_EQ(0u, getHeaders.find("GET "));
214+
EXPECT_EQ(expectedBody.size(), range.fNBytesRecv);
215+
std::string received(reinterpret_cast<char *>(readBuf.data()), range.fNBytesRecv);
216+
EXPECT_EQ(expectedBody, received);
217+
}
218+
219+
/// PUT with a payload larger than libcurl's internal Expect: 100-continue threshold (1 MB since curl 7.69).
220+
/// Verifies that the server-side 100 Continue handshake works and all bytes arrive correctly.
221+
TEST(RCurlConnection, PutLargeExpect100)
222+
{
223+
TServerSocket sock(0, false, TServerSocket::kDefaultBacklog, -1, ESocketBindOption::kInaddrLoopback);
224+
const std::string url =
225+
std::string("http://") + sock.GetLocalInetAddress().GetHostAddress() + ":" + std::to_string(sock.GetLocalPort());
226+
227+
// 2 MB payload with a known repeating pattern
228+
const std::size_t payloadLen = 2 * 1024 * 1024;
229+
std::vector<unsigned char> payload(payloadLen);
230+
for (std::size_t i = 0; i < payloadLen; ++i)
231+
payload[i] = static_cast<unsigned char>(i & 0xFF);
232+
233+
std::string headers;
234+
std::string body;
235+
std::thread threadRecv(TaskRecvPut, &sock, &headers, &body);
236+
237+
ROOT::Internal::RCurlConnection conn(url);
238+
auto status = conn.SendPutReq(payload.data(), payloadLen);
239+
240+
threadRecv.join();
241+
EXPECT_TRUE(static_cast<bool>(status));
242+
EXPECT_EQ(0u, headers.find("PUT "));
243+
244+
std::string headersLower = ToLower(headers);
245+
EXPECT_NE(std::string::npos, headersLower.find("expect: 100-continue"))
246+
<< "large upload should include Expect: 100-continue header";
247+
EXPECT_NE(std::string::npos, headersLower.find("content-length: " + std::to_string(payloadLen)));
248+
ASSERT_EQ(payloadLen, body.size());
249+
EXPECT_EQ(0, memcmp(body.data(), payload.data(), payloadLen));
250+
}

net/curl/test/curl_env.cxx

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
#include "TCurlFile.h"
1212
#include "TSystem.h"
1313

14+
#include <cstring>
1415
#include <memory>
1516
#include <utility>
17+
#include <vector>
1618

1719
TEST(RCurlConnection, CredFromEnv)
1820
{
@@ -93,3 +95,58 @@ TEST(CurlFile, S3Credentials)
9395
gSystem->Unsetenv("S3_ACCESS_KEY");
9496
gSystem->Unsetenv("S3_SECRET_KEY");
9597
}
98+
99+
TEST(CurlFile, S3PutAndRead)
100+
{
101+
const auto testAccessKey = std::getenv("ROOT_TEST_S3_ACCESS_KEY");
102+
const auto testSecretKey = std::getenv("ROOT_TEST_S3_SECRET_KEY");
103+
if (!testAccessKey || testAccessKey[0] == '\0' || !testSecretKey || testSecretKey[0] == '\0') {
104+
GTEST_SKIP() << "Missing S3 test credentials <ROOT_TEST_S3_[ACCESS|SECRET]_KEY>, skipping";
105+
}
106+
if (ROOT::Internal::RCurlConnection::GetCurlVersion() <= 0x078100) {
107+
GTEST_SKIP() << "libcurl <= 7.81 is known to produce an AWSv4 signature incompatible with Ceph S3";
108+
}
109+
110+
const std::string url = "https://root-project-s3test.s3.cern.ch/test-curl-put-roundtrip.bin";
111+
112+
ROOT::Internal::RS3Credentials creds;
113+
creds.fAccessKey = testAccessKey;
114+
creds.fSecretKey = testSecretKey;
115+
116+
// PUT a known payload
117+
const unsigned char payload[] = "RCurlConnection::SendPutReq round-trip test";
118+
const std::size_t payloadLen = sizeof(payload) - 1;
119+
120+
{
121+
ROOT::Internal::RCurlConnection conn(url);
122+
conn.SetCredentials(creds);
123+
auto status = conn.SendPutReq(payload, payloadLen);
124+
ASSERT_TRUE(static_cast<bool>(status)) << "PUT failed: " << status.fStatusMsg;
125+
}
126+
127+
// HEAD to verify size
128+
{
129+
ROOT::Internal::RCurlConnection conn(url);
130+
conn.SetCredentials(creds);
131+
std::uint64_t remoteSize = 0;
132+
auto status = conn.SendHeadReq(remoteSize);
133+
ASSERT_TRUE(static_cast<bool>(status)) << "HEAD failed: " << status.fStatusMsg;
134+
EXPECT_EQ(static_cast<std::uint64_t>(payloadLen), remoteSize);
135+
}
136+
137+
// GET (range read) to verify content
138+
{
139+
std::vector<unsigned char> readback(payloadLen, 0);
140+
ROOT::Internal::RCurlConnection::RUserRange range;
141+
range.fDestination = readback.data();
142+
range.fOffset = 0;
143+
range.fLength = payloadLen;
144+
145+
ROOT::Internal::RCurlConnection conn(url);
146+
conn.SetCredentials(creds);
147+
auto status = conn.SendRangesReq(1, &range);
148+
ASSERT_TRUE(static_cast<bool>(status)) << "GET failed: " << status.fStatusMsg;
149+
EXPECT_EQ(payloadLen, range.fNBytesRecv);
150+
EXPECT_EQ(0, memcmp(readback.data(), payload, payloadLen));
151+
}
152+
}

0 commit comments

Comments
 (0)