Skip to content

Commit 54658f0

Browse files
authored
Merge pull request #4237 from ProvableHQ/feat/rest_enhancements
[Perf] REST enhancements
2 parents de417fe + 5fc1a9d commit 54658f0

3 files changed

Lines changed: 70 additions & 26 deletions

File tree

node/rest/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ workspace = true
6161
features = [ "parking_lot" ]
6262
optional = true
6363

64+
[dependencies.lru]
65+
workspace = true
66+
6467
[dependencies.once_cell]
6568
workspace = true
6669

node/rest/src/lib.rs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,17 @@ use axum::{
5151
use axum_extra::response::ErasedJson;
5252
#[cfg(feature = "locktick")]
5353
use locktick::parking_lot::Mutex;
54+
use lru::LruCache;
5455
#[cfg(not(feature = "locktick"))]
5556
use parking_lot::Mutex;
56-
use std::{net::SocketAddr, sync::Arc};
57+
use std::{net::SocketAddr, num::NonZeroUsize, sync::Arc, time::Duration};
5758
use tokio::{net::TcpListener, sync::Semaphore, task::JoinHandle};
5859
use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
5960
use tower_http::{
6061
cors::{Any, CorsLayer},
6162
trace::TraceLayer,
6263
};
64+
use tracing::Span;
6365

6466
/// The default port used for the REST API
6567
pub const DEFAULT_REST_PORT: u16 = 3030;
@@ -68,6 +70,9 @@ pub const DEFAULT_REST_PORT: u16 = 3030;
6870
pub const API_VERSION_V1: &str = "v1";
6971
pub const API_VERSION_V2: &str = "v2";
7072

73+
/// The capacity of the LRU holding recently requested blocks.
74+
const BLOCK_CACHE_SIZE: usize = 128;
75+
7176
/// A REST API server for the ledger.
7277
#[derive(Clone)]
7378
pub struct Rest<N: Network, C: ConsensusStorage<N>, R: Routing<N>> {
@@ -89,6 +94,8 @@ pub struct Rest<N: Network, C: ConsensusStorage<N>, R: Routing<N>> {
8994
num_verifying_executions: Arc<Semaphore>,
9095
/// The number of ongoing solution verifications via REST.
9196
num_verifying_solutions: Arc<Semaphore>,
97+
/// A cache containing recently requested blocks.
98+
block_cache: Arc<Mutex<LruCache<N::BlockHash, ErasedJson>>>,
9299
}
93100

94101
impl<N: Network, C: 'static + ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
@@ -113,6 +120,7 @@ impl<N: Network, C: 'static + ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R>
113120
num_verifying_deploys: Arc::new(Semaphore::new(VM::<N, C>::MAX_PARALLEL_DEPLOY_VERIFICATIONS)),
114121
num_verifying_executions: Arc::new(Semaphore::new(VM::<N, C>::MAX_PARALLEL_EXECUTE_VERIFICATIONS)),
115122
num_verifying_solutions: Arc::new(Semaphore::new(N::MAX_SOLUTIONS)),
123+
block_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(BLOCK_CACHE_SIZE).unwrap()))),
116124
};
117125
// Spawn the server.
118126
server.spawn_server(rest_ip, rest_rps).await?;
@@ -263,20 +271,41 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
263271
#[cfg(feature = "history-staking-rewards")]
264272
let routes = routes.route("/staking/rewards/{address}/{height}", get(Self::get_staking_reward));
265273

274+
let trace_layer = TraceLayer::new_for_http()
275+
.make_span_with(|request: &Request<_>| {
276+
let addr = request
277+
.extensions()
278+
.get::<ConnectInfo<SocketAddr>>()
279+
.map(|ConnectInfo(addr)| addr.to_string())
280+
.unwrap_or_else(|| "unknown".to_string());
281+
282+
// Create a span that includes method, path, and our extracted IP
283+
tracing::info_span!(
284+
"REST",
285+
method = %request.method(),
286+
uri = %request.uri().path(),
287+
addr = %addr,
288+
)
289+
})
290+
.on_request(|_request: &Request<_>, _span: &Span| {
291+
info!("Received a request");
292+
})
293+
.on_response(|_response: &Response<_>, latency: Duration, _span: &Span| {
294+
info!("Finished request in {:?}", latency);
295+
});
296+
266297
routes
267298
// Pass in `Rest` to make things convenient.
268299
.with_state(self.clone())
269-
// Enable tower-http tracing.
270-
.layer(TraceLayer::new_for_http())
271-
// Custom logging.
272-
.layer(middleware::map_request(log_middleware))
273-
// Enable CORS.
274-
.layer(cors)
275300
// Cap the request body size at 512KiB.
276301
.layer(DefaultBodyLimit::max(512 * 1024))
277302
.layer(GovernorLayer {
278303
config: governor_config.into(),
279304
})
305+
// Enable CORS.
306+
.layer(cors)
307+
// Enable tower-http tracing.
308+
.layer(trace_layer)
280309
}
281310

