use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; use std::time::Instant; use tokio::sync::RwLock; /// Rate limiter for login attempts: max 5 failures per 60 seconds per IP. #[derive(Clone)] pub struct LoginRateLimiter { attempts: Arc>>>, } const MAX_ATTEMPTS: usize = 5; const WINDOW_SECS: u64 = 60; impl LoginRateLimiter { pub fn new() -> Self { Self { attempts: Arc::new(RwLock::new(HashMap::new())), } } pub async fn check(&self, ip: IpAddr) -> bool { let mut attempts = self.attempts.write().await; let now = Instant::now(); let entry = attempts.entry(ip).or_default(); entry.retain(|t| now.duration_since(*t).as_secs() < WINDOW_SECS); entry.len() < MAX_ATTEMPTS } pub async fn record_failure(&self, ip: IpAddr) { let mut attempts = self.attempts.write().await; let entry = attempts.entry(ip).or_default(); entry.push(Instant::now()); } /// Periodic cleanup of expired entries for IPs that are no longer active. pub async fn cleanup(&self) { let mut attempts = self.attempts.write().await; let now = Instant::now(); attempts.retain(|_, timestamps| { timestamps.retain(|t| now.duration_since(*t).as_secs() < WINDOW_SECS); !timestamps.is_empty() }); } } /// General-purpose rate limiter for sensitive endpoints. /// Tracks request counts per (method, IP) with configurable limits and windows. #[derive(Clone)] pub struct EndpointRateLimiter { /// Map of (method, ip) -> list of request timestamps requests: Arc>>>, // Instant for monotonic rate limiting /// Per-method configuration: (max_requests, window_secs) limits: Arc>, } impl EndpointRateLimiter { pub fn new() -> Self { let mut limits = HashMap::new(); // Financial operations: strict limits limits.insert("wallet.send".to_string(), (5usize, 300u64)); limits.insert("wallet.ecash-send".to_string(), (10, 300)); limits.insert("lnd.sendcoins".to_string(), (5, 300)); limits.insert("lnd.payinvoice".to_string(), (10, 300)); limits.insert("lnd.openchannel".to_string(), (3, 300)); limits.insert("lnd.closechannel".to_string(), (3, 300)); limits.insert("lnd.create-psbt".to_string(), (5, 300)); limits.insert("lnd.finalize-psbt".to_string(), (5, 300)); // Identity/credential operations limits.insert("identity.create".to_string(), (10, 300)); limits.insert("identity.issue-credential".to_string(), (20, 300)); // Backup operations (resource-intensive) limits.insert("backup.create".to_string(), (10, 600)); limits.insert("backup.restore".to_string(), (5, 600)); // Container operations limits.insert("container-install".to_string(), (5, 300)); limits.insert("package.install".to_string(), (5, 300)); // S3 backup operations (resource-intensive) limits.insert("backup.upload-s3".to_string(), (3, 600)); limits.insert("backup.download-s3".to_string(), (3, 600)); // System operations limits.insert("update.apply".to_string(), (2, 600)); limits.insert("system.reboot".to_string(), (2, 300)); limits.insert("system.shutdown".to_string(), (2, 300)); // Password and TOTP changes limits.insert("auth.changePassword".to_string(), (3, 300)); limits.insert("auth.totp.setup".to_string(), (3, 300)); limits.insert("auth.totp.confirm".to_string(), (5, 300)); // Federation join: prevent invite-code brute force limits.insert("federation.join".to_string(), (5, 60)); limits.insert("federation.invite".to_string(), (10, 300)); // Inter-node federation RPCs (unauthenticated, need stricter limits) limits.insert("federation.peer-joined".to_string(), (10, 60)); limits.insert("federation.peer-address-changed".to_string(), (10, 60)); limits.insert("federation.peer-did-changed".to_string(), (5, 60)); limits.insert("federation.get-state".to_string(), (30, 60)); // DID rotation: sensitive identity operation limits.insert("node.rotate-did".to_string(), (3, 600)); Self { requests: Arc::new(RwLock::new(HashMap::new())), limits: Arc::new(limits), } } /// Check if a request is allowed. Returns true if within limits. pub async fn check(&self, method: &str, ip: IpAddr) -> bool { let (max_req, window) = match self.limits.get(method) { Some(config) => *config, None => return true, // Not rate-limited }; let key = (method.to_string(), ip); let mut requests = self.requests.write().await; let now = Instant::now(); let entry = requests.entry(key).or_default(); entry.retain(|t| now.duration_since(*t).as_secs() < window); entry.len() < max_req } /// Record a request for rate limiting purposes. pub async fn record(&self, method: &str, ip: IpAddr) { if !self.limits.contains_key(method) { return; // Not rate-limited, skip tracking } let key = (method.to_string(), ip); let mut requests = self.requests.write().await; let entry = requests.entry(key).or_default(); entry.push(Instant::now()); } /// Periodic cleanup of expired entries. pub async fn cleanup(&self) { let mut requests = self.requests.write().await; let now = Instant::now(); requests.retain(|(method, _), timestamps| { let window = self .limits .get(method) .map(|(_, w)| *w) .unwrap_or(300); timestamps.retain(|t| now.duration_since(*t).as_secs() < window); !timestamps.is_empty() }); } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_rate_limiter_allows_under_limit() { let limiter = LoginRateLimiter::new(); let ip: IpAddr = "127.0.0.1".parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)); for _ in 0..MAX_ATTEMPTS { assert!(limiter.check(ip).await); limiter.record_failure(ip).await; } } #[tokio::test] async fn test_rate_limiter_blocks_over_limit() { let limiter = LoginRateLimiter::new(); let ip: IpAddr = "127.0.0.1".parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)); for _ in 0..MAX_ATTEMPTS { limiter.record_failure(ip).await; } assert!(!limiter.check(ip).await); } #[tokio::test] async fn test_rate_limiter_different_ips() { let limiter = LoginRateLimiter::new(); let ip1: IpAddr = "127.0.0.1".parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)); let ip2: IpAddr = "192.168.1.1".parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)); for _ in 0..MAX_ATTEMPTS { limiter.record_failure(ip1).await; } // ip1 should be blocked assert!(!limiter.check(ip1).await); // ip2 should still be allowed assert!(limiter.check(ip2).await); } }