Skip to main content

futu_server/
protect.rs

1// 限频保护 + 防重放攻击
2
3use std::collections::{HashMap, VecDeque};
4use std::time::{Duration, Instant};
5
6use parking_lot::Mutex;
7
8/// 限频窗口大小(30 秒)
9const FREQ_WINDOW: Duration = Duration::from_secs(30);
10
11/// 默认协议限频配置
12fn default_proto_limits() -> HashMap<u32, u32> {
13    use futu_core::proto_id::*;
14    let mut m = HashMap::new();
15    m.insert(TRD_UNLOCK_TRADE, 10);
16    m.insert(TRD_PLACE_ORDER, 15);
17    m.insert(TRD_MODIFY_ORDER, 20);
18    m.insert(TRD_GET_HISTORY_ORDER_FILL_LIST, 10);
19    m.insert(TRD_GET_HISTORY_ORDER_LIST, 10);
20    m.insert(QOT_GET_SECURITY_SNAPSHOT, 10);
21    m.insert(QOT_GET_PLATE_SET, 10);
22    m.insert(QOT_GET_PLATE_SECURITY, 10);
23    m.insert(QOT_GET_OWNER_PLATE, 10);
24    m.insert(QOT_GET_HOLDING_CHANGE_LIST, 10);
25    m.insert(QOT_GET_OPTION_CHAIN, 10);
26    m.insert(QOT_REQUEST_HISTORY_KL, 10);
27    m
28}
29
30/// 单连接的频率记录
31struct ConnFreqRecord {
32    /// 每个协议的请求时间戳队列
33    proto_times: HashMap<u32, VecDeque<Instant>>,
34    /// 防重放:上次见到的 serial number
35    last_serial: u32,
36}
37
38impl ConnFreqRecord {
39    fn new() -> Self {
40        Self {
41            proto_times: HashMap::new(),
42            last_serial: 0,
43        }
44    }
45}
46
47/// 限频保护器
48pub struct ProtectionManager {
49    records: Mutex<HashMap<u64, ConnFreqRecord>>,
50    proto_limits: HashMap<u32, u32>,
51}
52
53impl ProtectionManager {
54    /// 创建新的 [`ProtectionManager`] 实例。内部自动加载 per-proto_id 默认限
55    /// 频表(对齐 C++ 默认配置),不需额外初始化。
56    pub fn new() -> Self {
57        Self {
58            records: Mutex::new(HashMap::new()),
59            proto_limits: default_proto_limits(),
60        }
61    }
62
63    /// 检查请求是否超频
64    ///
65    /// 返回 true 表示被限频(应拒绝),false 表示通过
66    pub fn check_freq_limit(&self, conn_id: u64, proto_id: u32) -> bool {
67        let limit = match self.proto_limits.get(&proto_id) {
68            Some(&limit) => limit,
69            None => return false, // 无限制的协议
70        };
71
72        let mut records = self.records.lock();
73        let record = records.entry(conn_id).or_insert_with(ConnFreqRecord::new);
74
75        let times = record.proto_times.entry(proto_id).or_default();
76        let now = Instant::now();
77
78        // 清理过期记录
79        while times
80            .front()
81            .is_some_and(|t| now.duration_since(*t) > FREQ_WINDOW)
82        {
83            times.pop_front();
84        }
85
86        if times.len() as u32 >= limit {
87            true // 超频
88        } else {
89            times.push_back(now);
90            false
91        }
92    }
93
94    /// 检查防重放(serial number 必须递增)
95    ///
96    /// 返回 true 表示可能是重放攻击,false 表示正常
97    pub fn check_replay(&self, conn_id: u64, serial_no: u32) -> bool {
98        let mut records = self.records.lock();
99        let record = records.entry(conn_id).or_insert_with(ConnFreqRecord::new);
100
101        if serial_no <= record.last_serial {
102            true // 重放
103        } else {
104            record.last_serial = serial_no;
105            false
106        }
107    }
108
109    /// 连接断开时清理
110    pub fn on_disconnect(&self, conn_id: u64) {
111        self.records.lock().remove(&conn_id);
112    }
113}
114
115impl Default for ProtectionManager {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121#[cfg(test)]
122mod tests;