Skip to content

Commit 56956c6

Browse files
address review: refresh uses cached per-table session; add RestCatalog session test
1 parent fbfb7e9 commit 56956c6

5 files changed

Lines changed: 302 additions & 4 deletions

File tree

src/iceberg/catalog/rest/rest_catalog.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ Status RestCatalog::RememberTableSession(
162162
}
163163

164164
std::shared_ptr<auth::AuthSession> RestCatalog::SessionFor(
165-
const TableIdentifier& identifier) {
165+
const TableIdentifier& identifier) const {
166166
std::lock_guard<std::mutex> lock(table_sessions_mutex_);
167167
auto it = table_sessions_.find(TableSessionKey(identifier));
168168
return it != table_sessions_.end() ? it->second : catalog_session_;
@@ -512,10 +512,12 @@ Result<std::string> RestCatalog::LoadTableInternal(
512512
params["snapshots"] = "all";
513513
}
514514

515+
// Refresh uses the cached per-table session; the initial load falls back to
516+
// the catalog session (no table session is cached yet).
515517
ICEBERG_ASSIGN_OR_RAISE(
516518
const auto response,
517519
client_->Get(path, params, /*headers=*/{}, *TableErrorHandler::Instance(),
518-
*catalog_session_));
520+
*SessionFor(identifier)));
519521
return response.body();
520522
}
521523

src/iceberg/catalog/rest/rest_catalog.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class ICEBERG_REST_EXPORT RestCatalog : public Catalog,
117117
const std::unordered_map<std::string, std::string>& config);
118118

119119
/// \brief Returns the cached per-table session, or the catalog session.
120-
std::shared_ptr<auth::AuthSession> SessionFor(const TableIdentifier& identifier);
120+
std::shared_ptr<auth::AuthSession> SessionFor(const TableIdentifier& identifier) const;
121121

122122
Result<LoadTableResult> CreateTableInternal(
123123
const TableIdentifier& identifier, const std::shared_ptr<Schema>& schema,
@@ -134,7 +134,7 @@ class ICEBERG_REST_EXPORT RestCatalog : public Catalog,
134134
std::unique_ptr<auth::AuthManager> auth_manager_;
135135
std::shared_ptr<auth::AuthSession> catalog_session_;
136136
SnapshotMode snapshot_mode_;
137-
std::mutex table_sessions_mutex_;
137+
mutable std::mutex table_sessions_mutex_;
138138
std::unordered_map<std::string, std::shared_ptr<auth::AuthSession>> table_sessions_;
139139
};
140140

src/iceberg/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ if(ICEBERG_BUILD_REST)
286286
SOURCES
287287
auth_manager_test.cc
288288
endpoint_test.cc
289+
rest_catalog_session_test.cc
289290
rest_file_io_test.cc
290291
rest_json_serde_test.cc
291292
rest_util_test.cc)

