bark/round/
lock.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::Wallet;
6use crate::persist::models::{RoundStateId, StoredRoundState};
7
8#[derive(Clone)]
9pub(crate) struct RoundStateLockIndex {
10	locked: Arc<parking_lot::Mutex<HashSet<RoundStateId>>>,
11}
12
13impl RoundStateLockIndex {
14	pub fn new() -> Self {
15		Self {
16			locked: Arc::new(parking_lot::Mutex::new(HashSet::new())),
17		}
18	}
19
20	pub(crate) fn try_lock(&self, round_state: RoundStateId) -> Option<RoundStateGuard> {
21		let mut index_lock = self.locked.lock();
22		if index_lock.insert(round_state) {
23			Some(RoundStateGuard { index: self.clone(), round_state })
24		} else {
25			None
26		}
27	}
28
29	/// Try to lock the given round state, waiting until it becomes available.
30	pub(crate) async fn wait_lock(&self, round_state: RoundStateId) -> anyhow::Result<RoundStateGuard> {
31		let mut attempts = 0;
32		loop {
33			if let Some(guard) = self.try_lock(round_state) {
34				return Ok(guard);
35			}
36			attempts += 1;
37			// tries for 10 seconds, enough for a round to complete
38			if attempts > 100 {
39				bail!("Timed out waiting for lock on round state {}", round_state);
40			}
41			tokio::time::sleep(Duration::from_millis(100)).await;
42		}
43	}
44}
45
46pub struct RoundStateGuard {
47	index: RoundStateLockIndex,
48	round_state: RoundStateId,
49}
50
51impl std::ops::Drop for RoundStateGuard {
52	fn drop(&mut self) {
53		self.index.locked.lock().remove(&self.round_state);
54	}
55}
56
57impl Wallet {
58	/// Load and lock a single given round state (by id), waiting for the lock.
59	///
60	/// Returns `Some(state, guard)` if the round state is found and locked, `None`
61	/// if it is not found after waiting for the lock.
62	pub async fn lock_wait_round_state(&self, id: RoundStateId) -> anyhow::Result<Option<StoredRoundState>> {
63		let guard = self.round_state_lock_index.wait_lock(id).await?;
64
65		if let Some(state) = self.db.get_round_state_by_id(id).await? {
66			return Ok(Some(state.lock(guard)));
67		}
68
69		Ok(None)
70	}
71}
72
73#[cfg(test)]
74mod test {
75	use super::*;
76
77	#[test]
78	fn round_state_lock() {
79		let index = RoundStateLockIndex::new();
80
81		// returns guard on first acquisition
82		let guard = index.try_lock(RoundStateId(1));
83		assert!(guard.is_some(), "first lock should succeed");
84
85		// returns none on second acquisition
86		let guard2 = index.try_lock(RoundStateId(1));
87		assert!(guard2.is_none(), "second lock should fail");
88
89		// dropping guard releases lock
90		drop(guard);
91		assert!(index.try_lock(RoundStateId(1)).is_some(), "lock should succeed after drop");
92
93		// different ids lock independently
94		let guard3 = index.try_lock(RoundStateId(2));
95		assert!(guard3.is_some(), "second lock should succeed");
96
97		// cloned index shares lock state
98		let cloned = index.clone();
99		let id = RoundStateId(1);
100		let guard4 = cloned.try_lock(id);
101		assert!(guard4.is_some(), "cloned index should share lock state");
102		assert!(index.try_lock(id).is_none(), "original should prevent lock");
103
104		// dropping guard releases lock
105		drop(guard4);
106		let guard5 = index.try_lock(id);
107		assert!(guard5.is_some(), "lock should succeed on original index after drop");
108		assert!(cloned.try_lock(id).is_none(), "cloned index should prevent lock");
109	}
110
111	#[tokio::test]
112	async fn lock_wait_succeeds_after_guard_dropped() {
113		let index = RoundStateLockIndex::new();
114		let guard = index.try_lock(RoundStateId(1)).unwrap();
115
116		let cloned = index.clone();
117		let handle = tokio::spawn(async move {
118			cloned.wait_lock(RoundStateId(1)).await
119		});
120
121		// Release after a short delay so lock_wait can acquire it.
122		tokio::time::sleep(Duration::from_millis(150)).await;
123		drop(guard);
124
125		let result = tokio::time::timeout(Duration::from_secs(2), handle).await;
126		assert!(result.is_ok(), "lock_wait should complete after guard is dropped");
127	}
128}