rustjr-account-management/tests/common/mock_repositories.rs
tangweijie d7f81893c5 Initial commit: 完整的 Rust 账户管理系统
- 实现账户管理改进设计文档中的所有核心功能
- 三科目余额管理 (个人余额、劳动报酬、冻结余额)
- 交易状态机 (created → pending → bank_submitted → success/failed/timeout → reversed)
- 三键幂等体系 (JZTxId/BankTxId/SourceKey)
- 优先级扣款规则 (先个人后劳动)
- 在途资金管理 (可用→在途→结转/回退)
- 三账对账闭环 (总账 = 银行账 + 在途净额)
- 补偿服务域 (超时检测、重试、死信队列)
- 虚拟银行模拟器用于业务测试
- 完整的集成测试套件 (133 个测试全部通过)
- Docker 容器化部署配置
- 前端 Vue3 + TypeScript 项目结构
2026-01-05 17:56:01 +08:00

664 lines
21 KiB
Rust

//! 内存仓储实现
//!
//! 提供基于内存的仓储实现,用于单元测试和集成测试,
//! 避免依赖真实数据库
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use rustjr::domain::ledger::entity::{AccountBalance, DeductionResult};
use rustjr::domain::account::AccountType;
use rustjr::domain::ledger::repository::AccountBalanceRepository;
use rustjr::domain::transaction::entity::{
SystemTransaction, TransactionStatus, TransactionType, CreateSystemTransactionRequest,
};
use rustjr::domain::transaction::repository::SystemTransactionRepository;
use rustjr::domain::compensation::{
CompensationTask, CompensationTaskStatus, CompensationTaskType, CompensationTaskRepository,
};
use rustjr::error::Result;
// ==================== 账户余额内存仓储 ====================
/// 内存账户余额仓储
pub struct InMemoryAccountBalanceRepository {
balances: Arc<RwLock<HashMap<(i64, AccountType), AccountBalance>>>,
next_id: Arc<RwLock<i64>>,
}
impl InMemoryAccountBalanceRepository {
pub fn new() -> Self {
Self {
balances: Arc::new(RwLock::new(HashMap::new())),
next_id: Arc::new(RwLock::new(1)),
}
}
/// 插入预设余额(测试用)
pub fn insert(&self, balance: AccountBalance) {
let mut balances = self.balances.write().unwrap();
balances.insert((balance.account_id, balance.account_type), balance);
}
/// 获取所有余额(测试用)
pub fn get_all(&self) -> Vec<AccountBalance> {
let balances = self.balances.read().unwrap();
balances.values().cloned().collect()
}
}
impl Default for InMemoryAccountBalanceRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AccountBalanceRepository for InMemoryAccountBalanceRepository {
async fn find_by_account(
&self,
account_id: i64,
account_type: AccountType,
) -> Result<Option<AccountBalance>> {
let balances = self.balances.read().unwrap();
Ok(balances.get(&(account_id, account_type)).cloned())
}
async fn update(&self, balance: &AccountBalance) -> Result<()> {
let mut balances = self.balances.write().unwrap();
balances.insert((balance.account_id, balance.account_type), balance.clone());
Ok(())
}
async fn get_or_create(&self, account_id: i64, account_type: AccountType) -> Result<AccountBalance> {
let mut balances = self.balances.write().unwrap();
let key = (account_id, account_type);
if let Some(balance) = balances.get(&key) {
return Ok(balance.clone());
}
// 创建默认余额
let mut next_id = self.next_id.write().unwrap();
let id = *next_id;
*next_id += 1;
let balance = AccountBalance {
id,
account_id,
account_type,
personal_balance: Decimal::ZERO,
labor_balance: Decimal::ZERO,
frozen_balance: Decimal::ZERO,
bank_balance: Decimal::ZERO,
transit_amount: Decimal::ZERO,
system_balance: Decimal::ZERO,
available_balance: Decimal::ZERO,
frozen_amount: Decimal::ZERO,
version: 1,
updated_at: chrono::Utc::now(),
};
balances.insert(key, balance.clone());
Ok(balance)
}
async fn batch_update(&self, balances: &[AccountBalance]) -> Result<()> {
let mut stored = self.balances.write().unwrap();
for balance in balances {
let key = (balance.account_id, balance.account_type);
stored.insert(key, balance.clone());
}
Ok(())
}
async fn freeze(&self, account_id: i64, account_type: AccountType, amount: Decimal) -> Result<()> {
let mut balances = self.balances.write().unwrap();
let key = (account_id, account_type);
if let Some(balance) = balances.get_mut(&key) {
balance.freeze(amount);
}
Ok(())
}
async fn unfreeze(&self, account_id: i64, account_type: AccountType, amount: Decimal) -> Result<()> {
let mut balances = self.balances.write().unwrap();
let key = (account_id, account_type);
if let Some(balance) = balances.get_mut(&key) {
balance.unfreeze(amount);
}
Ok(())
}
}
// ==================== 系统交易内存仓储 ====================
/// 内存系统交易仓储
pub struct InMemorySystemTransactionRepository {
transactions: Arc<RwLock<HashMap<i64, SystemTransaction>>>,
by_txn_no: Arc<RwLock<HashMap<String, i64>>>,
by_source_key: Arc<RwLock<HashMap<String, i64>>>,
next_id: Arc<RwLock<i64>>,
}
impl InMemorySystemTransactionRepository {
pub fn new() -> Self {
Self {
transactions: Arc::new(RwLock::new(HashMap::new())),
by_txn_no: Arc::new(RwLock::new(HashMap::new())),
by_source_key: Arc::new(RwLock::new(HashMap::new())),
next_id: Arc::new(RwLock::new(1)),
}
}
/// 插入预设交易(测试用)
pub fn insert(&self, txn: SystemTransaction) {
let mut transactions = self.transactions.write().unwrap();
let mut by_txn_no = self.by_txn_no.write().unwrap();
let mut by_source_key = self.by_source_key.write().unwrap();
by_txn_no.insert(txn.txn_no.clone(), txn.id);
if let Some(ref key) = txn.source_key {
by_source_key.insert(key.clone(), txn.id);
}
transactions.insert(txn.id, txn);
}
/// 获取所有交易(测试用)
pub fn get_all(&self) -> Vec<SystemTransaction> {
let transactions = self.transactions.read().unwrap();
transactions.values().cloned().collect()
}
}
impl Default for InMemorySystemTransactionRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SystemTransactionRepository for InMemorySystemTransactionRepository {
async fn create(&self, request: &CreateSystemTransactionRequest) -> Result<SystemTransaction> {
let mut transactions = self.transactions.write().unwrap();
let mut by_txn_no = self.by_txn_no.write().unwrap();
let mut by_source_key = self.by_source_key.write().unwrap();
let mut next_id = self.next_id.write().unwrap();
let id = *next_id;
*next_id += 1;
// 生成交易号
let txn_no = format!("TXN{:010}", id);
let txn = SystemTransaction {
id,
txn_no: txn_no.clone(),
txn_type: request.txn_type.clone(),
from_account_id: request.from_account_id,
to_account_id: request.to_account_id,
amount: request.amount,
status: TransactionStatus::Created,
bank_ref_no: None,
source_key: request.source_key.clone(),
remark: request.remark.clone(),
created_at: chrono::Utc::now(),
confirmed_at: None,
submitted_at: None,
version: 1,
};
by_txn_no.insert(txn_no, id);
if let Some(ref key) = txn.source_key {
by_source_key.insert(key.clone(), id);
}
transactions.insert(id, txn.clone());
Ok(txn)
}
async fn find_by_id(&self, id: i64) -> Result<Option<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
Ok(transactions.get(&id).cloned())
}
async fn find_by_txn_no(&self, txn_no: &str) -> Result<Option<SystemTransaction>> {
let by_txn_no = self.by_txn_no.read().unwrap();
let transactions = self.transactions.read().unwrap();
if let Some(&id) = by_txn_no.get(txn_no) {
return Ok(transactions.get(&id).cloned());
}
Ok(None)
}
async fn find_by_source_key(&self, source_key: &str) -> Result<Option<SystemTransaction>> {
let by_source_key = self.by_source_key.read().unwrap();
let transactions = self.transactions.read().unwrap();
if let Some(&id) = by_source_key.get(source_key) {
return Ok(transactions.get(&id).cloned());
}
Ok(None)
}
async fn find_by_status(&self, status: TransactionStatus) -> Result<Vec<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
let result: Vec<_> = transactions
.values()
.filter(|t| t.status == status)
.cloned()
.collect();
Ok(result)
}
async fn find_timeout(&self, threshold_seconds: i64) -> Result<Vec<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
let now = Utc::now();
let threshold = chrono::Duration::seconds(threshold_seconds);
let result: Vec<_> = transactions
.values()
.filter(|t| {
t.status == TransactionStatus::BankSubmitted
&& t.submitted_at.map_or(false, |s| now - s > threshold)
})
.cloned()
.collect();
Ok(result)
}
async fn update_status(&self, id: i64, status: TransactionStatus) -> Result<()> {
let mut transactions = self.transactions.write().unwrap();
if let Some(txn) = transactions.get_mut(&id) {
txn.status = status;
}
Ok(())
}
async fn set_bank_ref_no(&self, id: i64, bank_ref_no: &str) -> Result<()> {
let mut transactions = self.transactions.write().unwrap();
if let Some(txn) = transactions.get_mut(&id) {
txn.bank_ref_no = Some(bank_ref_no.to_string());
}
Ok(())
}
async fn set_submitted_at(&self, id: i64, submitted_at: DateTime<Utc>) -> Result<()> {
let mut transactions = self.transactions.write().unwrap();
if let Some(txn) = transactions.get_mut(&id) {
txn.submitted_at = Some(submitted_at);
}
Ok(())
}
async fn find_by_bank_ref_no(&self, bank_ref_no: &str) -> Result<Option<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
let txn = transactions.values().find(|t| t.bank_ref_no == Some(bank_ref_no.to_string()));
Ok(txn.cloned())
}
async fn find_pending(&self) -> Result<Vec<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
let result: Vec<_> = transactions
.values()
.filter(|t| t.status == TransactionStatus::Pending)
.cloned()
.collect();
Ok(result)
}
async fn find_needs_reconciliation(&self) -> Result<Vec<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
let result: Vec<_> = transactions
.values()
.filter(|t| matches!(
t.status,
TransactionStatus::BankSubmitted |
TransactionStatus::Timeout
))
.cloned()
.collect();
Ok(result)
}
async fn query(&self, query: &rustjr::domain::transaction::entity::TransactionQuery) -> Result<Vec<SystemTransaction>> {
let transactions = self.transactions.read().unwrap();
let mut result: Vec<_> = transactions.values().cloned().collect();
// 应用过滤条件
if let Some(account_id) = query.account_id {
result.retain(|t| t.from_account_id == Some(account_id) || t.to_account_id == Some(account_id));
}
if let Some(txn_type) = &query.txn_type {
result.retain(|t| t.txn_type == *txn_type);
}
if let Some(status) = &query.status {
result.retain(|t| t.status == *status);
}
if let Some(start_time) = query.start_time {
result.retain(|t| t.created_at >= start_time);
}
if let Some(end_time) = query.end_time {
result.retain(|t| t.created_at <= end_time);
}
if let Some(min_amount) = query.min_amount {
result.retain(|t| t.amount >= min_amount);
}
if let Some(max_amount) = query.max_amount {
result.retain(|t| t.amount <= max_amount);
}
// 简单的分页
let offset = query.offset.unwrap_or(0) as usize;
let limit = query.limit.unwrap_or(50) as usize;
if offset < result.len() {
result = result[offset..(offset + limit).min(result.len())].to_vec();
} else {
result.clear();
}
Ok(result)
}
async fn confirm(&self, id: i64, confirmed_at: DateTime<Utc>) -> Result<()> {
let mut transactions = self.transactions.write().unwrap();
if let Some(txn) = transactions.get_mut(&id) {
txn.confirmed_at = Some(confirmed_at);
txn.status = TransactionStatus::Success;
}
Ok(())
}
}
// ==================== 补偿任务内存仓储 ====================
/// 内存补偿任务仓储
pub struct InMemoryCompensationTaskRepository {
tasks: Arc<RwLock<HashMap<i64, CompensationTask>>>,
by_txn_no: Arc<RwLock<HashMap<String, Vec<i64>>>>,
next_id: Arc<RwLock<i64>>,
}
impl InMemoryCompensationTaskRepository {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
by_txn_no: Arc::new(RwLock::new(HashMap::new())),
next_id: Arc::new(RwLock::new(1)),
}
}
/// 获取所有任务(测试用)
pub fn get_all(&self) -> Vec<CompensationTask> {
let tasks = self.tasks.read().unwrap();
tasks.values().cloned().collect()
}
}
impl Default for InMemoryCompensationTaskRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl CompensationTaskRepository for InMemoryCompensationTaskRepository {
async fn create(&self, request: &rustjr::domain::compensation::CreateCompensationTaskRequest) -> Result<CompensationTask> {
let mut tasks = self.tasks.write().unwrap();
let mut by_txn_no = self.by_txn_no.write().unwrap();
let mut next_id = self.next_id.write().unwrap();
let id = *next_id;
*next_id += 1;
let task = CompensationTask {
id,
txn_no: request.txn_no.clone(),
task_type: request.task_type,
status: CompensationTaskStatus::Pending,
retry_count: 0,
max_retries: request.max_retries.unwrap_or(3),
next_retry_at: None,
error_message: None,
created_at: Utc::now(),
updated_at: Utc::now(),
completed_at: None,
};
by_txn_no.entry(request.txn_no.clone()).or_insert_with(Vec::new).push(id);
tasks.insert(id, task.clone());
Ok(task)
}
async fn find_by_id(&self, id: i64) -> Result<Option<CompensationTask>> {
let tasks = self.tasks.read().unwrap();
Ok(tasks.get(&id).cloned())
}
async fn find_by_txn_no(&self, txn_no: &str) -> Result<Vec<CompensationTask>> {
let tasks = self.tasks.read().unwrap();
let result: Vec<_> = tasks
.values()
.filter(|t| t.txn_no == txn_no)
.cloned()
.collect();
Ok(result)
}
async fn find_pending(&self, limit: i64) -> Result<Vec<CompensationTask>> {
let tasks = self.tasks.read().unwrap();
let mut result: Vec<_> = tasks
.values()
.filter(|t| t.status == CompensationTaskStatus::Pending)
.cloned()
.collect();
result.truncate(limit as usize);
Ok(result)
}
async fn find_ready_for_retry(&self, limit: i64) -> Result<Vec<CompensationTask>> {
let tasks = self.tasks.read().unwrap();
let now = Utc::now();
let mut result: Vec<_> = tasks
.values()
.filter(|t| {
t.status == CompensationTaskStatus::Failed
&& t.next_retry_at.map_or(false, |n| n <= now)
})
.cloned()
.collect();
result.truncate(limit as usize);
Ok(result)
}
async fn find_dead_letter(&self, limit: i64) -> Result<Vec<CompensationTask>> {
let tasks = self.tasks.read().unwrap();
let mut result: Vec<_> = tasks
.values()
.filter(|t| t.status == CompensationTaskStatus::DeadLetter)
.cloned()
.collect();
result.truncate(limit as usize);
Ok(result)
}
async fn update_status(
&self,
id: i64,
status: CompensationTaskStatus,
error_message: Option<&str>,
) -> Result<()> {
let mut tasks = self.tasks.write().unwrap();
if let Some(task) = tasks.get_mut(&id) {
task.status = status;
task.error_message = error_message.map(|s| s.to_string());
task.updated_at = Utc::now();
}
Ok(())
}
async fn increment_retry(
&self,
id: i64,
next_retry_at: DateTime<Utc>,
) -> Result<()> {
let mut tasks = self.tasks.write().unwrap();
if let Some(task) = tasks.get_mut(&id) {
task.retry_count += 1;
task.next_retry_at = Some(next_retry_at);
if task.retry_count >= task.max_retries {
task.status = CompensationTaskStatus::DeadLetter;
} else {
task.status = CompensationTaskStatus::Failed;
}
task.updated_at = Utc::now();
}
Ok(())
}
async fn mark_completed(&self, id: i64) -> Result<()> {
let mut tasks = self.tasks.write().unwrap();
if let Some(task) = tasks.get_mut(&id) {
task.status = CompensationTaskStatus::Completed;
task.completed_at = Some(Utc::now());
task.updated_at = Utc::now();
}
Ok(())
}
async fn mark_dead_letter(&self, id: i64, error_message: &str) -> Result<()> {
let mut tasks = self.tasks.write().unwrap();
if let Some(task) = tasks.get_mut(&id) {
task.status = CompensationTaskStatus::DeadLetter;
task.error_message = Some(error_message.to_string());
task.updated_at = Utc::now();
}
Ok(())
}
async fn has_pending_task(&self, txn_no: &str, task_type: CompensationTaskType) -> Result<bool> {
let tasks = self.tasks.read().unwrap();
let has = tasks.values().any(|t| {
t.txn_no == txn_no
&& t.task_type == task_type
&& (t.status == CompensationTaskStatus::Pending || t.status == CompensationTaskStatus::Processing || t.status == CompensationTaskStatus::Failed)
});
Ok(has)
}
}
// ==================== 测试辅助 ====================
/// 创建测试用系统交易
pub fn create_test_transaction(
id: i64,
txn_no: &str,
amount: Decimal,
status: TransactionStatus,
) -> SystemTransaction {
SystemTransaction {
id,
txn_no: txn_no.to_string(),
txn_type: TransactionType::Transfer,
from_account_id: Some(1),
to_account_id: Some(2),
amount,
status,
bank_ref_no: None,
source_key: None,
remark: None,
created_at: Utc::now(),
confirmed_at: None,
submitted_at: None,
version: 1,
}
}
/// 创建测试用账户余额
pub fn create_test_balance(
id: i64,
account_id: i64,
personal: Decimal,
labor: Decimal,
frozen: Decimal,
) -> AccountBalance {
let bank_balance = personal + labor + frozen;
AccountBalance {
id,
account_id,
account_type: AccountType::Virtual,
personal_balance: personal,
labor_balance: labor,
frozen_balance: frozen,
bank_balance,
transit_amount: Decimal::ZERO,
system_balance: bank_balance,
available_balance: personal + labor, // 可用余额 = 个人 + 劳动(不含冻结)
frozen_amount: frozen,
version: 1,
updated_at: Utc::now(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[tokio::test]
async fn test_in_memory_balance_repo() {
let repo = InMemoryAccountBalanceRepository::new();
// 使用 get_or_create 创建余额
let balance = repo.get_or_create(1001, AccountType::Virtual).await.unwrap();
// 更新余额
let mut updated = balance.clone();
updated.personal_balance = dec!(1000.00);
updated.labor_balance = dec!(500.00);
updated.bank_balance = dec!(1500.00);
repo.update(&updated).await.unwrap();
let found = repo.find_by_account(1001, AccountType::Virtual).await.unwrap();
assert!(found.is_some());
let found = found.unwrap();
assert_eq!(found.personal_balance, dec!(1000.00));
}
#[tokio::test]
async fn test_in_memory_txn_repo() {
let repo = InMemorySystemTransactionRepository::new();
// 创建交易请求
let request = CreateSystemTransactionRequest {
txn_type: TransactionType::Transfer,
from_account_id: Some(1),
to_account_id: Some(2),
amount: dec!(100.00),
remark: None,
source_key: None,
};
let txn = repo.create(&request).await.unwrap();
let found = repo.find_by_txn_no(&txn.txn_no).await.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().id, txn.id);
}
}