1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::Wallet;
6use crate::persist::models::{RoundStateId, StoredRoundState};
7
8#[derive(Clone)]
9pub(crate) struct RoundStateLockIndex {
10 locked: Arc<parking_lot::Mutex<HashSet<RoundStateId>>>,
11}
12
13impl RoundStateLockIndex {
14 pub fn new() -> Self {
15 Self {
16 locked: Arc::new(parking_lot::Mutex::new(HashSet::new())),
17 }
18 }
19
20 pub(crate) fn try_lock(&self, round_state: RoundStateId) -> Option<RoundStateGuard> {
21 let mut index_lock = self.locked.lock();
22 if index_lock.insert(round_state) {
23 Some(RoundStateGuard { index: self.clone(), round_state })
24 } else {
25 None
26 }
27 }
28
29 pub(crate) async fn wait_lock(&self, round_state: RoundStateId) -> anyhow::Result<RoundStateGuard> {
31 let mut attempts = 0;
32 loop {
33 if let Some(guard) = self.try_lock(round_state) {
34 return Ok(guard);
35 }
36 attempts += 1;
37 if attempts > 100 {
39 bail!("Timed out waiting for lock on round state {}", round_state);
40 }
41 tokio::time::sleep(Duration::from_millis(100)).await;
42 }
43 }
44}
45
46pub struct RoundStateGuard {
47 index: RoundStateLockIndex,
48 round_state: RoundStateId,
49}
50
51impl std::ops::Drop for RoundStateGuard {
52 fn drop(&mut self) {
53 self.index.locked.lock().remove(&self.round_state);
54 }
55}
56
57impl Wallet {
58 pub async fn lock_wait_round_state(&self, id: RoundStateId) -> anyhow::Result<Option<StoredRoundState>> {
63 let guard = self.round_state_lock_index.wait_lock(id).await?;
64
65 if let Some(state) = self.db.get_round_state_by_id(id).await? {
66 return Ok(Some(state.lock(guard)));
67 }
68
69 Ok(None)
70 }
71}
72
73#[cfg(test)]
74mod test {
75 use super::*;
76
77 #[test]
78 fn round_state_lock() {
79 let index = RoundStateLockIndex::new();
80
81 let guard = index.try_lock(RoundStateId(1));
83 assert!(guard.is_some(), "first lock should succeed");
84
85 let guard2 = index.try_lock(RoundStateId(1));
87 assert!(guard2.is_none(), "second lock should fail");
88
89 drop(guard);
91 assert!(index.try_lock(RoundStateId(1)).is_some(), "lock should succeed after drop");
92
93 let guard3 = index.try_lock(RoundStateId(2));
95 assert!(guard3.is_some(), "second lock should succeed");
96
97 let cloned = index.clone();
99 let id = RoundStateId(1);
100 let guard4 = cloned.try_lock(id);
101 assert!(guard4.is_some(), "cloned index should share lock state");
102 assert!(index.try_lock(id).is_none(), "original should prevent lock");
103
104 drop(guard4);
106 let guard5 = index.try_lock(id);
107 assert!(guard5.is_some(), "lock should succeed on original index after drop");
108 assert!(cloned.try_lock(id).is_none(), "cloned index should prevent lock");
109 }
110
111 #[tokio::test]
112 async fn lock_wait_succeeds_after_guard_dropped() {
113 let index = RoundStateLockIndex::new();
114 let guard = index.try_lock(RoundStateId(1)).unwrap();
115
116 let cloned = index.clone();
117 let handle = tokio::spawn(async move {
118 cloned.wait_lock(RoundStateId(1)).await
119 });
120
121 tokio::time::sleep(Duration::from_millis(150)).await;
123 drop(guard);
124
125 let result = tokio::time::timeout(Duration::from_secs(2), handle).await;
126 assert!(result.is_ok(), "lock_wait should complete after guard is dropped");
127 }
128}