archy/core/archipelago/src/session.rs

296 lines
8.8 KiB
Rust
Raw Normal View History

use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use zeroize::Zeroize;
const FULL_SESSION_TTL: u64 = 86400; // 24 hours
const PENDING_SESSION_TTL: u64 = 300; // 5 minutes
const MAX_TOTP_ATTEMPTS: u8 = 5;
#[derive(Clone)]
enum SessionType {
Full,
PendingTotp {
totp_secret: Vec<u8>,
attempts: u8,
},
}
impl Drop for SessionType {
fn drop(&mut self) {
if let SessionType::PendingTotp { totp_secret, .. } = self {
totp_secret.zeroize();
}
}
}
#[derive(Clone)]
struct Session {
created_at: Instant,
session_type: SessionType,
}
#[derive(Clone)]
pub struct SessionStore {
sessions: Arc<RwLock<HashMap<[u8; 32], Session>>>,
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Create a full (authenticated) session. Returns the plaintext token.
pub async fn create(&self) -> String {
let token_bytes: [u8; 32] = rand::random();
let token = hex::encode(token_bytes);
let hash = hash_token(&token);
let session = Session {
created_at: Instant::now(),
session_type: SessionType::Full,
};
self.sessions.write().await.insert(hash, session);
token
}
/// Create a pending TOTP session (password verified, awaiting TOTP).
/// Caches the decrypted TOTP secret in memory for verification.
pub async fn create_pending(&self, totp_secret: Vec<u8>) -> String {
let token_bytes: [u8; 32] = rand::random();
let token = hex::encode(token_bytes);
let hash = hash_token(&token);
let session = Session {
created_at: Instant::now(),
session_type: SessionType::PendingTotp {
totp_secret,
attempts: 0,
},
};
self.sessions.write().await.insert(hash, session);
token
}
/// Validate a full session token. Returns true if the session exists and hasn't expired.
pub async fn validate(&self, token: &str) -> bool {
let hash = hash_token(token);
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(&hash) {
matches!(session.session_type, SessionType::Full)
&& session.created_at.elapsed().as_secs() < FULL_SESSION_TTL
} else {
false
}
}
/// Get the TOTP secret from a pending session. Returns None if not a valid pending session.
/// Increments the attempt counter.
pub async fn get_pending_secret(&self, token: &str) -> Option<Vec<u8>> {
let hash = hash_token(token);
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(&hash) {
if session.created_at.elapsed().as_secs() >= PENDING_SESSION_TTL {
sessions.remove(&hash);
return None;
}
if let SessionType::PendingTotp {
ref totp_secret,
ref mut attempts,
} = session.session_type
{
*attempts += 1;
if *attempts > MAX_TOTP_ATTEMPTS {
sessions.remove(&hash); // Too many attempts, force re-login
return None;
}
return Some(totp_secret.clone());
}
}
None
}
/// Upgrade a pending session to a full session.
pub async fn upgrade_to_full(&self, token: &str) {
let hash = hash_token(token);
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(&hash) {
session.session_type = SessionType::Full;
session.created_at = Instant::now(); // Reset TTL to 24h from now
}
}
pub async fn remove(&self, token: &str) {
let hash = hash_token(token);
self.sessions.write().await.remove(&hash);
}
}
fn hash_token(token: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hasher.finalize().into()
}
/// Extract the session token from a Cookie header value.
pub fn extract_session_cookie(headers: &hyper::HeaderMap) -> Option<String> {
headers
.get("cookie")
.and_then(|v| v.to_str().ok())
.and_then(|cookies| {
cookies.split(';').find_map(|c| {
let c = c.trim();
c.strip_prefix("session=").map(|v| v.to_string())
})
})
.filter(|v| !v.is_empty())
}
/// Rate limiter for login attempts: max 5 failures per 60 seconds per IP.
#[derive(Clone)]
pub struct LoginRateLimiter {
attempts: Arc<RwLock<HashMap<IpAddr, Vec<Instant>>>>,
}
const MAX_ATTEMPTS: usize = 5;
const WINDOW_SECS: u64 = 60;
impl LoginRateLimiter {
pub fn new() -> Self {
Self {
attempts: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn check(&self, ip: IpAddr) -> bool {
let mut attempts = self.attempts.write().await;
let now = Instant::now();
let entry = attempts.entry(ip).or_default();
entry.retain(|t| now.duration_since(*t).as_secs() < WINDOW_SECS);
entry.len() < MAX_ATTEMPTS
}
pub async fn record_failure(&self, ip: IpAddr) {
let mut attempts = self.attempts.write().await;
let entry = attempts.entry(ip).or_default();
entry.push(Instant::now());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_session_create_and_validate() {
let store = SessionStore::new();
let token = store.create().await;
assert!(store.validate(&token).await);
}
#[tokio::test]
async fn test_session_invalid_token() {
let store = SessionStore::new();
assert!(!store.validate("nonexistent_token").await);
}
#[tokio::test]
async fn test_session_remove() {
let store = SessionStore::new();
let token = store.create().await;
assert!(store.validate(&token).await);
store.remove(&token).await;
assert!(!store.validate(&token).await);
}
#[tokio::test]
async fn test_pending_session_upgrade() {
let store = SessionStore::new();
let secret = vec![1, 2, 3, 4];
let token = store.create_pending(secret.clone()).await;
// Pending session should not validate as full
assert!(!store.validate(&token).await);
// Can get the TOTP secret
let got = store.get_pending_secret(&token).await;
assert_eq!(got, Some(secret));
// Upgrade to full
store.upgrade_to_full(&token).await;
assert!(store.validate(&token).await);
}
#[tokio::test]
async fn test_pending_session_max_attempts() {
let store = SessionStore::new();
let secret = vec![1, 2, 3];
let token = store.create_pending(secret).await;
// Exhaust MAX_TOTP_ATTEMPTS (5) + 1 to trigger removal
for _ in 0..MAX_TOTP_ATTEMPTS {
assert!(store.get_pending_secret(&token).await.is_some());
}
// 6th attempt should fail (session removed)
assert!(store.get_pending_secret(&token).await.is_none());
}
#[tokio::test]
async fn test_extract_session_cookie() {
let mut headers = hyper::HeaderMap::new();
headers.insert("cookie", "session=abc123; other=xyz".parse().unwrap());
assert_eq!(extract_session_cookie(&headers), Some("abc123".to_string()));
}
#[tokio::test]
async fn test_extract_session_cookie_missing() {
let headers = hyper::HeaderMap::new();
assert_eq!(extract_session_cookie(&headers), None);
}
#[tokio::test]
async fn test_rate_limiter_allows_under_limit() {
let limiter = LoginRateLimiter::new();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
for _ in 0..MAX_ATTEMPTS {
assert!(limiter.check(ip).await);
limiter.record_failure(ip).await;
}
}
#[tokio::test]
async fn test_rate_limiter_blocks_over_limit() {
let limiter = LoginRateLimiter::new();
let ip: IpAddr = "127.0.0.1".parse().unwrap();
for _ in 0..MAX_ATTEMPTS {
limiter.record_failure(ip).await;
}
assert!(!limiter.check(ip).await);
}
#[tokio::test]
async fn test_rate_limiter_different_ips() {
let limiter = LoginRateLimiter::new();
let ip1: IpAddr = "127.0.0.1".parse().unwrap();
let ip2: IpAddr = "192.168.1.1".parse().unwrap();
for _ in 0..MAX_ATTEMPTS {
limiter.record_failure(ip1).await;
}
// ip1 should be blocked
assert!(!limiter.check(ip1).await);
// ip2 should still be allowed
assert!(limiter.check(ip2).await);
}
}