282311
async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) -> Result<()> {
@@ -314,12 +343,6 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
314343
}
315344
}
316345

317-
/// Creates a log message for every HTTP request.
318-
async fn log_middleware(ConnectInfo(addr): ConnectInfo<SocketAddr>, request: Request<Body>) -> Request<Body> {
319-
info!("Received '{} {}' from '{addr}'", request.method(), request.uri());
320-
request
321-
}
322-
323346
/// Converts errors to the old style for the v1 API.
324347
/// The error code will always be 500 and the content a simple string.
325348
async fn v1_error_middleware(response: Response) -> Response {

node/rest/src/routes.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
146146

147147
/// GET /<network>/block/latest
148148
pub(crate) async fn get_block_latest(State(rest): State<Self>) -> ErasedJson {
149-
ErasedJson::pretty(rest.ledger.latest_block())
149+
let block = rest.ledger.latest_block();
150+
let hash = block.hash();
151+
// When present, this is 3x faster than serializing the block from the ledger.
152+
rest.block_cache.lock().get_or_insert(hash, || ErasedJson::pretty(block)).clone()
150153
}
151154

152155
/// GET /<network>/block/{height}
@@ -157,23 +160,38 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
157160
) -> Result<ErasedJson, RestError> {
158161
// Manually parse the height or the height of the hash, axum doesn't support different types
159162
// for the same path param.
160-
let id_name;
161-
let block = if let Ok(height) = height_or_hash.parse::<u32>() {
162-
id_name = "hash";
163-
rest.ledger.try_get_block(height).with_context(|| "Failed to get block by height")?
163+
let hash = if let Ok(height) = height_or_hash.parse::<u32>() {
164+
rest.ledger.get_hash(height).with_context(|| "Failed to get a block's hash")?
164165
} else if let Ok(hash) = height_or_hash.parse::<N::BlockHash>() {
165-
id_name = "height";
166-
rest.ledger.try_get_block_by_hash(&hash).with_context(|| "Failed to get block by hash")?
166+
hash
167167
} else {
168-
return Err(RestError::bad_request(anyhow!(
169-
"invalid input, it is neither a block height nor a block hash"
170-
)));
168+
return Err(RestError::bad_request(anyhow!("invalid input: neither a block height nor a block hash")));
171169
};
172170

173-
match block {
174-
Some(block) => Ok(ErasedJson::pretty(block)),
175-
None => Err(RestError::not_found(anyhow!("No block with {id_name} {height_or_hash} found"))),
171+
// Attempt to find a serialized block in the cache.
172+
if let Some(json_block) = rest.block_cache.lock().get(&hash) {
173+
return Ok(json_block.clone());
176174
}
175+
176+
// Retrieve the block from the database.
177+
let json_block = match tokio::task::spawn_blocking(move || match rest.ledger.try_get_block_by_hash(&hash) {
178+
Ok(Some(block)) => Some(ErasedJson::pretty(block)),
179+
Ok(None) => None,
180+
Err(e) => {
181+
error!("Couldn't find a block: {e}");
182+
None
183+
}
184+
})
185+
.await
186+
{
187+
Ok(Some(block)) => Ok(block),
188+
Ok(None) => Err(RestError::not_found(anyhow!("Couldn't find block {height_or_hash}"))),
189+
Err(e) => Err(RestError::internal_server_error(anyhow!("tokio error: {e}"))),
190+
}?;
191+
192+
rest.block_cache.lock().put(hash, json_block.clone());
193+
194+
Ok(json_block)
177195
}
178196

179197
/// GET /<network>/blocks?start={start_height}&end={end_height}

0 commit comments

Comments
 (0)