//! 内存仓储实现 //! //! 提供基于内存的仓储实现,用于单元测试和集成测试, //! 避免依赖真实数据库 use std::collections::HashMap; use std::sync::{Arc, RwLock}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use rust_decimal::Decimal; use rustjr::domain::ledger::entity::{AccountBalance, DeductionResult}; use rustjr::domain::account::AccountType; use rustjr::domain::ledger::repository::AccountBalanceRepository; use rustjr::domain::transaction::entity::{ SystemTransaction, TransactionStatus, TransactionType, CreateSystemTransactionRequest, }; use rustjr::domain::transaction::repository::SystemTransactionRepository; use rustjr::domain::compensation::{ CompensationTask, CompensationTaskStatus, CompensationTaskType, CompensationTaskRepository, }; use rustjr::error::Result; // ==================== 账户余额内存仓储 ==================== /// 内存账户余额仓储 pub struct InMemoryAccountBalanceRepository { balances: Arc>>, next_id: Arc>, } impl InMemoryAccountBalanceRepository { pub fn new() -> Self { Self { balances: Arc::new(RwLock::new(HashMap::new())), next_id: Arc::new(RwLock::new(1)), } } /// 插入预设余额(测试用) pub fn insert(&self, balance: AccountBalance) { let mut balances = self.balances.write().unwrap(); balances.insert((balance.account_id, balance.account_type), balance); } /// 获取所有余额(测试用) pub fn get_all(&self) -> Vec { let balances = self.balances.read().unwrap(); balances.values().cloned().collect() } } impl Default for InMemoryAccountBalanceRepository { fn default() -> Self { Self::new() } } #[async_trait] impl AccountBalanceRepository for InMemoryAccountBalanceRepository { async fn find_by_account( &self, account_id: i64, account_type: AccountType, ) -> Result> { let balances = self.balances.read().unwrap(); Ok(balances.get(&(account_id, account_type)).cloned()) } async fn update(&self, balance: &AccountBalance) -> Result<()> { let mut balances = self.balances.write().unwrap(); balances.insert((balance.account_id, balance.account_type), balance.clone()); Ok(()) } async fn get_or_create(&self, account_id: i64, account_type: AccountType) -> Result { let mut balances = self.balances.write().unwrap(); let key = (account_id, account_type); if let Some(balance) = balances.get(&key) { return Ok(balance.clone()); } // 创建默认余额 let mut next_id = self.next_id.write().unwrap(); let id = *next_id; *next_id += 1; let balance = AccountBalance { id, account_id, account_type, personal_balance: Decimal::ZERO, labor_balance: Decimal::ZERO, frozen_balance: Decimal::ZERO, bank_balance: Decimal::ZERO, transit_amount: Decimal::ZERO, system_balance: Decimal::ZERO, available_balance: Decimal::ZERO, frozen_amount: Decimal::ZERO, version: 1, updated_at: chrono::Utc::now(), }; balances.insert(key, balance.clone()); Ok(balance) } async fn batch_update(&self, balances: &[AccountBalance]) -> Result<()> { let mut stored = self.balances.write().unwrap(); for balance in balances { let key = (balance.account_id, balance.account_type); stored.insert(key, balance.clone()); } Ok(()) } async fn freeze(&self, account_id: i64, account_type: AccountType, amount: Decimal) -> Result<()> { let mut balances = self.balances.write().unwrap(); let key = (account_id, account_type); if let Some(balance) = balances.get_mut(&key) { balance.freeze(amount); } Ok(()) } async fn unfreeze(&self, account_id: i64, account_type: AccountType, amount: Decimal) -> Result<()> { let mut balances = self.balances.write().unwrap(); let key = (account_id, account_type); if let Some(balance) = balances.get_mut(&key) { balance.unfreeze(amount); } Ok(()) } } // ==================== 系统交易内存仓储 ==================== /// 内存系统交易仓储 pub struct InMemorySystemTransactionRepository { transactions: Arc>>, by_txn_no: Arc>>, by_source_key: Arc>>, next_id: Arc>, } impl InMemorySystemTransactionRepository { pub fn new() -> Self { Self { transactions: Arc::new(RwLock::new(HashMap::new())), by_txn_no: Arc::new(RwLock::new(HashMap::new())), by_source_key: Arc::new(RwLock::new(HashMap::new())), next_id: Arc::new(RwLock::new(1)), } } /// 插入预设交易(测试用) pub fn insert(&self, txn: SystemTransaction) { let mut transactions = self.transactions.write().unwrap(); let mut by_txn_no = self.by_txn_no.write().unwrap(); let mut by_source_key = self.by_source_key.write().unwrap(); by_txn_no.insert(txn.txn_no.clone(), txn.id); if let Some(ref key) = txn.source_key { by_source_key.insert(key.clone(), txn.id); } transactions.insert(txn.id, txn); } /// 获取所有交易(测试用) pub fn get_all(&self) -> Vec { let transactions = self.transactions.read().unwrap(); transactions.values().cloned().collect() } } impl Default for InMemorySystemTransactionRepository { fn default() -> Self { Self::new() } } #[async_trait] impl SystemTransactionRepository for InMemorySystemTransactionRepository { async fn create(&self, request: &CreateSystemTransactionRequest) -> Result { let mut transactions = self.transactions.write().unwrap(); let mut by_txn_no = self.by_txn_no.write().unwrap(); let mut by_source_key = self.by_source_key.write().unwrap(); let mut next_id = self.next_id.write().unwrap(); let id = *next_id; *next_id += 1; // 生成交易号 let txn_no = format!("TXN{:010}", id); let txn = SystemTransaction { id, txn_no: txn_no.clone(), txn_type: request.txn_type.clone(), from_account_id: request.from_account_id, to_account_id: request.to_account_id, amount: request.amount, status: TransactionStatus::Created, bank_ref_no: None, source_key: request.source_key.clone(), remark: request.remark.clone(), created_at: chrono::Utc::now(), confirmed_at: None, submitted_at: None, version: 1, }; by_txn_no.insert(txn_no, id); if let Some(ref key) = txn.source_key { by_source_key.insert(key.clone(), id); } transactions.insert(id, txn.clone()); Ok(txn) } async fn find_by_id(&self, id: i64) -> Result> { let transactions = self.transactions.read().unwrap(); Ok(transactions.get(&id).cloned()) } async fn find_by_txn_no(&self, txn_no: &str) -> Result> { let by_txn_no = self.by_txn_no.read().unwrap(); let transactions = self.transactions.read().unwrap(); if let Some(&id) = by_txn_no.get(txn_no) { return Ok(transactions.get(&id).cloned()); } Ok(None) } async fn find_by_source_key(&self, source_key: &str) -> Result> { let by_source_key = self.by_source_key.read().unwrap(); let transactions = self.transactions.read().unwrap(); if let Some(&id) = by_source_key.get(source_key) { return Ok(transactions.get(&id).cloned()); } Ok(None) } async fn find_by_status(&self, status: TransactionStatus) -> Result> { let transactions = self.transactions.read().unwrap(); let result: Vec<_> = transactions .values() .filter(|t| t.status == status) .cloned() .collect(); Ok(result) } async fn find_timeout(&self, threshold_seconds: i64) -> Result> { let transactions = self.transactions.read().unwrap(); let now = Utc::now(); let threshold = chrono::Duration::seconds(threshold_seconds); let result: Vec<_> = transactions .values() .filter(|t| { t.status == TransactionStatus::BankSubmitted && t.submitted_at.map_or(false, |s| now - s > threshold) }) .cloned() .collect(); Ok(result) } async fn update_status(&self, id: i64, status: TransactionStatus) -> Result<()> { let mut transactions = self.transactions.write().unwrap(); if let Some(txn) = transactions.get_mut(&id) { txn.status = status; } Ok(()) } async fn set_bank_ref_no(&self, id: i64, bank_ref_no: &str) -> Result<()> { let mut transactions = self.transactions.write().unwrap(); if let Some(txn) = transactions.get_mut(&id) { txn.bank_ref_no = Some(bank_ref_no.to_string()); } Ok(()) } async fn set_submitted_at(&self, id: i64, submitted_at: DateTime) -> Result<()> { let mut transactions = self.transactions.write().unwrap(); if let Some(txn) = transactions.get_mut(&id) { txn.submitted_at = Some(submitted_at); } Ok(()) } async fn find_by_bank_ref_no(&self, bank_ref_no: &str) -> Result> { let transactions = self.transactions.read().unwrap(); let txn = transactions.values().find(|t| t.bank_ref_no == Some(bank_ref_no.to_string())); Ok(txn.cloned()) } async fn find_pending(&self) -> Result> { let transactions = self.transactions.read().unwrap(); let result: Vec<_> = transactions .values() .filter(|t| t.status == TransactionStatus::Pending) .cloned() .collect(); Ok(result) } async fn find_needs_reconciliation(&self) -> Result> { let transactions = self.transactions.read().unwrap(); let result: Vec<_> = transactions .values() .filter(|t| matches!( t.status, TransactionStatus::BankSubmitted | TransactionStatus::Timeout )) .cloned() .collect(); Ok(result) } async fn query(&self, query: &rustjr::domain::transaction::entity::TransactionQuery) -> Result> { let transactions = self.transactions.read().unwrap(); let mut result: Vec<_> = transactions.values().cloned().collect(); // 应用过滤条件 if let Some(account_id) = query.account_id { result.retain(|t| t.from_account_id == Some(account_id) || t.to_account_id == Some(account_id)); } if let Some(txn_type) = &query.txn_type { result.retain(|t| t.txn_type == *txn_type); } if let Some(status) = &query.status { result.retain(|t| t.status == *status); } if let Some(start_time) = query.start_time { result.retain(|t| t.created_at >= start_time); } if let Some(end_time) = query.end_time { result.retain(|t| t.created_at <= end_time); } if let Some(min_amount) = query.min_amount { result.retain(|t| t.amount >= min_amount); } if let Some(max_amount) = query.max_amount { result.retain(|t| t.amount <= max_amount); } // 简单的分页 let offset = query.offset.unwrap_or(0) as usize; let limit = query.limit.unwrap_or(50) as usize; if offset < result.len() { result = result[offset..(offset + limit).min(result.len())].to_vec(); } else { result.clear(); } Ok(result) } async fn confirm(&self, id: i64, confirmed_at: DateTime) -> Result<()> { let mut transactions = self.transactions.write().unwrap(); if let Some(txn) = transactions.get_mut(&id) { txn.confirmed_at = Some(confirmed_at); txn.status = TransactionStatus::Success; } Ok(()) } } // ==================== 补偿任务内存仓储 ==================== /// 内存补偿任务仓储 pub struct InMemoryCompensationTaskRepository { tasks: Arc>>, by_txn_no: Arc>>>, next_id: Arc>, } impl InMemoryCompensationTaskRepository { pub fn new() -> Self { Self { tasks: Arc::new(RwLock::new(HashMap::new())), by_txn_no: Arc::new(RwLock::new(HashMap::new())), next_id: Arc::new(RwLock::new(1)), } } /// 获取所有任务(测试用) pub fn get_all(&self) -> Vec { let tasks = self.tasks.read().unwrap(); tasks.values().cloned().collect() } } impl Default for InMemoryCompensationTaskRepository { fn default() -> Self { Self::new() } } #[async_trait] impl CompensationTaskRepository for InMemoryCompensationTaskRepository { async fn create(&self, request: &rustjr::domain::compensation::CreateCompensationTaskRequest) -> Result { let mut tasks = self.tasks.write().unwrap(); let mut by_txn_no = self.by_txn_no.write().unwrap(); let mut next_id = self.next_id.write().unwrap(); let id = *next_id; *next_id += 1; let task = CompensationTask { id, txn_no: request.txn_no.clone(), task_type: request.task_type, status: CompensationTaskStatus::Pending, retry_count: 0, max_retries: request.max_retries.unwrap_or(3), next_retry_at: None, error_message: None, created_at: Utc::now(), updated_at: Utc::now(), completed_at: None, }; by_txn_no.entry(request.txn_no.clone()).or_insert_with(Vec::new).push(id); tasks.insert(id, task.clone()); Ok(task) } async fn find_by_id(&self, id: i64) -> Result> { let tasks = self.tasks.read().unwrap(); Ok(tasks.get(&id).cloned()) } async fn find_by_txn_no(&self, txn_no: &str) -> Result> { let tasks = self.tasks.read().unwrap(); let result: Vec<_> = tasks .values() .filter(|t| t.txn_no == txn_no) .cloned() .collect(); Ok(result) } async fn find_pending(&self, limit: i64) -> Result> { let tasks = self.tasks.read().unwrap(); let mut result: Vec<_> = tasks .values() .filter(|t| t.status == CompensationTaskStatus::Pending) .cloned() .collect(); result.truncate(limit as usize); Ok(result) } async fn find_ready_for_retry(&self, limit: i64) -> Result> { let tasks = self.tasks.read().unwrap(); let now = Utc::now(); let mut result: Vec<_> = tasks .values() .filter(|t| { t.status == CompensationTaskStatus::Failed && t.next_retry_at.map_or(false, |n| n <= now) }) .cloned() .collect(); result.truncate(limit as usize); Ok(result) } async fn find_dead_letter(&self, limit: i64) -> Result> { let tasks = self.tasks.read().unwrap(); let mut result: Vec<_> = tasks .values() .filter(|t| t.status == CompensationTaskStatus::DeadLetter) .cloned() .collect(); result.truncate(limit as usize); Ok(result) } async fn update_status( &self, id: i64, status: CompensationTaskStatus, error_message: Option<&str>, ) -> Result<()> { let mut tasks = self.tasks.write().unwrap(); if let Some(task) = tasks.get_mut(&id) { task.status = status; task.error_message = error_message.map(|s| s.to_string()); task.updated_at = Utc::now(); } Ok(()) } async fn increment_retry( &self, id: i64, next_retry_at: DateTime, ) -> Result<()> { let mut tasks = self.tasks.write().unwrap(); if let Some(task) = tasks.get_mut(&id) { task.retry_count += 1; task.next_retry_at = Some(next_retry_at); if task.retry_count >= task.max_retries { task.status = CompensationTaskStatus::DeadLetter; } else { task.status = CompensationTaskStatus::Failed; } task.updated_at = Utc::now(); } Ok(()) } async fn mark_completed(&self, id: i64) -> Result<()> { let mut tasks = self.tasks.write().unwrap(); if let Some(task) = tasks.get_mut(&id) { task.status = CompensationTaskStatus::Completed; task.completed_at = Some(Utc::now()); task.updated_at = Utc::now(); } Ok(()) } async fn mark_dead_letter(&self, id: i64, error_message: &str) -> Result<()> { let mut tasks = self.tasks.write().unwrap(); if let Some(task) = tasks.get_mut(&id) { task.status = CompensationTaskStatus::DeadLetter; task.error_message = Some(error_message.to_string()); task.updated_at = Utc::now(); } Ok(()) } async fn has_pending_task(&self, txn_no: &str, task_type: CompensationTaskType) -> Result { let tasks = self.tasks.read().unwrap(); let has = tasks.values().any(|t| { t.txn_no == txn_no && t.task_type == task_type && (t.status == CompensationTaskStatus::Pending || t.status == CompensationTaskStatus::Processing || t.status == CompensationTaskStatus::Failed) }); Ok(has) } } // ==================== 测试辅助 ==================== /// 创建测试用系统交易 pub fn create_test_transaction( id: i64, txn_no: &str, amount: Decimal, status: TransactionStatus, ) -> SystemTransaction { SystemTransaction { id, txn_no: txn_no.to_string(), txn_type: TransactionType::Transfer, from_account_id: Some(1), to_account_id: Some(2), amount, status, bank_ref_no: None, source_key: None, remark: None, created_at: Utc::now(), confirmed_at: None, submitted_at: None, version: 1, } } /// 创建测试用账户余额 pub fn create_test_balance( id: i64, account_id: i64, personal: Decimal, labor: Decimal, frozen: Decimal, ) -> AccountBalance { let bank_balance = personal + labor + frozen; AccountBalance { id, account_id, account_type: AccountType::Virtual, personal_balance: personal, labor_balance: labor, frozen_balance: frozen, bank_balance, transit_amount: Decimal::ZERO, system_balance: bank_balance, available_balance: personal + labor, // 可用余额 = 个人 + 劳动(不含冻结) frozen_amount: frozen, version: 1, updated_at: Utc::now(), } } #[cfg(test)] mod tests { use super::*; use rust_decimal_macros::dec; #[tokio::test] async fn test_in_memory_balance_repo() { let repo = InMemoryAccountBalanceRepository::new(); // 使用 get_or_create 创建余额 let balance = repo.get_or_create(1001, AccountType::Virtual).await.unwrap(); // 更新余额 let mut updated = balance.clone(); updated.personal_balance = dec!(1000.00); updated.labor_balance = dec!(500.00); updated.bank_balance = dec!(1500.00); repo.update(&updated).await.unwrap(); let found = repo.find_by_account(1001, AccountType::Virtual).await.unwrap(); assert!(found.is_some()); let found = found.unwrap(); assert_eq!(found.personal_balance, dec!(1000.00)); } #[tokio::test] async fn test_in_memory_txn_repo() { let repo = InMemorySystemTransactionRepository::new(); // 创建交易请求 let request = CreateSystemTransactionRequest { txn_type: TransactionType::Transfer, from_account_id: Some(1), to_account_id: Some(2), amount: dec!(100.00), remark: None, source_key: None, }; let txn = repo.create(&request).await.unwrap(); let found = repo.find_by_txn_no(&txn.txn_no).await.unwrap(); assert!(found.is_some()); assert_eq!(found.unwrap().id, txn.id); } }