futu_server/
conn.rs

1// 单连接管理:状态机、帧收发、加密、心跳超时
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::Instant;
6
7use bytes::Bytes;
8use futures::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio_util::codec::Framed;
12
13use futu_auth::Scope;
14use futu_codec::frame::FutuFrame;
15use futu_codec::header::ProtoFmtType;
16use futu_codec::FutuCodec;
17use futu_core::error::FutuError;
18use futu_net::encrypt;
19
20/// 连接状态
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ConnState {
23    Connected,
24    Initialized,
25    Active,
26    Disconnected,
27}
28
29/// 单个客户端连接
30pub struct ClientConn {
31    pub conn_id: u64,
32    pub state: ConnState,
33    pub aes_key: [u8; 16],
34    /// AES 加解密已启用(InitConnect 完成且配置了 RSA 时为 true)
35    pub aes_encrypt_enabled: bool,
36    pub proto_fmt_type: ProtoFmtType,
37    pub last_keepalive: Instant,
38    pub keepalive_count: AtomicU32,
39    /// 发送帧到此连接
40    pub tx: mpsc::Sender<FutuFrame>,
41
42    // ---- v1.0 WS per-message scope 鉴权(TCP 连接暂不启用,scopes 留空当"全放行")----
43    /// 该连接绑定的 API key id;WS 握手时填,未配 keys.json 时为 None
44    pub key_id: Option<String>,
45    /// 该连接持有的 scope 集合;空集 = legacy 模式 / TCP 直连,scope 检查放行
46    pub scopes: HashSet<Scope>,
47}
48
49/// 从连接接收到的请求
50#[derive(Debug)]
51pub struct IncomingRequest {
52    pub conn_id: u64,
53    pub proto_id: u32,
54    pub serial_no: u32,
55    pub proto_fmt_type: ProtoFmtType,
56    pub body: Bytes,
57}
58
59impl ClientConn {
60    /// 生成随机连接 ID(与 C++ 的 GetRand_MilliTimeAndU22 对应)
61    pub fn generate_conn_id() -> u64 {
62        use std::time::{SystemTime, UNIX_EPOCH};
63        let millis = SystemTime::now()
64            .duration_since(UNIX_EPOCH)
65            .unwrap_or_default()
66            .as_millis() as u64;
67        let rand_part: u32 = rand::random();
68        (millis << 22) | (rand_part as u64 & 0x3FFFFF)
69    }
70
71    /// 生成随机 AES key(16 字节 hex 字符串的 ASCII 字节)
72    pub fn generate_aes_key() -> [u8; 16] {
73        let rand_val: u64 = rand::random();
74        let hex = format!("{rand_val:016X}");
75        let mut key = [0u8; 16];
76        key.copy_from_slice(hex.as_bytes());
77        key
78    }
79
80    /// 创建发送帧,自动处理 AES 加密
81    ///
82    /// 当 aes_encrypt_enabled 为 true 时:
83    /// - SHA1 基于明文计算(FutuFrame::new 自动处理)
84    /// - body 使用 AES-128 ECB 加密
85    /// - header.body_len 更新为密文长度
86    ///
87    /// 对应 C++ APIServerCS_Conn::OnSendPacketData 的加密逻辑
88    pub fn make_frame(&self, proto_id: u32, serial_no: u32, body: Bytes) -> FutuFrame {
89        if self.aes_encrypt_enabled {
90            // 先用明文构建 frame(SHA1 基于明文计算)
91            let mut frame = FutuFrame::new(proto_id, serial_no, body);
92            // 再加密 body
93            let encrypted = encrypt::aes_ecb_encrypt(&self.aes_key, &frame.body);
94            frame.header.body_len = encrypted.len() as u32;
95            frame.body = Bytes::from(encrypted);
96            frame
97        } else {
98            FutuFrame::new(proto_id, serial_no, body)
99        }
100    }
101
102    /// 解密请求 body(如果启用了 AES 加密)
103    ///
104    /// 对应 C++ APIServerCS_Conn::OnRecvPacket 的解密逻辑
105    pub fn decrypt_body(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
106        if self.aes_encrypt_enabled {
107            encrypt::aes_ecb_decrypt(&self.aes_key, body).map_err(|e| {
108                tracing::warn!(conn_id = self.conn_id, error = %e, "AES decrypt body failed");
109                e
110            })
111        } else {
112            Ok(body.to_vec())
113        }
114    }
115
116    /// 处理 InitConnect 请求,返回 InitConnect 响应 body
117    ///
118    /// 当配置了 RSA 私钥时:
119    /// - C2S 请求 body 使用 RSA 公钥加密(需要用私钥解密)
120    /// - S2C 响应 body 使用 RSA 公钥加密(客户端用私钥解密)
121    ///
122    /// 对应 C++ APIServer::OnRecvInitConnect
123    pub fn handle_init_connect(
124        &mut self,
125        body: &[u8],
126        server_ver: i32,
127        login_user_id: u64,
128        keepalive_interval: i32,
129        rsa_private_key: Option<&str>,
130    ) -> Result<Vec<u8>, FutuError> {
131        // 1. 解密 C2S(如果配置了 RSA)
132        let decrypted_body;
133        let req_body = if let Some(rsa_key) = rsa_private_key {
134            decrypted_body =
135                futu_net::encrypt::rsa_private_decrypt_blocks(rsa_key, body).map_err(|e| {
136                    tracing::warn!(error = %e, "RSA decrypt InitConnect C2S failed");
137                    e
138                })?;
139            tracing::debug!(
140                encrypted_len = body.len(),
141                decrypted_len = decrypted_body.len(),
142                "RSA decrypted InitConnect C2S"
143            );
144            &decrypted_body[..]
145        } else {
146            body
147        };
148
149        let _req: futu_proto::init_connect::Request =
150            prost::Message::decode(req_body).map_err(FutuError::Proto)?;
151
152        self.state = ConnState::Initialized;
153
154        // 当配置了 RSA 时,后续所有帧使用 AES 加解密
155        if rsa_private_key.is_some() {
156            self.aes_encrypt_enabled = true;
157            tracing::debug!(conn_id = self.conn_id, "AES body encryption enabled");
158        }
159
160        let aes_key_str = std::str::from_utf8(&self.aes_key).unwrap_or("").to_string();
161
162        let resp = futu_proto::init_connect::Response {
163            ret_type: 0,
164            ret_msg: None,
165            err_code: None,
166            s2c: Some(futu_proto::init_connect::S2c {
167                server_ver,
168                login_user_id,
169                conn_id: self.conn_id,
170                conn_aes_key: aes_key_str,
171                keep_alive_interval: keepalive_interval,
172                aes_cb_civ: None,
173                user_attribution: None,
174            }),
175        };
176
177        let resp_body = prost::Message::encode_to_vec(&resp);
178
179        // 2. 加密 S2C(如果配置了 RSA)
180        if let Some(rsa_key) = rsa_private_key {
181            let encrypted = futu_net::encrypt::rsa_public_encrypt_blocks(rsa_key, &resp_body)
182                .map_err(|e| {
183                    tracing::warn!(error = %e, "RSA encrypt InitConnect S2C failed");
184                    e
185                })?;
186            tracing::debug!(
187                plaintext_len = resp_body.len(),
188                encrypted_len = encrypted.len(),
189                "RSA encrypted InitConnect S2C"
190            );
191            Ok(encrypted)
192        } else {
193            Ok(resp_body)
194        }
195    }
196
197    /// 处理 KeepAlive 请求
198    pub fn handle_keepalive(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
199        let _req: futu_proto::keep_alive::Request =
200            prost::Message::decode(body).map_err(FutuError::Proto)?;
201
202        self.keepalive_count.fetch_add(1, Ordering::Relaxed);
203
204        let resp = futu_proto::keep_alive::Response {
205            ret_type: 0,
206            ret_msg: None,
207            err_code: None,
208            s2c: Some(futu_proto::keep_alive::S2c {
209                time: chrono::Utc::now().timestamp(),
210            }),
211        };
212
213        Ok(prost::Message::encode_to_vec(&resp))
214    }
215}
216
217/// 连接断开通知
218pub struct DisconnectNotify {
219    pub conn_id: u64,
220}
221
222/// 运行单个连接的收发循环
223///
224/// 返回 (接收请求的 channel, 连接信息)
225pub async fn run_connection(
226    stream: TcpStream,
227    conn_id: u64,
228    _aes_key: [u8; 16],
229    req_tx: mpsc::UnboundedSender<IncomingRequest>,
230    disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
231) -> mpsc::Sender<FutuFrame> {
232    let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
233
234    let framed = Framed::new(stream, FutuCodec);
235    let (mut sink, mut stream) = framed.split();
236
237    // 发送任务
238    tokio::spawn(async move {
239        while let Some(frame) = frame_rx.recv().await {
240            if let Err(e) = sink.send(frame).await {
241                tracing::warn!(conn_id = conn_id, error = %e, "send failed");
242                break;
243            }
244        }
245    });
246
247    // 接收任务
248    tokio::spawn(async move {
249        while let Some(result) = stream.next().await {
250            match result {
251                Ok(frame) => {
252                    let req = IncomingRequest {
253                        conn_id,
254                        proto_id: frame.header.proto_id,
255                        serial_no: frame.header.serial_no,
256                        proto_fmt_type: frame.header.proto_fmt_type,
257                        body: frame.body,
258                    };
259                    if req_tx.send(req).is_err() {
260                        break;
261                    }
262                }
263                Err(e) => {
264                    tracing::warn!(conn_id = conn_id, error = %e, "recv error");
265                    break;
266                }
267            }
268        }
269        tracing::info!(conn_id = conn_id, "connection closed");
270        // 通知 listener 清理连接
271        let _ = disconnect_tx.send(DisconnectNotify { conn_id });
272    });
273
274    frame_tx
275}