use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; use std::time::Instant; use tokio::sync::RwLock; use zeroize::Zeroize; const FULL_SESSION_TTL: u64 = 86400; // 24 hours const PENDING_SESSION_TTL: u64 = 300; // 5 minutes const MAX_TOTP_ATTEMPTS: u8 = 5; #[derive(Clone)] enum SessionType { Full, PendingTotp { totp_secret: Vec, attempts: u8, }, } impl Drop for SessionType { fn drop(&mut self) { if let SessionType::PendingTotp { totp_secret, .. } = self { totp_secret.zeroize(); } } } #[derive(Clone)] struct Session { created_at: Instant, session_type: SessionType, } #[derive(Clone)] pub struct SessionStore { sessions: Arc>>, } impl SessionStore { pub fn new() -> Self { Self { sessions: Arc::new(RwLock::new(HashMap::new())), } } /// Create a full (authenticated) session. Returns the plaintext token. pub async fn create(&self) -> String { let token_bytes: [u8; 32] = rand::random(); let token = hex::encode(token_bytes); let hash = hash_token(&token); let session = Session { created_at: Instant::now(), session_type: SessionType::Full, }; self.sessions.write().await.insert(hash, session); token } /// Create a pending TOTP session (password verified, awaiting TOTP). /// Caches the decrypted TOTP secret in memory for verification. pub async fn create_pending(&self, totp_secret: Vec) -> String { let token_bytes: [u8; 32] = rand::random(); let token = hex::encode(token_bytes); let hash = hash_token(&token); let session = Session { created_at: Instant::now(), session_type: SessionType::PendingTotp { totp_secret, attempts: 0, }, }; self.sessions.write().await.insert(hash, session); token } /// Validate a full session token. Returns true if the session exists and hasn't expired. pub async fn validate(&self, token: &str) -> bool { let hash = hash_token(token); let sessions = self.sessions.read().await; if let Some(session) = sessions.get(&hash) { matches!(session.session_type, SessionType::Full) && session.created_at.elapsed().as_secs() < FULL_SESSION_TTL } else { false } } /// Get the TOTP secret from a pending session. Returns None if not a valid pending session. /// Increments the attempt counter. pub async fn get_pending_secret(&self, token: &str) -> Option> { let hash = hash_token(token); let mut sessions = self.sessions.write().await; if let Some(session) = sessions.get_mut(&hash) { if session.created_at.elapsed().as_secs() >= PENDING_SESSION_TTL { sessions.remove(&hash); return None; } if let SessionType::PendingTotp { ref totp_secret, ref mut attempts, } = session.session_type { *attempts += 1; if *attempts > MAX_TOTP_ATTEMPTS { sessions.remove(&hash); // Too many attempts, force re-login return None; } return Some(totp_secret.clone()); } } None } /// Upgrade a pending session to a full session. pub async fn upgrade_to_full(&self, token: &str) { let hash = hash_token(token); let mut sessions = self.sessions.write().await; if let Some(session) = sessions.get_mut(&hash) { session.session_type = SessionType::Full; session.created_at = Instant::now(); // Reset TTL to 24h from now } } pub async fn remove(&self, token: &str) { let hash = hash_token(token); self.sessions.write().await.remove(&hash); } } fn hash_token(token: &str) -> [u8; 32] { let mut hasher = Sha256::new(); hasher.update(token.as_bytes()); hasher.finalize().into() } /// Extract the session token from a Cookie header value. pub fn extract_session_cookie(headers: &hyper::HeaderMap) -> Option { headers .get("cookie") .and_then(|v| v.to_str().ok()) .and_then(|cookies| { cookies.split(';').find_map(|c| { let c = c.trim(); c.strip_prefix("session=").map(|v| v.to_string()) }) }) .filter(|v| !v.is_empty()) } /// Rate limiter for login attempts: max 5 failures per 60 seconds per IP. #[derive(Clone)] pub struct LoginRateLimiter { attempts: Arc>>>, } const MAX_ATTEMPTS: usize = 5; const WINDOW_SECS: u64 = 60; impl LoginRateLimiter { pub fn new() -> Self { Self { attempts: Arc::new(RwLock::new(HashMap::new())), } } pub async fn check(&self, ip: IpAddr) -> bool { let mut attempts = self.attempts.write().await; let now = Instant::now(); let entry = attempts.entry(ip).or_default(); entry.retain(|t| now.duration_since(*t).as_secs() < WINDOW_SECS); entry.len() < MAX_ATTEMPTS } pub async fn record_failure(&self, ip: IpAddr) { let mut attempts = self.attempts.write().await; let entry = attempts.entry(ip).or_default(); entry.push(Instant::now()); } }