1use std::collections::HashSet;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::Instant;
6
7use bytes::Bytes;
8use futures::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio_util::codec::Framed;
12
13use futu_auth::Scope;
14use futu_codec::frame::FutuFrame;
15use futu_codec::header::ProtoFmtType;
16use futu_codec::FutuCodec;
17use futu_core::error::FutuError;
18use futu_net::encrypt;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnState {
23 Connected,
24 Initialized,
25 Active,
26 Disconnected,
27}
28
29pub struct ClientConn {
31 pub conn_id: u64,
32 pub state: ConnState,
33 pub aes_key: [u8; 16],
34 pub aes_encrypt_enabled: bool,
36 pub proto_fmt_type: ProtoFmtType,
37 pub last_keepalive: Instant,
38 pub keepalive_count: AtomicU32,
39 pub tx: mpsc::Sender<FutuFrame>,
41
42 pub key_id: Option<String>,
45 pub scopes: HashSet<Scope>,
47}
48
49#[derive(Debug)]
51pub struct IncomingRequest {
52 pub conn_id: u64,
53 pub proto_id: u32,
54 pub serial_no: u32,
55 pub proto_fmt_type: ProtoFmtType,
56 pub body: Bytes,
57}
58
59impl ClientConn {
60 pub fn generate_conn_id() -> u64 {
62 use std::time::{SystemTime, UNIX_EPOCH};
63 let millis = SystemTime::now()
64 .duration_since(UNIX_EPOCH)
65 .unwrap_or_default()
66 .as_millis() as u64;
67 let rand_part: u32 = rand::random();
68 (millis << 22) | (rand_part as u64 & 0x3FFFFF)
69 }
70
71 pub fn generate_aes_key() -> [u8; 16] {
73 let rand_val: u64 = rand::random();
74 let hex = format!("{rand_val:016X}");
75 let mut key = [0u8; 16];
76 key.copy_from_slice(hex.as_bytes());
77 key
78 }
79
80 pub fn make_frame(&self, proto_id: u32, serial_no: u32, body: Bytes) -> FutuFrame {
89 if self.aes_encrypt_enabled {
90 let mut frame = FutuFrame::new(proto_id, serial_no, body);
92 let encrypted = encrypt::aes_ecb_encrypt(&self.aes_key, &frame.body);
94 frame.header.body_len = encrypted.len() as u32;
95 frame.body = Bytes::from(encrypted);
96 frame
97 } else {
98 FutuFrame::new(proto_id, serial_no, body)
99 }
100 }
101
102 pub fn decrypt_body(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
106 if self.aes_encrypt_enabled {
107 encrypt::aes_ecb_decrypt(&self.aes_key, body).map_err(|e| {
108 tracing::warn!(conn_id = self.conn_id, error = %e, "AES decrypt body failed");
109 e
110 })
111 } else {
112 Ok(body.to_vec())
113 }
114 }
115
116 pub fn handle_init_connect(
124 &mut self,
125 body: &[u8],
126 server_ver: i32,
127 login_user_id: u64,
128 keepalive_interval: i32,
129 rsa_private_key: Option<&str>,
130 ) -> Result<Vec<u8>, FutuError> {
131 let decrypted_body;
133 let req_body = if let Some(rsa_key) = rsa_private_key {
134 decrypted_body =
135 futu_net::encrypt::rsa_private_decrypt_blocks(rsa_key, body).map_err(|e| {
136 tracing::warn!(error = %e, "RSA decrypt InitConnect C2S failed");
137 e
138 })?;
139 tracing::debug!(
140 encrypted_len = body.len(),
141 decrypted_len = decrypted_body.len(),
142 "RSA decrypted InitConnect C2S"
143 );
144 &decrypted_body[..]
145 } else {
146 body
147 };
148
149 let _req: futu_proto::init_connect::Request =
150 prost::Message::decode(req_body).map_err(FutuError::Proto)?;
151
152 self.state = ConnState::Initialized;
153
154 if rsa_private_key.is_some() {
156 self.aes_encrypt_enabled = true;
157 tracing::debug!(conn_id = self.conn_id, "AES body encryption enabled");
158 }
159
160 let aes_key_str = std::str::from_utf8(&self.aes_key).unwrap_or("").to_string();
161
162 let resp = futu_proto::init_connect::Response {
163 ret_type: 0,
164 ret_msg: None,
165 err_code: None,
166 s2c: Some(futu_proto::init_connect::S2c {
167 server_ver,
168 login_user_id,
169 conn_id: self.conn_id,
170 conn_aes_key: aes_key_str,
171 keep_alive_interval: keepalive_interval,
172 aes_cb_civ: None,
173 user_attribution: None,
174 }),
175 };
176
177 let resp_body = prost::Message::encode_to_vec(&resp);
178
179 if let Some(rsa_key) = rsa_private_key {
181 let encrypted = futu_net::encrypt::rsa_public_encrypt_blocks(rsa_key, &resp_body)
182 .map_err(|e| {
183 tracing::warn!(error = %e, "RSA encrypt InitConnect S2C failed");
184 e
185 })?;
186 tracing::debug!(
187 plaintext_len = resp_body.len(),
188 encrypted_len = encrypted.len(),
189 "RSA encrypted InitConnect S2C"
190 );
191 Ok(encrypted)
192 } else {
193 Ok(resp_body)
194 }
195 }
196
197 pub fn handle_keepalive(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
199 let _req: futu_proto::keep_alive::Request =
200 prost::Message::decode(body).map_err(FutuError::Proto)?;
201
202 self.keepalive_count.fetch_add(1, Ordering::Relaxed);
203
204 let resp = futu_proto::keep_alive::Response {
205 ret_type: 0,
206 ret_msg: None,
207 err_code: None,
208 s2c: Some(futu_proto::keep_alive::S2c {
209 time: chrono::Utc::now().timestamp(),
210 }),
211 };
212
213 Ok(prost::Message::encode_to_vec(&resp))
214 }
215}
216
217pub struct DisconnectNotify {
219 pub conn_id: u64,
220}
221
222pub async fn run_connection(
226 stream: TcpStream,
227 conn_id: u64,
228 _aes_key: [u8; 16],
229 req_tx: mpsc::UnboundedSender<IncomingRequest>,
230 disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
231) -> mpsc::Sender<FutuFrame> {
232 let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
233
234 let framed = Framed::new(stream, FutuCodec);
235 let (mut sink, mut stream) = framed.split();
236
237 tokio::spawn(async move {
239 while let Some(frame) = frame_rx.recv().await {
240 if let Err(e) = sink.send(frame).await {
241 tracing::warn!(conn_id = conn_id, error = %e, "send failed");
242 break;
243 }
244 }
245 });
246
247 tokio::spawn(async move {
249 while let Some(result) = stream.next().await {
250 match result {
251 Ok(frame) => {
252 let req = IncomingRequest {
253 conn_id,
254 proto_id: frame.header.proto_id,
255 serial_no: frame.header.serial_no,
256 proto_fmt_type: frame.header.proto_fmt_type,
257 body: frame.body,
258 };
259 if req_tx.send(req).is_err() {
260 break;
261 }
262 }
263 Err(e) => {
264 tracing::warn!(conn_id = conn_id, error = %e, "recv error");
265 break;
266 }
267 }
268 }
269 tracing::info!(conn_id = conn_id, "connection closed");
270 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
272 });
273
274 frame_tx
275}