diff --git a/sqlite-watcher/build.rs b/sqlite-watcher/build.rs new file mode 100644 index 0000000..a38290d --- /dev/null +++ b/sqlite-watcher/build.rs @@ -0,0 +1,8 @@ +fn main() -> Result<(), Box> { + tonic_build::configure() + .build_client(true) + .build_server(true) + .compile(&["proto/watcher.proto"], &["proto"])?; + println!("cargo:rerun-if-changed=proto/watcher.proto"); + Ok(()) +} diff --git a/sqlite-watcher/proto/watcher.proto b/sqlite-watcher/proto/watcher.proto new file mode 100644 index 0000000..064f3cb --- /dev/null +++ b/sqlite-watcher/proto/watcher.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package sqlitewatcher; + +message HealthCheckRequest {} +message HealthCheckResponse { + string status = 1; +} + +message ListChangesRequest { + uint32 limit = 1; +} + +message Change { + int64 change_id = 1; + string table_name = 2; + string op = 3; + string primary_key = 4; + bytes payload = 5; + string wal_frame = 6; + string cursor = 7; +} + +message ListChangesResponse { + repeated Change changes = 1; +} + +message AckChangesRequest { + int64 up_to_change_id = 1; +} + +message AckChangesResponse { + uint64 acknowledged = 1; +} + +message GetStateRequest { + string table_name = 1; +} + +message GetStateResponse { + bool exists = 1; + int64 last_change_id = 2; + string last_wal_frame = 3; + string cursor = 4; +} + +message SetStateRequest { + string table_name = 1; + int64 last_change_id = 2; + string last_wal_frame = 3; + string cursor = 4; +} + +message SetStateResponse {} + +service Watcher { + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + rpc ListChanges(ListChangesRequest) returns (ListChangesResponse); + rpc AckChanges(AckChangesRequest) returns (AckChangesResponse); + rpc GetState(GetStateRequest) returns (GetStateResponse); + rpc SetState(SetStateRequest) returns (SetStateResponse); +} diff --git a/sqlite-watcher/src/change.rs b/sqlite-watcher/src/change.rs new file mode 100644 index 0000000..64059d7 --- /dev/null +++ b/sqlite-watcher/src/change.rs @@ -0,0 +1,49 @@ +use serde_json::Value; + +use crate::queue::{ChangeOperation, NewChange}; + +#[derive(Debug, Clone, PartialEq)] +pub struct RowChange { + pub table_name: String, + pub operation: ChangeOperation, + pub primary_key: String, + pub payload: Option, + pub wal_frame: Option, + pub cursor: Option, +} + +impl RowChange { + pub fn into_new_change(self) -> NewChange { + let payload = self + .payload + .map(|value| serde_json::to_vec(&value).expect("row change payload serializes")); + NewChange { + table_name: self.table_name, + operation: self.operation, + primary_key: self.primary_key, + payload, + wal_frame: self.wal_frame, + cursor: self.cursor, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn converts_to_new_change() { + let row = RowChange { + table_name: "prices".into(), + operation: ChangeOperation::Update, + primary_key: "pk1".into(), + payload: Some(serde_json::json!({"foo": "bar"})), + wal_frame: Some("frame-1".into()), + cursor: Some("cursor".into()), + }; + let change = row.into_new_change(); + assert_eq!(change.table_name, "prices"); + assert!(change.payload.unwrap().contains(&b'b')); + } +} diff --git a/sqlite-watcher/src/decoder.rs b/sqlite-watcher/src/decoder.rs new file mode 100644 index 0000000..9da5268 --- /dev/null +++ b/sqlite-watcher/src/decoder.rs @@ -0,0 +1,48 @@ +use crate::change::RowChange; +use crate::queue::ChangeOperation; +use crate::wal::WalEvent; +use serde_json::json; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Temporary decoder that turns WAL growth bytes into placeholder RowChange events. +/// Placeholder until row-level decoding is implemented. +#[derive(Debug, Default, Clone)] +pub struct WalGrowthDecoder; + +impl WalGrowthDecoder { + pub fn decode(&self, event: &WalEvent) -> Vec { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be >= UNIX epoch"); + vec![RowChange { + table_name: "__wal__".to_string(), + operation: ChangeOperation::Insert, + primary_key: now.as_nanos().to_string(), + payload: Some(json!({ + "kind": "wal_growth", + "bytes_added": event.bytes_added, + "current_size": event.current_size, + "recorded_at": now.as_secs_f64(), + })), + wal_frame: None, + cursor: None, + }] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn produces_placeholder_row_change() { + let decoder = WalGrowthDecoder::default(); + let rows = decoder.decode(&WalEvent { + bytes_added: 1024, + current_size: 2048, + }); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].table_name, "__wal__"); + assert_eq!(rows[0].operation, ChangeOperation::Insert); + } +} diff --git a/sqlite-watcher/src/main.rs b/sqlite-watcher/src/main.rs new file mode 100644 index 0000000..4b2f373 --- /dev/null +++ b/sqlite-watcher/src/main.rs @@ -0,0 +1,371 @@ +use std::fmt; +use std::fs; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::sync::mpsc; +use std::time::Duration; + +use anyhow::{anyhow, bail, Context, Result}; +use clap::Parser; +use sqlite_watcher::decoder::WalGrowthDecoder; +use sqlite_watcher::queue::ChangeQueue; +#[cfg(unix)] +use sqlite_watcher::server::spawn_unix_server; +use sqlite_watcher::server::{spawn_tcp_server, ServerHandle}; +use sqlite_watcher::wal::{start_wal_watcher, WalWatcherConfig as TailConfig}; +use tracing_subscriber::EnvFilter; + +#[cfg(unix)] +const DEFAULT_LISTEN: &str = "unix:/tmp/sqlite-watcher.sock"; +#[cfg(not(unix))] +const DEFAULT_LISTEN: &str = "tcp:50051"; + +/// Command-line interface definition for sqlite-watcher. +#[derive(Debug, Clone, Parser)] +#[command( + name = "sqlite-watcher", + version, + about = "Tails SQLite WAL files and exposes change streams.", + long_about = None +)] +struct Cli { + /// Path to the SQLite database the watcher should monitor. + #[arg(long = "db", value_name = "PATH")] + db_path: PathBuf, + + /// Listener binding. Accepts unix:/path, tcp:, or pipe:. + #[arg(long, value_name = "ENDPOINT", default_value = DEFAULT_LISTEN)] + listen: String, + + /// Shared-secret token file used to authenticate RPC clients. + #[arg(long = "token-file", value_name = "PATH")] + token_file: Option, + + /// Path to the durable change queue database. + #[arg(long = "queue-db", value_name = "PATH")] + queue_db: Option, + + /// Tracing filter (info,warn,debug,trace). Can also be provided via SQLITE_WATCHER_LOG. + #[arg( + long = "log-level", + value_name = "FILTER", + default_value = "info", + env = "SQLITE_WATCHER_LOG" + )] + log_filter: String, + + /// Interval in milliseconds between WAL file polls. + #[arg( + long = "poll-interval-ms", + default_value_t = 500, + value_parser = clap::value_parser!(u64).range(50..=60_000) + )] + poll_interval_ms: u64, + + /// Minimum WAL byte growth required before emitting an event. + #[arg( + long = "min-event-bytes", + default_value_t = 1, + value_parser = clap::value_parser!(u64).range(1..=10_000_000) + )] + min_event_bytes: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ListenAddress { + Unix(PathBuf), + Tcp { host: String, port: u16 }, + Pipe(String), +} + +impl fmt::Display for ListenAddress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ListenAddress::Unix(path) => write!(f, "unix:{}", path.display()), + ListenAddress::Tcp { host, port } => write!(f, "tcp:{}:{}", host, port), + ListenAddress::Pipe(name) => write!(f, "pipe:{}", name), + } + } +} + +impl FromStr for ListenAddress { + type Err = anyhow::Error; + + fn from_str(value: &str) -> Result { + if let Some(path) = value.strip_prefix("unix:") { + if cfg!(windows) { + bail!("unix sockets are not supported on Windows"); + } + if path.is_empty() { + bail!("unix listen path cannot be empty"); + } + return Ok(ListenAddress::Unix(PathBuf::from(path))); + } + + if let Some(port) = value.strip_prefix("tcp:") { + let port: u16 = port + .parse() + .map_err(|_| anyhow!("tcp listener must specify a numeric port"))?; + return Ok(ListenAddress::Tcp { + host: "127.0.0.1".to_string(), + port, + }); + } + + if let Some(name) = value.strip_prefix("pipe:") { + if cfg!(not(windows)) { + bail!("named pipes are only valid on Windows"); + } + if name.is_empty() { + bail!("pipe name cannot be empty"); + } + return Ok(ListenAddress::Pipe(name.to_string())); + } + + bail!("listen endpoint must start with unix:/, tcp:, or pipe:"); + } +} + +#[derive(Debug, Clone)] +struct WatcherConfig { + database_path: PathBuf, + listen: ListenAddress, + token_file: PathBuf, + queue_path: PathBuf, + poll_interval: Duration, + min_event_bytes: u64, +} + +impl TryFrom for WatcherConfig { + type Error = anyhow::Error; + + fn try_from(args: Cli) -> Result { + let database_path = ensure_sqlite_file(&args.db_path)?; + let listen = ListenAddress::from_str(args.listen.trim())?; + let token_file = match args.token_file { + Some(path) => expand_home(path)?, + None => default_token_path()?, + }; + let queue_path = match args.queue_db { + Some(path) => expand_home(path)?, + None => default_queue_path()?, + }; + + Ok(Self { + database_path, + listen, + token_file, + queue_path, + poll_interval: Duration::from_millis(args.poll_interval_ms), + min_event_bytes: args.min_event_bytes, + }) + } +} + +fn ensure_sqlite_file(path: &Path) -> Result { + if !path.exists() { + bail!("database path {} does not exist", path.display()); + } + if !path.is_file() { + bail!("database path {} is not a file", path.display()); + } + Ok(path + .canonicalize() + .with_context(|| format!("failed to canonicalize {}", path.display()))?) +} + +fn default_token_path() -> Result { + let home = dirs::home_dir().ok_or_else(|| anyhow!("unable to determine home directory"))?; + Ok(home.join(".seren/sqlite-watcher/token")) +} + +fn default_queue_path() -> Result { + let home = dirs::home_dir().ok_or_else(|| anyhow!("unable to determine home directory"))?; + Ok(home.join(".seren/sqlite-watcher/changes.db")) +} + +fn expand_home(path: PathBuf) -> Result { + let as_str = path.to_string_lossy(); + if let Some(stripped) = as_str.strip_prefix("~/") { + let home = dirs::home_dir().ok_or_else(|| anyhow!("unable to determine home directory"))?; + return Ok(home.join(stripped)); + } + if as_str == "~" { + let home = dirs::home_dir().ok_or_else(|| anyhow!("unable to determine home directory"))?; + return Ok(home); + } + Ok(path) +} + +fn init_tracing(filter: &str) -> Result<()> { + let env_filter = EnvFilter::try_new(filter).or_else(|_| EnvFilter::try_new("info"))?; + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .with_target(false) + .try_init() + .map_err(|err| anyhow!("failed to init tracing subscriber: {err}")) +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + init_tracing(&cli.log_filter)?; + let config = WatcherConfig::try_from(cli)?; + let auth_token = read_token_file(&config.token_file)?; + + tracing::info!( + db = %config.database_path.display(), + listen = %config.listen, + token = %config.token_file.display(), + queue = %config.queue_path.display(), + poll_ms = config.poll_interval.as_millis(), + min_event_bytes = config.min_event_bytes, + "sqlite-watcher starting" + ); + + let queue = ChangeQueue::open(&config.queue_path)?; + let decoder = WalGrowthDecoder::default(); + let server_handle = start_grpc_server(&config.listen, &config.queue_path, &auth_token)?; + let (event_tx, event_rx) = mpsc::channel(); + let _wal_handle = start_wal_watcher( + &config.database_path, + TailConfig { + poll_interval: config.poll_interval, + min_event_bytes: config.min_event_bytes, + }, + event_tx, + )?; + + for event in event_rx { + match process_wal_event(&decoder, &queue, &event) { + Ok(change_ids) if !change_ids.is_empty() => { + tracing::info!( + bytes_added = event.bytes_added, + wal_size = event.current_size, + change_count = change_ids.len(), + "queued wal growth event" + ); + } + Err(err) => { + tracing::warn!(error = %err, "failed to persist wal event to queue"); + } + _ => {} + } + } + + drop(server_handle); + Ok(()) +} + +fn process_wal_event( + decoder: &WalGrowthDecoder, + queue: &ChangeQueue, + event: &sqlite_watcher::wal::WalEvent, +) -> Result> { + let mut ids = Vec::new(); + for row in decoder.decode(event) { + ids.push(queue.enqueue(&row.into_new_change())?); + } + Ok(ids) +} + +fn read_token_file(path: &Path) -> Result { + let contents = fs::read_to_string(path) + .with_context(|| format!("failed to read token file {}", path.display()))?; + let token = contents.trim().to_string(); + if token.is_empty() { + bail!("token file {} is empty", path.display()); + } + Ok(token) +} + +fn start_grpc_server( + listen: &ListenAddress, + queue_path: &Path, + token: &str, +) -> Result> { + match listen { + ListenAddress::Tcp { host, port } => { + let addr: SocketAddr = format!("{}:{}", host, port) + .parse() + .with_context(|| format!("invalid tcp listen address {host}:{port}"))?; + let handle = spawn_tcp_server(addr, queue_path.to_path_buf(), token.to_string())?; + Ok(Some(handle)) + } + ListenAddress::Unix(path) => { + #[cfg(unix)] + { + let handle = spawn_unix_server(path, queue_path.to_path_buf(), token.to_string())?; + Ok(Some(handle)) + } + #[cfg(not(unix))] + { + bail!("unix sockets are not supported on this platform") + } + } + ListenAddress::Pipe(name) => { + tracing::warn!(pipe = name, "named pipe transport is not yet implemented"); + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + use sqlite_watcher::queue::ChangeQueue; + use tempfile::{tempdir, NamedTempFile}; + + #[test] + fn parses_tcp_listener() { + let tmp = NamedTempFile::new().unwrap(); + let cli = Cli::try_parse_from([ + "sqlite-watcher", + "--db", + tmp.path().to_str().unwrap(), + "--listen", + "tcp:6000", + "--token-file", + "./token", + "--log-level", + "debug", + ]) + .expect("cli parsing failed"); + + let config = WatcherConfig::try_from(cli).expect("config conversion failed"); + assert!(matches!( + config.listen, + ListenAddress::Tcp { host, port } if host == "127.0.0.1" && port == 6000 + )); + assert!(config.token_file.ends_with("token")); + assert!(config.queue_path.ends_with("changes.db")); + } + + #[test] + #[cfg(unix)] + fn parses_unix_listener_default() { + let tmp = NamedTempFile::new().unwrap(); + let cli = + Cli::try_parse_from(["sqlite-watcher", "--db", tmp.path().to_str().unwrap()]).unwrap(); + let config = WatcherConfig::try_from(cli).unwrap(); + assert!(matches!(config.listen, ListenAddress::Unix(_))); + } + + #[test] + fn persist_wal_events_into_queue() { + let dir = tempdir().unwrap(); + let queue_path = dir.path().join("queue.db"); + let queue = ChangeQueue::open(&queue_path).unwrap(); + let decoder = WalGrowthDecoder::default(); + + let event = sqlite_watcher::wal::WalEvent { + bytes_added: 2048, + current_size: 4096, + }; + let change_ids = process_wal_event(&decoder, &queue, &event).unwrap(); + let batch = queue.fetch_batch(10).unwrap(); + assert_eq!(batch.len(), change_ids.len()); + assert_eq!(batch[0].table_name, "__wal__"); + } +} diff --git a/sqlite-watcher/src/server.rs b/sqlite-watcher/src/server.rs new file mode 100644 index 0000000..d41004f --- /dev/null +++ b/sqlite-watcher/src/server.rs @@ -0,0 +1,282 @@ +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; + +use anyhow::{Context, Result}; +use tokio::runtime::Builder; +use tokio::sync::oneshot; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::service::Interceptor; +use tonic::{transport::Server, Request, Response, Status}; + +use crate::queue::{ChangeQueue, QueueState}; +use crate::watcher_proto::watcher_server::{Watcher, WatcherServer}; +use crate::watcher_proto::{ + AckChangesRequest, AckChangesResponse, Change, GetStateRequest, GetStateResponse, + HealthCheckRequest, HealthCheckResponse, ListChangesRequest, ListChangesResponse, + SetStateRequest, SetStateResponse, +}; + +#[cfg(unix)] +use tokio::net::UnixListener; +#[cfg(unix)] +use tokio_stream::wrappers::UnixListenerStream; + +pub struct ServerHandle { + shutdown: Option>, + thread: Option>>, + #[cfg(unix)] + unix_path: Option, +} + +impl Drop for ServerHandle { + fn drop(&mut self) { + if let Some(tx) = self.shutdown.take() { + let _ = tx.send(()); + } + if let Some(handle) = self.thread.take() { + let _ = handle.join(); + } + #[cfg(unix)] + if let Some(path) = self.unix_path.take() { + let _ = std::fs::remove_file(path); + } + } +} + +pub fn spawn_tcp_server( + addr: SocketAddr, + queue_path: PathBuf, + token: String, +) -> Result { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let thread = thread::spawn(move || -> Result<()> { + let runtime = Builder::new_multi_thread() + .enable_all() + .build() + .context("failed to build tokio runtime")?; + runtime.block_on(async move { + let listener = tokio::net::TcpListener::bind(addr) + .await + .context("failed to bind tcp listener")?; + let queue_path = Arc::new(queue_path); + let svc = WatcherService::new(queue_path); + let interceptor = AuthInterceptor { + token: Arc::new(token), + }; + Server::builder() + .add_service(WatcherServer::with_interceptor(svc, interceptor)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async move { + let _ = shutdown_rx.await; + }) + .await + .context("grpc server exited with error")?; + Ok::<(), anyhow::Error>(()) + })?; + Ok(()) + }); + + Ok(ServerHandle { + shutdown: Some(shutdown_tx), + thread: Some(thread), + #[cfg(unix)] + unix_path: None, + }) +} + +#[cfg(unix)] +pub fn spawn_unix_server(path: &Path, queue_path: PathBuf, token: String) -> Result { + if path.exists() { + std::fs::remove_file(path) + .with_context(|| format!("failed to remove stale unix socket {}", path.display()))?; + } + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("failed to create unix socket dir {}", parent.display()))?; + } + let path_buf = path.to_path_buf(); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let path_for_drop = path_buf.clone(); + let thread = thread::spawn(move || -> Result<()> { + let runtime = Builder::new_multi_thread() + .enable_all() + .build() + .context("failed to build tokio runtime")?; + runtime.block_on(async move { + let listener = UnixListener::bind(&path_buf).context("failed to bind unix socket")?; + let queue_path = Arc::new(queue_path); + let svc = WatcherService::new(queue_path); + let interceptor = AuthInterceptor { + token: Arc::new(token), + }; + Server::builder() + .add_service(WatcherServer::with_interceptor(svc, interceptor)) + .serve_with_incoming_shutdown(UnixListenerStream::new(listener), async move { + let _ = shutdown_rx.await; + }) + .await + .context("grpc server exited with error")?; + Ok::<(), anyhow::Error>(()) + })?; + Ok(()) + }); + + Ok(ServerHandle { + shutdown: Some(shutdown_tx), + thread: Some(thread), + unix_path: Some(path_for_drop), + }) +} + +struct WatcherService { + queue_path: Arc, +} + +impl WatcherService { + fn new(queue_path: Arc) -> Self { + Self { queue_path } + } + + fn open_queue(&self) -> Result { + ChangeQueue::open(&*self.queue_path) + } +} + +#[derive(Clone)] +struct AuthInterceptor { + token: Arc, +} + +impl Interceptor for AuthInterceptor { + fn call(&mut self, request: Request<()>) -> Result, Status> { + let header = request + .metadata() + .get("authorization") + .ok_or_else(|| Status::unauthenticated("missing authorization header"))?; + let expected = format!("Bearer {}", self.token.as_ref()); + if header + .to_str() + .map(|value| value == expected) + .unwrap_or(false) + { + Ok(request) + } else { + Err(Status::unauthenticated("invalid authorization header")) + } + } +} + +#[tonic::async_trait] +impl Watcher for WatcherService { + async fn health_check( + &self, + _: Request, + ) -> Result, Status> { + Ok(Response::new(HealthCheckResponse { + status: "ok".to_string(), + })) + } + + async fn list_changes( + &self, + request: Request, + ) -> Result, Status> { + let limit = request.get_ref().limit.max(1).min(10_000) as usize; + let queue = self + .open_queue() + .map_err(|err| Status::internal(err.to_string()))?; + let rows = queue + .fetch_batch(limit) + .map_err(|err| Status::internal(err.to_string()))?; + let changes = rows.into_iter().map(change_to_proto).collect(); + Ok(Response::new(ListChangesResponse { changes })) + } + + async fn ack_changes( + &self, + request: Request, + ) -> Result, Status> { + let upto = request.get_ref().up_to_change_id; + let queue = self + .open_queue() + .map_err(|err| Status::internal(err.to_string()))?; + let count = queue + .ack_up_to(upto) + .map_err(|err| Status::internal(err.to_string()))?; + Ok(Response::new(AckChangesResponse { + acknowledged: count, + })) + } + + async fn get_state( + &self, + request: Request, + ) -> Result, Status> { + let name = request.get_ref().table_name.clone(); + let queue = self + .open_queue() + .map_err(|err| Status::internal(err.to_string()))?; + let state = queue + .get_state(&name) + .map_err(|err| Status::internal(err.to_string()))?; + let resp = match state { + Some(state) => GetStateResponse { + exists: true, + last_change_id: state.last_change_id, + last_wal_frame: state.last_wal_frame.unwrap_or_default(), + cursor: state.cursor.unwrap_or_default(), + }, + None => GetStateResponse { + exists: false, + last_change_id: 0, + last_wal_frame: String::new(), + cursor: String::new(), + }, + }; + Ok(Response::new(resp)) + } + + async fn set_state( + &self, + request: Request, + ) -> Result, Status> { + let payload = request.into_inner(); + if payload.table_name.is_empty() { + return Err(Status::invalid_argument("table_name is required")); + } + let queue = self + .open_queue() + .map_err(|err| Status::internal(err.to_string()))?; + let state = QueueState { + table_name: payload.table_name, + last_change_id: payload.last_change_id, + last_wal_frame: if payload.last_wal_frame.is_empty() { + None + } else { + Some(payload.last_wal_frame) + }, + cursor: if payload.cursor.is_empty() { + None + } else { + Some(payload.cursor) + }, + }; + queue + .set_state(&state) + .map_err(|err| Status::internal(err.to_string()))?; + Ok(Response::new(SetStateResponse {})) + } +} + +fn change_to_proto(row: crate::queue::ChangeRecord) -> Change { + Change { + change_id: row.change_id, + table_name: row.table_name, + op: row.operation.as_str().to_string(), + primary_key: row.primary_key, + payload: row.payload.unwrap_or_default(), + wal_frame: row.wal_frame.unwrap_or_default(), + cursor: row.cursor.unwrap_or_default(), + } +} diff --git a/sqlite-watcher/src/wal.rs b/sqlite-watcher/src/wal.rs new file mode 100644 index 0000000..13472c0 --- /dev/null +++ b/sqlite-watcher/src/wal.rs @@ -0,0 +1,236 @@ +use std::ffi::OsString; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::Sender; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use std::time::Duration; + +use anyhow::{Context, Result}; +use tracing::{debug, warn}; + +#[derive(Debug, Clone, Copy)] +pub struct WalWatcherConfig { + pub poll_interval: Duration, + pub min_event_bytes: u64, +} + +impl Default for WalWatcherConfig { + fn default() -> Self { + Self { + poll_interval: Duration::from_millis(500), + min_event_bytes: 0, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WalEvent { + pub bytes_added: u64, + pub current_size: u64, +} + +pub struct WalWatcherHandle { + stop: Arc, + thread: Option>, +} + +impl Drop for WalWatcherHandle { + fn drop(&mut self) { + self.stop.store(true, Ordering::SeqCst); + if let Some(handle) = self.thread.take() { + let _ = handle.join(); + } + } +} + +pub fn start_wal_watcher>( + db_path: P, + options: WalWatcherConfig, + sender: Sender, +) -> Result { + let db_path = db_path.as_ref().canonicalize().with_context(|| { + format!( + "failed to canonicalize database path {}", + db_path.as_ref().display() + ) + })?; + if !db_path.is_file() { + anyhow::bail!("database path {} is not a file", db_path.display()); + } + + let wal_path = wal_file_path(&db_path); + let poll_interval = options.poll_interval; + let min_event_bytes = options.min_event_bytes; + let stop_flag = Arc::new(AtomicBool::new(false)); + let thread_stop = Arc::clone(&stop_flag); + + let handle = thread::spawn(move || { + let mut last_len = wal_file_size(&wal_path).unwrap_or(0); + debug!( + wal = %wal_path.display(), + last_len, + "wal watcher started" + ); + while !thread_stop.load(Ordering::SeqCst) { + match wal_file_size(&wal_path) { + Ok(len) => { + if len < last_len { + debug!( + wal = %wal_path.display(), + prev = last_len, + current = len, + "wal truncated; resetting baseline" + ); + last_len = len; + } else if len > last_len { + let delta = len - last_len; + last_len = len; + if delta >= min_event_bytes { + let event = WalEvent { + bytes_added: delta, + current_size: len, + }; + if sender.send(event).is_err() { + debug!("wal watcher stopping because receiver closed"); + break; + } + } + } + } + Err(err) => { + if err.kind() == std::io::ErrorKind::NotFound { + last_len = 0; + } else { + warn!( + wal = %wal_path.display(), + error = %err, + "failed to read wal metadata" + ); + } + } + } + + thread::sleep(poll_interval); + } + + debug!("wal watcher exiting"); + }); + + Ok(WalWatcherHandle { + stop: stop_flag, + thread: Some(handle), + }) +} + +fn wal_file_path(db_path: &Path) -> PathBuf { + let mut os_string = OsString::from(db_path.as_os_str()); + os_string.push("-wal"); + PathBuf::from(os_string) +} + +fn wal_file_size(path: &Path) -> std::io::Result { + std::fs::metadata(path).map(|m| m.len()) +} + +#[cfg(test)] +mod tests { + use super::*; + use rusqlite::Connection; + use std::sync::mpsc::channel; + use std::time::{Duration, Instant}; + use tempfile::tempdir; + + #[test] + fn emits_event_when_wal_grows() { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("watch.sqlite"); + let writer = Connection::open(&db_path).unwrap(); + writer.pragma_update(None, "journal_mode", &"wal").unwrap(); + writer + .pragma_update(None, "wal_autocheckpoint", &0i64) + .unwrap(); + writer + .execute( + "CREATE TABLE changes(id INTEGER PRIMARY KEY, value TEXT)", + [], + ) + .unwrap(); + + let (tx, rx) = channel(); + let handle = start_wal_watcher( + &db_path, + WalWatcherConfig { + poll_interval: Duration::from_millis(50), + min_event_bytes: 1, + }, + tx, + ) + .unwrap(); + + for i in 0..50 { + writer + .execute( + "INSERT INTO changes(value) VALUES (?1)", + [format!("value-{i}")], + ) + .unwrap(); + } + + let event = rx.recv_timeout(Duration::from_secs(5)).unwrap(); + assert!(event.bytes_added > 0); + assert!(event.current_size >= event.bytes_added); + + drop(handle); + } + + #[test] + fn handles_wal_truncation() { + let dir = tempdir().unwrap(); + let db_path = dir.path().join("truncate.sqlite"); + let writer = Connection::open(&db_path).unwrap(); + writer.pragma_update(None, "journal_mode", &"wal").unwrap(); + writer + .pragma_update(None, "wal_autocheckpoint", &0i64) + .unwrap(); + writer + .execute("CREATE TABLE stuff(id INTEGER PRIMARY KEY, value TEXT)", []) + .unwrap(); + + let (tx, rx) = channel(); + let handle = start_wal_watcher( + &db_path, + WalWatcherConfig { + poll_interval: Duration::from_millis(25), + min_event_bytes: 1, + }, + tx, + ) + .unwrap(); + + for i in 0..10 { + writer + .execute("INSERT INTO stuff(value) VALUES (?1)", [format!("row-{i}")]) + .unwrap(); + } + + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + + writer + .execute_batch("PRAGMA wal_checkpoint(TRUNCATE);") + .unwrap(); + + // Ensure watcher does not send negative deltas (would panic or overflow) + let start = Instant::now(); + loop { + if rx.recv_timeout(Duration::from_millis(100)).is_ok() { + break; + } + if start.elapsed() > Duration::from_millis(500) { + break; + } + } + + drop(handle); + } +} diff --git a/sqlite-watcher/tests/server_tests.rs b/sqlite-watcher/tests/server_tests.rs new file mode 100644 index 0000000..0fe5526 --- /dev/null +++ b/sqlite-watcher/tests/server_tests.rs @@ -0,0 +1,64 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use sqlite_watcher::server::spawn_tcp_server; +#[cfg(unix)] +use sqlite_watcher::server::spawn_unix_server; +use sqlite_watcher::watcher_proto::watcher_client::WatcherClient; +use sqlite_watcher::watcher_proto::HealthCheckRequest; +use tempfile::tempdir; +use tokio::time::sleep; +use tonic::metadata::MetadataValue; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn health_check_responds_ok() { + let dir = tempdir().unwrap(); + let queue_path = dir.path().join("queue.db"); + let addr: SocketAddr = "127.0.0.1:55051".parse().unwrap(); + let token = "secret-token".to_string(); + + let _handle = spawn_tcp_server(addr, queue_path, token.clone()).unwrap(); + sleep(Duration::from_millis(200)).await; + + let channel = tonic::transport::Channel::from_shared(format!("http://{}", addr)) + .unwrap() + .connect() + .await + .unwrap(); + let mut client = WatcherClient::new(channel); + let mut req = tonic::Request::new(HealthCheckRequest {}); + let header = MetadataValue::try_from(format!("Bearer {}", token)).unwrap(); + req.metadata_mut().insert("authorization", header); + let resp = client.health_check(req).await.unwrap(); + assert_eq!(resp.into_inner().status, "ok"); +} +#[cfg(unix)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn health_check_over_unix_socket() { + use tokio::net::UnixStream; + use tonic::transport::Endpoint; + use tower::service_fn; + + let dir = tempdir().unwrap(); + let queue_path = dir.path().join("queue.db"); + let socket_path = dir.path().join("watcher.sock"); + let token = "secret-token".to_string(); + + let _handle = spawn_unix_server(&socket_path, queue_path, token.clone()).unwrap(); + sleep(Duration::from_millis(200)).await; + + let endpoint = Endpoint::try_from("http://[::]:50051").unwrap(); + let channel = endpoint + .connect_with_connector(service_fn(move |_: tonic::transport::Uri| { + let path = socket_path.clone(); + async move { UnixStream::connect(path).await } + })) + .await + .unwrap(); + let mut client = WatcherClient::new(channel); + let mut req = tonic::Request::new(HealthCheckRequest {}); + let header = MetadataValue::try_from(format!("Bearer {}", token)).unwrap(); + req.metadata_mut().insert("authorization", header); + let resp = client.health_check(req).await.unwrap(); + assert_eq!(resp.into_inner().status, "ok"); +}