futu_server/
subscription.rs

1// 订阅管理:行情订阅 + 交易账户推送订阅 + 通知订阅
2
3use std::collections::{HashMap, HashSet};
4
5use parking_lot::RwLock;
6
7/// 订阅管理器
8pub struct SubscriptionManager {
9    /// 通知订阅:哪些连接订阅了系统通知
10    notify_subs: RwLock<HashSet<u64>>,
11
12    /// 交易账户推送订阅:acc_id → Set<conn_id>
13    trd_acc_subs: RwLock<HashMap<u64, HashSet<u64>>>,
14
15    /// 行情订阅:"market_code:sub_type" → Set<conn_id>
16    qot_subs: RwLock<HashMap<String, HashSet<u64>>>,
17
18    /// 每连接的行情订阅额度使用
19    conn_quota: RwLock<HashMap<u64, u32>>,
20}
21
22/// 总订阅额度上限
23pub const TOTAL_QUOTA: u32 = 4000;
24
25impl SubscriptionManager {
26    pub fn new() -> Self {
27        Self {
28            notify_subs: RwLock::new(HashSet::new()),
29            trd_acc_subs: RwLock::new(HashMap::new()),
30            qot_subs: RwLock::new(HashMap::new()),
31            conn_quota: RwLock::new(HashMap::new()),
32        }
33    }
34
35    // ===== 通知订阅 =====
36
37    pub fn subscribe_notify(&self, conn_id: u64) {
38        self.notify_subs.write().insert(conn_id);
39    }
40
41    pub fn unsubscribe_notify(&self, conn_id: u64) {
42        self.notify_subs.write().remove(&conn_id);
43    }
44
45    pub fn is_subscribed_notify(&self, conn_id: u64) -> bool {
46        self.notify_subs.read().contains(&conn_id)
47    }
48
49    // ===== 交易账户推送 =====
50
51    pub fn subscribe_trd_acc(&self, conn_id: u64, acc_id: u64) {
52        self.trd_acc_subs
53            .write()
54            .entry(acc_id)
55            .or_default()
56            .insert(conn_id);
57    }
58
59    pub fn unsubscribe_trd_acc(&self, conn_id: u64, acc_id: u64) {
60        if let Some(subs) = self.trd_acc_subs.write().get_mut(&acc_id) {
61            subs.remove(&conn_id);
62        }
63    }
64
65    pub fn get_acc_subscribers(&self, acc_id: u64) -> Vec<u64> {
66        self.trd_acc_subs
67            .read()
68            .get(&acc_id)
69            .map(|s| s.iter().copied().collect())
70            .unwrap_or_default()
71    }
72
73    // ===== 行情订阅 =====
74
75    /// 生成行情订阅 key
76    pub fn make_qot_key(market: i32, code: &str, sub_type: i32) -> String {
77        format!("{market}_{code}:{sub_type}")
78    }
79
80    /// 订阅行情
81    pub fn subscribe_qot(&self, conn_id: u64, security_key: &str, sub_type: i32) {
82        let key = format!("{security_key}:{sub_type}");
83        self.qot_subs
84            .write()
85            .entry(key)
86            .or_default()
87            .insert(conn_id);
88
89        // 更新额度
90        *self.conn_quota.write().entry(conn_id).or_insert(0) += 1;
91    }
92
93    /// 退订行情
94    pub fn unsubscribe_qot(&self, conn_id: u64, security_key: &str, sub_type: i32) {
95        let key = format!("{security_key}:{sub_type}");
96        if let Some(subs) = self.qot_subs.write().get_mut(&key) {
97            if subs.remove(&conn_id) {
98                let mut quota = self.conn_quota.write();
99                if let Some(q) = quota.get_mut(&conn_id) {
100                    *q = q.saturating_sub(1);
101                }
102            }
103        }
104    }
105
106    /// 获取订阅了指定行情的连接列表
107    pub fn get_qot_subscribers(&self, security_key: &str, sub_type: i32) -> Vec<u64> {
108        let key = format!("{security_key}:{sub_type}");
109        self.qot_subs
110            .read()
111            .get(&key)
112            .map(|s| s.iter().copied().collect())
113            .unwrap_or_default()
114    }
115
116    /// 获取连接的已用订阅额度
117    pub fn get_conn_used_quota(&self, conn_id: u64) -> u32 {
118        self.conn_quota.read().get(&conn_id).copied().unwrap_or(0)
119    }
120
121    /// 获取总已用额度
122    pub fn get_total_used_quota(&self) -> u32 {
123        self.conn_quota.read().values().sum()
124    }
125
126    /// 获取指定连接订阅的行情列表: sub_type → Vec<security_key>
127    pub fn get_conn_qot_subs(&self, conn_id: u64) -> HashMap<i32, Vec<String>> {
128        let qot = self.qot_subs.read();
129        let mut result: HashMap<i32, Vec<String>> = HashMap::new();
130        for (key, conn_ids) in qot.iter() {
131            if conn_ids.contains(&conn_id) {
132                // key 格式: "market_code:sub_type"
133                if let Some(colon) = key.rfind(':') {
134                    if let Ok(sub_type) = key[colon + 1..].parse::<i32>() {
135                        let sec_key = &key[..colon];
136                        result
137                            .entry(sub_type)
138                            .or_default()
139                            .push(sec_key.to_string());
140                    }
141                }
142            }
143        }
144        result
145    }
146
147    /// 获取所有连接 ID(有行情订阅的)
148    pub fn get_all_qot_conn_ids(&self) -> HashSet<u64> {
149        let qot = self.qot_subs.read();
150        let mut ids = HashSet::new();
151        for conn_ids in qot.values() {
152            ids.extend(conn_ids);
153        }
154        ids
155    }
156
157    // ===== 连接断开清理 =====
158
159    pub fn on_disconnect(&self, conn_id: u64) {
160        self.notify_subs.write().remove(&conn_id);
161
162        // 清理交易账户订阅
163        let mut trd = self.trd_acc_subs.write();
164        for subs in trd.values_mut() {
165            subs.remove(&conn_id);
166        }
167
168        // 清理行情订阅
169        let mut qot = self.qot_subs.write();
170        for subs in qot.values_mut() {
171            subs.remove(&conn_id);
172        }
173
174        self.conn_quota.write().remove(&conn_id);
175    }
176}
177
178impl Default for SubscriptionManager {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_notify_subscription() {
190        let mgr = SubscriptionManager::new();
191        assert!(!mgr.is_subscribed_notify(1));
192        mgr.subscribe_notify(1);
193        assert!(mgr.is_subscribed_notify(1));
194        mgr.unsubscribe_notify(1);
195        assert!(!mgr.is_subscribed_notify(1));
196    }
197
198    #[test]
199    fn test_trd_acc_subscription() {
200        let mgr = SubscriptionManager::new();
201        mgr.subscribe_trd_acc(1, 100);
202        mgr.subscribe_trd_acc(2, 100);
203        mgr.subscribe_trd_acc(1, 200);
204
205        let subs = mgr.get_acc_subscribers(100);
206        assert_eq!(subs.len(), 2);
207        assert!(subs.contains(&1));
208        assert!(subs.contains(&2));
209
210        mgr.unsubscribe_trd_acc(1, 100);
211        assert_eq!(mgr.get_acc_subscribers(100).len(), 1);
212    }
213
214    #[test]
215    fn test_qot_subscription() {
216        let mgr = SubscriptionManager::new();
217        mgr.subscribe_qot(1, "1_00700", 1); // Basic
218        mgr.subscribe_qot(2, "1_00700", 1);
219        mgr.subscribe_qot(1, "1_00700", 2); // OrderBook
220
221        assert_eq!(mgr.get_qot_subscribers("1_00700", 1).len(), 2);
222        assert_eq!(mgr.get_qot_subscribers("1_00700", 2).len(), 1);
223        assert_eq!(mgr.get_conn_used_quota(1), 2);
224    }
225
226    #[test]
227    fn test_disconnect_cleanup() {
228        let mgr = SubscriptionManager::new();
229        mgr.subscribe_notify(1);
230        mgr.subscribe_trd_acc(1, 100);
231        mgr.subscribe_qot(1, "1_00700", 1);
232
233        mgr.on_disconnect(1);
234
235        assert!(!mgr.is_subscribed_notify(1));
236        assert!(mgr.get_acc_subscribers(100).is_empty());
237        assert!(mgr.get_qot_subscribers("1_00700", 1).is_empty());
238        assert_eq!(mgr.get_conn_used_quota(1), 0);
239    }
240}