1use std::collections::{HashMap, VecDeque};
4use std::time::{Duration, Instant};
5
6use parking_lot::Mutex;
7
8const FREQ_WINDOW: Duration = Duration::from_secs(30);
10
11fn default_proto_limits() -> HashMap<u32, u32> {
13 use futu_core::proto_id::*;
14 let mut m = HashMap::new();
15 m.insert(TRD_UNLOCK_TRADE, 10);
16 m.insert(TRD_PLACE_ORDER, 15);
17 m.insert(TRD_MODIFY_ORDER, 20);
18 m.insert(TRD_GET_HISTORY_ORDER_FILL_LIST, 10);
19 m.insert(TRD_GET_HISTORY_ORDER_LIST, 10);
20 m.insert(QOT_GET_SECURITY_SNAPSHOT, 10);
21 m.insert(QOT_GET_PLATE_SET, 10);
22 m.insert(QOT_GET_PLATE_SECURITY, 10);
23 m.insert(QOT_GET_OWNER_PLATE, 10);
24 m.insert(QOT_GET_HOLDING_CHANGE_LIST, 10);
25 m.insert(QOT_GET_OPTION_CHAIN, 10);
26 m.insert(QOT_REQUEST_HISTORY_KL, 10);
27 m
28}
29
30struct ConnFreqRecord {
32 proto_times: HashMap<u32, VecDeque<Instant>>,
34 last_serial: u32,
36}
37
38impl ConnFreqRecord {
39 fn new() -> Self {
40 Self {
41 proto_times: HashMap::new(),
42 last_serial: 0,
43 }
44 }
45}
46
47pub struct ProtectionManager {
49 records: Mutex<HashMap<u64, ConnFreqRecord>>,
50 proto_limits: HashMap<u32, u32>,
51}
52
53impl ProtectionManager {
54 pub fn new() -> Self {
57 Self {
58 records: Mutex::new(HashMap::new()),
59 proto_limits: default_proto_limits(),
60 }
61 }
62
63 pub fn check_freq_limit(&self, conn_id: u64, proto_id: u32) -> bool {
67 let limit = match self.proto_limits.get(&proto_id) {
68 Some(&limit) => limit,
69 None => return false, };
71
72 let mut records = self.records.lock();
73 let record = records.entry(conn_id).or_insert_with(ConnFreqRecord::new);
74
75 let times = record.proto_times.entry(proto_id).or_default();
76 let now = Instant::now();
77
78 while times
80 .front()
81 .is_some_and(|t| now.duration_since(*t) > FREQ_WINDOW)
82 {
83 times.pop_front();
84 }
85
86 if times.len() as u32 >= limit {
87 true } else {
89 times.push_back(now);
90 false
91 }
92 }
93
94 pub fn check_replay(&self, conn_id: u64, serial_no: u32) -> bool {
98 let mut records = self.records.lock();
99 let record = records.entry(conn_id).or_insert_with(ConnFreqRecord::new);
100
101 if serial_no <= record.last_serial {
102 true } else {
104 record.last_serial = serial_no;
105 false
106 }
107 }
108
109 pub fn on_disconnect(&self, conn_id: u64) {
111 self.records.lock().remove(&conn_id);
112 }
113}
114
115impl Default for ProtectionManager {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121#[cfg(test)]
122mod tests;