src/iceberg/test/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ if get_option('rest').enabled()
128128
'sources': files(
129129
'auth_manager_test.cc',
130130
'endpoint_test.cc',
131+
'rest_catalog_session_test.cc',
131132
'rest_file_io_test.cc',
132133
'rest_json_serde_test.cc',
133134
'rest_util_test.cc',
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <gtest/gtest.h>
21+
22+
#ifndef _WIN32
23+
24+
# include <unistd.h>
25+
26+
# include <atomic>
27+
# include <memory>
28+
# include <mutex>
29+
# include <string>
30+
# include <thread>
31+
# include <unordered_map>
32+
# include <vector>
33+
34+
# include <netinet/in.h>
35+
# include <sys/socket.h>
36+
37+
# include "iceberg/catalog/rest/auth/auth_manager.h"
38+
# include "iceberg/catalog/rest/auth/auth_managers.h"
39+
# include "iceberg/catalog/rest/auth/auth_properties.h"
40+
# include "iceberg/catalog/rest/auth/auth_session.h"
41+
# include "iceberg/catalog/rest/catalog_properties.h"
42+
# include "iceberg/catalog/rest/rest_catalog.h"
43+
# include "iceberg/file_io.h"
44+
# include "iceberg/file_io_registry.h"
45+
# include "iceberg/table_identifier.h"
46+
# include "iceberg/table_requirement.h"
47+
# include "iceberg/table_update.h"
48+
# include "iceberg/test/matchers.h"
49+
50+
namespace iceberg::rest {
51+
52+
namespace {
53+
54+
constexpr std::string_view kMetadataJson =
55+
R"({"format-version":2,"table-uuid":"test-uuid-1234","location":"s3://bucket/test",)"
56+
R"("last-sequence-number":0,"last-updated-ms":0,"last-column-id":1,)"
57+
R"("schemas":[{"type":"struct","schema-id":1,"fields":[{"id":1,"name":"id","type":"int","required":true}]}],)"
58+
R"("current-schema-id":1,"partition-specs":[{"spec-id":0,"fields":[]}],"default-spec-id":0,)"
59+
R"("last-partition-id":0,"sort-orders":[{"order-id":0,"fields":[]}],"default-sort-order-id":0})";
60+
61+
struct RecordedRequest {
62+
std::string method;
63+
std::string path;
64+
std::string auth_marker;
65+
};
66+
67+
class MiniRestServer {
68+
public:
69+
bool Start() {
70+
listen_fd_ = ::socket(AF_INET, SOCK_STREAM, 0);
71+
if (listen_fd_ < 0) return false;
72+
int reuse = 1;
73+
::setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));
74+
sockaddr_in addr{};
75+
addr.sin_family = AF_INET;
76+
addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
77+
addr.sin_port = 0;
78+
if (::bind(listen_fd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) {
79+
return false;
80+
}
81+
socklen_t len = sizeof(addr);
82+
::getsockname(listen_fd_, reinterpret_cast<sockaddr*>(&addr), &len);
83+
port_ = ntohs(addr.sin_port);
84+
if (::listen(listen_fd_, 8) < 0) return false;
85+
server_thread_ = std::thread([this] { Loop(); });
86+
return true;
87+
}
88+
89+
void Stop() {
90+
stopping_ = true;
91+
if (listen_fd_ >= 0) {
92+
::shutdown(listen_fd_, SHUT_RDWR);
93+
::close(listen_fd_);
94+
listen_fd_ = -1;
95+
}
96+
if (server_thread_.joinable()) server_thread_.join();
97+
}
98+
99+
int port() const { return port_; }
100+
101+
std::vector<RecordedRequest> requests() {
102+
std::lock_guard<std::mutex> lock(mutex_);
103+
return requests_;
104+
}
105+
106+
private:
107+
void Loop() {
108+
while (!stopping_) {
109+
int fd = ::accept(listen_fd_, nullptr, nullptr);
110+
if (fd < 0) break;
111+
HandleConnection(fd);
112+
::close(fd);
113+
}
114+
}
115+
116+
void HandleConnection(int fd) {
117+
std::string raw;
118+
char buf[4096];
119+
size_t header_end = std::string::npos;
120+
while (header_end == std::string::npos) {
121+
ssize_t n = ::read(fd, buf, sizeof(buf));
122+
if (n <= 0) return;
123+
raw.append(buf, static_cast<size_t>(n));
124+
header_end = raw.find("\r\n\r\n");
125+
}
126+
size_t content_length = 0;
127+
{
128+
std::string lower;
129+
lower.reserve(header_end);
130+
for (size_t i = 0; i < header_end; ++i) {
131+
lower.push_back(static_cast<char>(std::tolower(raw[i])));
132+
}
133+
auto pos = lower.find("content-length:");
134+
if (pos != std::string::npos) {
135+
content_length = std::stoul(lower.substr(pos + 15));
136+
}
137+
}
138+
while (raw.size() < header_end + 4 + content_length) {
139+
ssize_t n = ::read(fd, buf, sizeof(buf));
140+
if (n <= 0) break;
141+
raw.append(buf, static_cast<size_t>(n));
142+
}
143+
144+
auto line_end = raw.find("\r\n");
145+
auto request_line = raw.substr(0, line_end);
146+
auto sp1 = request_line.find(' ');
147+
auto sp2 = request_line.find(' ', sp1 + 1);
148+
RecordedRequest req;
149+
req.method = request_line.substr(0, sp1);
150+
req.path = request_line.substr(sp1 + 1, sp2 - sp1 - 1);
151+
req.auth_marker = HeaderValue(raw.substr(0, header_end), "x-test-auth");
152+
{
153+
std::lock_guard<std::mutex> lock(mutex_);
154+
requests_.push_back(req);
155+
}
156+
157+
Respond(fd, BodyFor(req));
158+
}
159+
160+
static std::string HeaderValue(const std::string& headers, std::string_view name) {
161+
std::string lower;
162+
lower.reserve(headers.size());
163+
for (char c : headers) lower.push_back(static_cast<char>(std::tolower(c)));
164+
auto pos = lower.find(std::string(name) + ":");
165+
if (pos == std::string::npos) return "";
166+
auto value_start = pos + name.size() + 1;
167+
auto value_end = headers.find("\r\n", value_start);
168+
auto value = headers.substr(value_start, value_end - value_start);
169+
auto first = value.find_first_not_of(' ');
170+
return first == std::string::npos ? "" : value.substr(first);
171+
}
172+
173+
std::string BodyFor(const RecordedRequest& req) {
174+
if (req.path.find("/v1/config") != std::string::npos) {
175+
return R"({"defaults":{},"overrides":{}})";
176+
}
177+
if (req.method == "GET" && req.path.find("/tables/") != std::string::npos) {
178+
return std::string(R"({"metadata-location":"s3://bucket/meta/v1.json",)") +
179+
R"("metadata":)" + std::string(kMetadataJson) +
180+
R"(,"config":{"token":"tbl-token-1"}})";
181+
}
182+
if (req.method == "POST" && req.path.find("/tables/") != std::string::npos) {
183+
return std::string(R"({"metadata-location":"s3://bucket/meta/v2.json",)") +
184+
R"("metadata":)" + std::string(kMetadataJson) + "}";
185+
}
186+
return "{}";
187+
}
188+
189+
static void Respond(int fd, const std::string& body) {
190+
std::string response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n";
191+
response += "Content-Length: " + std::to_string(body.size()) + "\r\n";
192+
response += "Connection: close\r\n\r\n";
193+
response += body;
194+
size_t sent = 0;
195+
while (sent < response.size()) {
196+
ssize_t n = ::write(fd, response.data() + sent, response.size() - sent);
197+
if (n <= 0) break;
198+
sent += static_cast<size_t>(n);
199+
}
200+
}
201+
202+
int listen_fd_ = -1;
203+
int port_ = 0;
204+
std::atomic<bool> stopping_{false};
205+
std::thread server_thread_;
206+
std::mutex mutex_;
207+
std::vector<RecordedRequest> requests_;
208+
};
209+
210+
class RecordingAuthManager : public auth::AuthManager {
211+
public:
212+
Result<std::shared_ptr<auth::AuthSession>> InitSession(
213+
HttpClient& /*init_client*/,
214+
const std::unordered_map<std::string, std::string>& /*properties*/) override {
215+
return auth::AuthSession::MakeDefault({{"x-test-auth", "init"}});
216+
}
217+
218+
Result<std::shared_ptr<auth::AuthSession>> CatalogSession(
219+
HttpClient& /*shared_client*/,
220+
const std::unordered_map<std::string, std::string>& /*properties*/) override {
221+
return auth::AuthSession::MakeDefault({{"x-test-auth", "catalog"}});
222+
}
223+
224+
Result<std::shared_ptr<auth::AuthSession>> TableSession(
225+
const TableIdentifier& /*table*/,
226+
const std::unordered_map<std::string, std::string>& properties,
227+
std::shared_ptr<auth::AuthSession> parent) override {
228+
auto token = properties.find("token");
229+
if (token == properties.end()) {
230+
return parent;
231+
}
232+
return auth::AuthSession::MakeDefault({{"x-test-auth", "table:" + token->second}});
233+
}
234+
};
235+
236+
class MockFileIO : public FileIO {};
237+
238+
} // namespace
239+
240+
TEST(RestCatalogSessionTest, RefreshAndCommitUseTableSessionFromResponseConfig) {
241+
MiniRestServer server;
242+
ASSERT_TRUE(server.Start());
243+
244+
auth::AuthManagers::Register(
245+
"test-session-recorder",
246+
[](std::string_view /*name*/,
247+
const std::unordered_map<std::string, std::string>& /*properties*/)
248+
-> Result<std::unique_ptr<auth::AuthManager>> {
249+
return std::make_unique<RecordingAuthManager>();
250+
});
251+
FileIORegistry::Register(
252+
"test.SessionMockFileIO",
253+
[](const std::unordered_map<std::string, std::string>& /*properties*/)
254+
-> Result<std::unique_ptr<FileIO>> { return std::make_unique<MockFileIO>(); });
255+
256+
auto config = RestCatalogProperties::FromMap(
257+
{{"uri", "http://127.0.0.1:" + std::to_string(server.port())},
258+
{auth::AuthProperties::kAuthType, "test-session-recorder"},
259+
{"io-impl", "test.SessionMockFileIO"}});
260+
261+
{
262+
auto catalog_result = RestCatalog::Make(config);
263+
ASSERT_THAT(catalog_result, IsOk());
264+
auto catalog = catalog_result.value();
265+
266+
TableIdentifier identifier{.ns = Namespace{{"ns1"}}, .name = "tbl1"};
267+
ASSERT_THAT(catalog->LoadTable(identifier), IsOk());
268+
ASSERT_THAT(catalog->LoadTable(identifier), IsOk());
269+
ASSERT_THAT(catalog->UpdateTable(identifier, {}, {}), IsOk());
270+
}
271+
272+
server.Stop();
273+
274+
auto requests = server.requests();
275+
ASSERT_EQ(requests.size(), 4);
276+
EXPECT_TRUE(requests[0].path.find("/v1/config") != std::string::npos);
277+
EXPECT_EQ(requests[0].auth_marker, "init");
278+
EXPECT_EQ(requests[1].method, "GET");
279+
EXPECT_EQ(requests[1].auth_marker, "catalog");
280+
EXPECT_EQ(requests[2].method, "GET");
281+
EXPECT_EQ(requests[2].auth_marker, "table:tbl-token-1");
282+
EXPECT_EQ(requests[3].method, "POST");
283+
EXPECT_EQ(requests[3].auth_marker, "table:tbl-token-1");
284+
}
285+
286+
} // namespace iceberg::rest
287+
288+
#else
289+
290+
TEST(RestCatalogSessionTest, RefreshAndCommitUseTableSessionFromResponseConfig) {
291+
GTEST_SKIP() << "POSIX-socket test server is not available on Windows";
292+
}
293+
294+
#endif // _WIN32

0 commit comments

Comments
 (0)