Skip to main content

futu_net/
client.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use bytes::Bytes;
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot};
8
9use futu_codec::frame::FutuFrame;
10use futu_core::error::{FutuError, Result};
11use futu_core::proto_id;
12
13use crate::connection::Connection;
14use crate::encrypt;
15use crate::reconnect::ReconnectPolicy;
16
17/// 客户端配置
18#[derive(Debug, Clone)]
19pub struct ClientConfig {
20    pub addr: String,
21    pub client_ver: String,
22    pub client_id: String,
23    pub recv_notify: bool,
24    /// RSA 私钥(PEM 格式)。若提供则启用加密:
25    /// - InitConnect 使用 RSA 加解密
26    /// - 后续请求使用 AES-128 ECB 加解密
27    /// - 若不提供则全程明文通信
28    pub rsa_key: Option<String>,
29}
30
31/// 推送消息
32#[derive(Debug, Clone)]
33pub struct PushMessage {
34    pub proto_id: u32,
35    pub body: Bytes,
36}
37
38/// 请求上下文(等待响应的 oneshot sender)
39struct PendingRequest {
40    tx: oneshot::Sender<FutuFrame>,
41}
42
43/// FutuOpenD 高层客户端
44///
45/// 功能:
46/// - InitConnect 握手
47/// - 自动心跳 KeepAlive
48/// - 请求-响应匹配(serial number)
49/// - 推送消息分发
50/// - 断线重连
51pub struct FutuClient {
52    config: ClientConfig,
53    serial_no: AtomicU32,
54    pending: Arc<DashMap<u32, PendingRequest>>,
55    push_tx: mpsc::UnboundedSender<PushMessage>,
56    cmd_tx: Option<mpsc::Sender<ClientCommand>>,
57    conn_aes_key: parking_lot::Mutex<Option<[u8; 16]>>,
58    conn_id: AtomicU64,
59}
60
61enum ClientCommand {
62    Send(FutuFrame, oneshot::Sender<()>),
63}
64
65fn decode_client_inbound_body(frame: &FutuFrame, aes_key: Option<&[u8; 16]>) -> Result<Bytes> {
66    match aes_key {
67        Some(key) if !frame.body.is_empty() => {
68            encrypt::aes_ecb_decrypt(key, &frame.body).map(Bytes::from)
69        }
70        _ => Ok(frame.body.clone()),
71    }
72}
73
74fn release_pending_on_inbound_decode_error(
75    frame: &FutuFrame,
76    pending: &DashMap<u32, PendingRequest>,
77) {
78    if !proto_id::is_push_proto(frame.header.proto_id) {
79        pending.remove(&frame.header.serial_no);
80    }
81}
82
83impl FutuClient {
84    /// 创建客户端(未连接状态)
85    ///
86    /// 返回 (client, push_rx),其中 push_rx 用于接收服务端推送。
87    pub fn new(config: ClientConfig) -> (Self, mpsc::UnboundedReceiver<PushMessage>) {
88        let (push_tx, push_rx) = mpsc::unbounded_channel();
89
90        let client = Self {
91            config,
92            serial_no: AtomicU32::new(0),
93            pending: Arc::new(DashMap::new()),
94            push_tx,
95            cmd_tx: None,
96            conn_aes_key: parking_lot::Mutex::new(None),
97            conn_id: AtomicU64::new(0),
98        };
99
100        (client, push_rx)
101    }
102
103    /// 连接到 OpenD 网关并完成 InitConnect 握手
104    pub async fn connect(&mut self) -> Result<InitConnectInfo> {
105        let mut conn = Connection::connect(&self.config.addr).await?;
106
107        // 发送 InitConnect 请求
108        let serial = self.next_serial();
109        let req = futu_proto::init_connect::Request {
110            c2s: futu_proto::init_connect::C2s {
111                client_ver: 100,
112                client_id: self.config.client_id.clone(),
113                recv_notify: Some(self.config.recv_notify),
114                packet_enc_algo: Some(0),
115                push_proto_fmt: Some(0),
116                programming_language: Some(String::new()),
117            },
118        };
119        let raw_body = prost::Message::encode_to_vec(&req);
120
121        // 如果有 RSA 密钥,用 RSA 公钥加密 InitConnect 请求 body
122        let send_body = if let Some(ref rsa_key) = self.config.rsa_key {
123            tracing::debug!("encrypting InitConnect with RSA");
124            encrypt::rsa_public_encrypt(rsa_key, &raw_body)?
125        } else {
126            raw_body.clone()
127        };
128
129        // SHA1 基于明文
130        let mut frame = Connection::build_frame(proto_id::INIT_CONNECT, serial, send_body);
131        {
132            use sha1::{Digest, Sha1};
133            let mut hasher = Sha1::new();
134            hasher.update(&raw_body);
135            let hash = hasher.finalize();
136            frame.header.body_sha1.copy_from_slice(&hash);
137        }
138        conn.send(frame).await?;
139
140        // 接收 InitConnect 响应
141        let resp_frame = conn.recv().await?.ok_or(FutuError::Codec(
142            "connection closed during InitConnect".into(),
143        ))?;
144
145        // 如果有 RSA 密钥,用 RSA 私钥解密 InitConnect 响应 body
146        let resp_body = if let Some(ref rsa_key) = self.config.rsa_key {
147            tracing::debug!("decrypting InitConnect response with RSA");
148            encrypt::rsa_private_decrypt(rsa_key, &resp_frame.body)?
149        } else {
150            resp_frame.body.to_vec()
151        };
152
153        let resp: futu_proto::init_connect::Response =
154            prost::Message::decode(resp_body.as_slice()).map_err(FutuError::Proto)?;
155
156        // 检查返回码
157        let ret_type = resp.ret_type;
158        if ret_type != 0 {
159            return Err(FutuError::ServerError {
160                ret_type,
161                msg: resp.ret_msg.unwrap_or_default(),
162            });
163        }
164
165        let s2c = resp.s2c.ok_or(FutuError::Codec(
166            "missing s2c in InitConnect response".into(),
167        ))?;
168
169        let info = InitConnectInfo {
170            server_ver: s2c.server_ver,
171            login_user_id: s2c.login_user_id,
172            conn_id: s2c.conn_id,
173            conn_aes_key: s2c.conn_aes_key.clone(),
174            keep_alive_interval: s2c.keep_alive_interval,
175        };
176        self.conn_id.store(info.conn_id, Ordering::Relaxed);
177
178        // 保存 AES key
179        if !info.conn_aes_key.is_empty() {
180            let key_bytes = info.conn_aes_key.as_bytes();
181            tracing::debug!(
182                key_len = key_bytes.len(),
183                key_hex = hex_str(key_bytes),
184                "received AES key"
185            );
186            if key_bytes.len() == 16 {
187                // key 直接是 16 字节原始值
188                let mut key = [0u8; 16];
189                key.copy_from_slice(key_bytes);
190                *self.conn_aes_key.lock() = Some(key);
191            } else if key_bytes.len() == 32 {
192                // key 是 32 字符十六进制编码的 16 字节
193                if let Some(key) = hex_decode_16(key_bytes) {
194                    *self.conn_aes_key.lock() = Some(key);
195                } else {
196                    tracing::warn!(
197                        "AES key is 32 chars but not valid hex, using raw first 16 bytes"
198                    );
199                    let mut key = [0u8; 16];
200                    key.copy_from_slice(&key_bytes[..16]);
201                    *self.conn_aes_key.lock() = Some(key);
202                }
203            } else {
204                tracing::warn!(
205                    key_len = key_bytes.len(),
206                    "unexpected AES key length, using raw bytes (truncated/padded to 16)"
207                );
208                let mut key = [0u8; 16];
209                let copy_len = key_bytes.len().min(16);
210                key[..copy_len].copy_from_slice(&key_bytes[..copy_len]);
211                *self.conn_aes_key.lock() = Some(key);
212            }
213        }
214
215        tracing::info!(
216            server_ver = info.server_ver,
217            conn_id = info.conn_id,
218            keep_alive_interval = info.keep_alive_interval,
219            "InitConnect succeeded"
220        );
221
222        // 启动后台任务:心跳、消息接收
223        let keep_alive_interval = Duration::from_secs(info.keep_alive_interval as u64);
224        self.start_background_tasks(conn, keep_alive_interval);
225
226        Ok(info)
227    }
228
229    /// 发送请求并等待响应
230    pub async fn request(&self, proto_id: u32, body: Vec<u8>) -> Result<FutuFrame> {
231        let serial = self.next_serial();
232
233        // 加密 body(如果有 AES key)
234        let (final_body, sha1) = self.prepare_body(&body);
235
236        let mut frame = FutuFrame::new(proto_id, serial, Bytes::from(final_body));
237        // 使用明文的 SHA1
238        frame.header.body_sha1 = sha1;
239
240        let (resp_tx, resp_rx) = oneshot::channel();
241        self.pending.insert(serial, PendingRequest { tx: resp_tx });
242
243        // 通过 command channel 发送
244        if let Some(cmd_tx) = &self.cmd_tx {
245            let (ack_tx, ack_rx) = oneshot::channel();
246            cmd_tx
247                .send(ClientCommand::Send(frame, ack_tx))
248                .await
249                .map_err(|_| FutuError::NotInitialized)?;
250            ack_rx
251                .await
252                .map_err(|_| FutuError::Codec("send ack failed".into()))?;
253        } else {
254            self.pending.remove(&serial);
255            return Err(FutuError::NotInitialized);
256        }
257
258        // 等待响应(带超时)
259        let resp = tokio::time::timeout(Duration::from_secs(12), resp_rx)
260            .await
261            .map_err(|_| {
262                self.pending.remove(&serial);
263                FutuError::Timeout
264            })?
265            .map_err(|_| FutuError::Codec("response channel closed".into()))?;
266
267        Ok(resp)
268    }
269
270    /// Server-assigned connection id from InitConnect S2C.
271    ///
272    /// FTAPI trade-write packet ids must echo this value in `PacketID.connID`;
273    /// the gateway replay guard compares it against the actual TCP connection.
274    pub fn conn_id(&self) -> Option<u64> {
275        let conn_id = self.conn_id.load(Ordering::Relaxed);
276        (conn_id != 0).then_some(conn_id)
277    }
278
279    fn next_serial(&self) -> u32 {
280        self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
281    }
282
283    fn prepare_body(&self, plaintext: &[u8]) -> (Vec<u8>, [u8; 20]) {
284        use sha1::{Digest, Sha1};
285
286        // SHA1 始终基于明文
287        let mut hasher = Sha1::new();
288        hasher.update(plaintext);
289        let sha1_result = hasher.finalize();
290        let mut sha1 = [0u8; 20];
291        sha1.copy_from_slice(&sha1_result);
292
293        // 仅在有 RSA 密钥(即启用加密)时才用 AES 加密
294        let body = if self.config.rsa_key.is_some() {
295            let key = self.conn_aes_key.lock();
296            match key.as_ref() {
297                Some(k) => encrypt::aes_ecb_encrypt(k, plaintext),
298                None => plaintext.to_vec(),
299            }
300        } else {
301            plaintext.to_vec()
302        };
303
304        (body, sha1)
305    }
306
307    fn start_background_tasks(&mut self, conn: Connection, keep_alive_interval: Duration) {
308        let (cmd_tx, cmd_rx) = mpsc::channel(256);
309        self.cmd_tx = Some(cmd_tx.clone());
310
311        let pending = Arc::clone(&self.pending);
312        let push_tx = self.push_tx.clone();
313        let aes_key = if self.config.rsa_key.is_some() {
314            *self.conn_aes_key.lock()
315        } else {
316            None
317        };
318
319        tokio::spawn(async move {
320            run_event_loop(conn, cmd_rx, pending, push_tx, aes_key, keep_alive_interval).await;
321        });
322    }
323}
324
325/// 后台事件循环:处理接收、心跳、发送
326async fn run_event_loop(
327    mut conn: Connection,
328    mut cmd_rx: mpsc::Receiver<ClientCommand>,
329    pending: Arc<DashMap<u32, PendingRequest>>,
330    push_tx: mpsc::UnboundedSender<PushMessage>,
331    aes_key: Option<[u8; 16]>,
332    keep_alive_interval: Duration,
333) {
334    let mut heartbeat = tokio::time::interval(keep_alive_interval);
335    heartbeat.tick().await; // 跳过第一次立即触发
336    let mut heartbeat_serial: u32 = 10_000_000; // 心跳用独立序列号空间
337
338    loop {
339        tokio::select! {
340            // 接收服务端消息
341            result = conn.recv() => {
342                match result {
343                    Ok(Some(frame)) => {
344                        let proto_id = frame.header.proto_id;
345                        let body = match decode_client_inbound_body(&frame, aes_key.as_ref()) {
346                            Ok(body) => body,
347                            Err(e) => {
348                                tracing::warn!(
349                                    error = %e,
350                                    serial = frame.header.serial_no,
351                                    proto_id,
352                                    "decrypt failed, dropping inbound frame"
353                                );
354                                release_pending_on_inbound_decode_error(&frame, &pending);
355                                continue;
356                            }
357                        };
358
359                        if proto_id::is_push_proto(proto_id) {
360                            // 推送消息
361                            let _ = push_tx.send(PushMessage { proto_id, body });
362                        } else {
363                            // 响应消息:匹配 serial number
364                            let serial = frame.header.serial_no;
365                            match pending.remove(&serial) { Some((_, req)) => {
366                                let resp_frame = FutuFrame {
367                                    header: frame.header,
368                                    body,
369                                };
370                                let _ = req.tx.send(resp_frame);
371                            } _ => {
372                                tracing::debug!(serial = serial, proto_id = proto_id, "unmatched response");
373                            }}
374                        }
375                    }
376                    Ok(None) => {
377                        tracing::warn!("connection closed by server");
378                        break;
379                    }
380                    Err(e) => {
381                        tracing::error!(error = %e, "recv error");
382                        break;
383                    }
384                }
385            }
386
387            // 发送命令
388            cmd = cmd_rx.recv() => {
389                match cmd {
390                    Some(ClientCommand::Send(frame, ack)) => {
391                        let result = conn.send(frame).await;
392                        if let Err(e) = &result {
393                            tracing::error!(error = %e, "send failed");
394                        }
395                        let _ = ack.send(());
396                        if result.is_err() {
397                            break;
398                        }
399                    }
400                    None => {
401                        tracing::info!("shutting down event loop");
402                        break;
403                    }
404                }
405            }
406
407            // 心跳
408            _ = heartbeat.tick() => {
409                heartbeat_serial += 1;
410                let req = futu_proto::keep_alive::Request {
411                    c2s: futu_proto::keep_alive::C2s {
412                        time: chrono::Utc::now().timestamp(),
413                    },
414                };
415                let body = prost::Message::encode_to_vec(&req);
416                let frame = Connection::build_frame(
417                    proto_id::KEEP_ALIVE,
418                    heartbeat_serial,
419                    body,
420                );
421                if let Err(e) = conn.send(frame).await {
422                    tracing::error!(error = %e, "heartbeat send failed");
423                    break;
424                }
425                tracing::trace!("heartbeat sent");
426            }
427        }
428    }
429
430    // 清理所有 pending 请求
431    pending.clear();
432    tracing::info!("event loop exited");
433}
434
435/// InitConnect 握手返回的信息
436#[derive(Debug, Clone)]
437pub struct InitConnectInfo {
438    pub server_ver: i32,
439    pub login_user_id: u64,
440    pub conn_id: u64,
441    pub conn_aes_key: String,
442    pub keep_alive_interval: i32,
443}
444
445/// 带自动重连的客户端包装
446pub struct ReconnectingClient {
447    config: ClientConfig,
448    policy: ReconnectPolicy,
449}
450
451impl ReconnectingClient {
452    pub fn new(config: ClientConfig) -> Self {
453        Self {
454            config,
455            policy: ReconnectPolicy::default_policy(),
456        }
457    }
458
459    pub fn with_policy(mut self, policy: ReconnectPolicy) -> Self {
460        self.policy = policy;
461        self
462    }
463
464    /// 带重连的连接循环
465    ///
466    /// 返回成功连接的 (FutuClient, push_rx, InitConnectInfo)。
467    /// 如果达到最大重试次数则返回错误。
468    pub async fn connect(
469        &mut self,
470    ) -> Result<(
471        FutuClient,
472        mpsc::UnboundedReceiver<PushMessage>,
473        InitConnectInfo,
474    )> {
475        loop {
476            let (mut client, push_rx) = FutuClient::new(self.config.clone());
477            match client.connect().await {
478                Ok(info) => {
479                    self.policy.reset();
480                    return Ok((client, push_rx, info));
481                }
482                Err(e) => {
483                    tracing::warn!(
484                        error = %e,
485                        attempt = self.policy.attempts(),
486                        "connection failed"
487                    );
488                    match self.policy.next_delay() {
489                        Some(delay) => {
490                            tracing::info!(delay_ms = delay.as_millis(), "reconnecting...");
491                            tokio::time::sleep(delay).await;
492                        }
493                        None => {
494                            return Err(FutuError::Codec(format!(
495                                "max retries reached after {} attempts",
496                                self.policy.attempts()
497                            )));
498                        }
499                    }
500                }
501            }
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests;
508
509fn hex_str(bytes: &[u8]) -> String {
510    bytes.iter().map(|b| format!("{b:02x}")).collect()
511}
512
513fn hex_decode_16(hex_bytes: &[u8]) -> Option<[u8; 16]> {
514    if hex_bytes.len() != 32 {
515        return None;
516    }
517    let hex_str = std::str::from_utf8(hex_bytes).ok()?;
518    let mut key = [0u8; 16];
519    for i in 0..16 {
520        key[i] = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).ok()?;
521    }
522    Some(key)
523}