futu_net/
client.rs

1use std::sync::atomic::{AtomicU32, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use bytes::Bytes;
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot};
8
9use futu_codec::frame::FutuFrame;
10use futu_core::error::{FutuError, Result};
11use futu_core::proto_id;
12
13use crate::connection::Connection;
14use crate::encrypt;
15use crate::reconnect::ReconnectPolicy;
16
17/// 客户端配置
18#[derive(Debug, Clone)]
19pub struct ClientConfig {
20    pub addr: String,
21    pub client_ver: String,
22    pub client_id: String,
23    pub recv_notify: bool,
24    /// RSA 私钥(PEM 格式)。若提供则启用加密:
25    /// - InitConnect 使用 RSA 加解密
26    /// - 后续请求使用 AES-128 ECB 加解密
27    /// - 若不提供则全程明文通信
28    pub rsa_key: Option<String>,
29}
30
31/// 推送消息
32#[derive(Debug, Clone)]
33pub struct PushMessage {
34    pub proto_id: u32,
35    pub body: Bytes,
36}
37
38/// 请求上下文(等待响应的 oneshot sender)
39struct PendingRequest {
40    tx: oneshot::Sender<FutuFrame>,
41}
42
43/// FutuOpenD 高层客户端
44///
45/// 功能:
46/// - InitConnect 握手
47/// - 自动心跳 KeepAlive
48/// - 请求-响应匹配(serial number)
49/// - 推送消息分发
50/// - 断线重连
51pub struct FutuClient {
52    config: ClientConfig,
53    serial_no: AtomicU32,
54    pending: Arc<DashMap<u32, PendingRequest>>,
55    push_tx: mpsc::UnboundedSender<PushMessage>,
56    cmd_tx: Option<mpsc::Sender<ClientCommand>>,
57    conn_aes_key: parking_lot::Mutex<Option<[u8; 16]>>,
58}
59
60enum ClientCommand {
61    Send(FutuFrame, oneshot::Sender<()>),
62    #[expect(dead_code)]
63    Shutdown,
64}
65
66impl FutuClient {
67    /// 创建客户端(未连接状态)
68    ///
69    /// 返回 (client, push_rx),其中 push_rx 用于接收服务端推送。
70    pub fn new(config: ClientConfig) -> (Self, mpsc::UnboundedReceiver<PushMessage>) {
71        let (push_tx, push_rx) = mpsc::unbounded_channel();
72
73        let client = Self {
74            config,
75            serial_no: AtomicU32::new(0),
76            pending: Arc::new(DashMap::new()),
77            push_tx,
78            cmd_tx: None,
79            conn_aes_key: parking_lot::Mutex::new(None),
80        };
81
82        (client, push_rx)
83    }
84
85    /// 连接到 OpenD 网关并完成 InitConnect 握手
86    pub async fn connect(&mut self) -> Result<InitConnectInfo> {
87        let mut conn = Connection::connect(&self.config.addr).await?;
88
89        // 发送 InitConnect 请求
90        let serial = self.next_serial();
91        let req = futu_proto::init_connect::Request {
92            c2s: futu_proto::init_connect::C2s {
93                client_ver: 100,
94                client_id: self.config.client_id.clone(),
95                recv_notify: Some(self.config.recv_notify),
96                packet_enc_algo: Some(0),
97                push_proto_fmt: Some(0),
98                programming_language: Some(String::new()),
99            },
100        };
101        let raw_body = prost::Message::encode_to_vec(&req);
102
103        // 如果有 RSA 密钥,用 RSA 公钥加密 InitConnect 请求 body
104        let send_body = if let Some(ref rsa_key) = self.config.rsa_key {
105            tracing::debug!("encrypting InitConnect with RSA");
106            encrypt::rsa_public_encrypt(rsa_key, &raw_body)?
107        } else {
108            raw_body.clone()
109        };
110
111        // SHA1 基于明文
112        let mut frame = Connection::build_frame(proto_id::INIT_CONNECT, serial, send_body);
113        {
114            use sha1::{Digest, Sha1};
115            let mut hasher = Sha1::new();
116            hasher.update(&raw_body);
117            let hash = hasher.finalize();
118            frame.header.body_sha1.copy_from_slice(&hash);
119        }
120        conn.send(frame).await?;
121
122        // 接收 InitConnect 响应
123        let resp_frame = conn.recv().await?.ok_or(FutuError::Codec(
124            "connection closed during InitConnect".into(),
125        ))?;
126
127        // 如果有 RSA 密钥,用 RSA 私钥解密 InitConnect 响应 body
128        let resp_body = if let Some(ref rsa_key) = self.config.rsa_key {
129            tracing::debug!("decrypting InitConnect response with RSA");
130            encrypt::rsa_private_decrypt(rsa_key, &resp_frame.body)?
131        } else {
132            resp_frame.body.to_vec()
133        };
134
135        let resp: futu_proto::init_connect::Response =
136            prost::Message::decode(resp_body.as_slice()).map_err(FutuError::Proto)?;
137
138        // 检查返回码
139        let ret_type = resp.ret_type;
140        if ret_type != 0 {
141            return Err(FutuError::ServerError {
142                ret_type,
143                msg: resp.ret_msg.unwrap_or_default(),
144            });
145        }
146
147        let s2c = resp.s2c.ok_or(FutuError::Codec(
148            "missing s2c in InitConnect response".into(),
149        ))?;
150
151        let info = InitConnectInfo {
152            server_ver: s2c.server_ver,
153            login_user_id: s2c.login_user_id,
154            conn_id: s2c.conn_id,
155            conn_aes_key: s2c.conn_aes_key.clone(),
156            keep_alive_interval: s2c.keep_alive_interval,
157        };
158
159        // 保存 AES key
160        if !info.conn_aes_key.is_empty() {
161            let key_bytes = info.conn_aes_key.as_bytes();
162            tracing::debug!(
163                key_len = key_bytes.len(),
164                key_hex = hex_str(key_bytes),
165                "received AES key"
166            );
167            if key_bytes.len() == 16 {
168                // key 直接是 16 字节原始值
169                let mut key = [0u8; 16];
170                key.copy_from_slice(key_bytes);
171                *self.conn_aes_key.lock() = Some(key);
172            } else if key_bytes.len() == 32 {
173                // key 是 32 字符十六进制编码的 16 字节
174                if let Some(key) = hex_decode_16(key_bytes) {
175                    *self.conn_aes_key.lock() = Some(key);
176                } else {
177                    tracing::warn!(
178                        "AES key is 32 chars but not valid hex, using raw first 16 bytes"
179                    );
180                    let mut key = [0u8; 16];
181                    key.copy_from_slice(&key_bytes[..16]);
182                    *self.conn_aes_key.lock() = Some(key);
183                }
184            } else {
185                tracing::warn!(
186                    key_len = key_bytes.len(),
187                    "unexpected AES key length, using raw bytes (truncated/padded to 16)"
188                );
189                let mut key = [0u8; 16];
190                let copy_len = key_bytes.len().min(16);
191                key[..copy_len].copy_from_slice(&key_bytes[..copy_len]);
192                *self.conn_aes_key.lock() = Some(key);
193            }
194        }
195
196        tracing::info!(
197            server_ver = info.server_ver,
198            conn_id = info.conn_id,
199            keep_alive_interval = info.keep_alive_interval,
200            "InitConnect succeeded"
201        );
202
203        // 启动后台任务:心跳、消息接收
204        let keep_alive_interval = Duration::from_secs(info.keep_alive_interval as u64);
205        self.start_background_tasks(conn, keep_alive_interval);
206
207        Ok(info)
208    }
209
210    /// 发送请求并等待响应
211    pub async fn request(&self, proto_id: u32, body: Vec<u8>) -> Result<FutuFrame> {
212        let serial = self.next_serial();
213
214        // 加密 body(如果有 AES key)
215        let (final_body, sha1) = self.prepare_body(&body);
216
217        let mut frame = FutuFrame::new(proto_id, serial, Bytes::from(final_body));
218        // 使用明文的 SHA1
219        frame.header.body_sha1 = sha1;
220
221        let (resp_tx, resp_rx) = oneshot::channel();
222        self.pending.insert(serial, PendingRequest { tx: resp_tx });
223
224        // 通过 command channel 发送
225        if let Some(cmd_tx) = &self.cmd_tx {
226            let (ack_tx, ack_rx) = oneshot::channel();
227            cmd_tx
228                .send(ClientCommand::Send(frame, ack_tx))
229                .await
230                .map_err(|_| FutuError::NotInitialized)?;
231            ack_rx
232                .await
233                .map_err(|_| FutuError::Codec("send ack failed".into()))?;
234        } else {
235            self.pending.remove(&serial);
236            return Err(FutuError::NotInitialized);
237        }
238
239        // 等待响应(带超时)
240        let resp = tokio::time::timeout(Duration::from_secs(12), resp_rx)
241            .await
242            .map_err(|_| {
243                self.pending.remove(&serial);
244                FutuError::Timeout
245            })?
246            .map_err(|_| FutuError::Codec("response channel closed".into()))?;
247
248        Ok(resp)
249    }
250
251    fn next_serial(&self) -> u32 {
252        self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
253    }
254
255    fn prepare_body(&self, plaintext: &[u8]) -> (Vec<u8>, [u8; 20]) {
256        use sha1::{Digest, Sha1};
257
258        // SHA1 始终基于明文
259        let mut hasher = Sha1::new();
260        hasher.update(plaintext);
261        let sha1_result = hasher.finalize();
262        let mut sha1 = [0u8; 20];
263        sha1.copy_from_slice(&sha1_result);
264
265        // 仅在有 RSA 密钥(即启用加密)时才用 AES 加密
266        let body = if self.config.rsa_key.is_some() {
267            let key = self.conn_aes_key.lock();
268            match key.as_ref() {
269                Some(k) => encrypt::aes_ecb_encrypt(k, plaintext),
270                None => plaintext.to_vec(),
271            }
272        } else {
273            plaintext.to_vec()
274        };
275
276        (body, sha1)
277    }
278
279    fn start_background_tasks(&mut self, conn: Connection, keep_alive_interval: Duration) {
280        let (cmd_tx, cmd_rx) = mpsc::channel(256);
281        self.cmd_tx = Some(cmd_tx.clone());
282
283        let pending = Arc::clone(&self.pending);
284        let push_tx = self.push_tx.clone();
285        let aes_key = if self.config.rsa_key.is_some() {
286            *self.conn_aes_key.lock()
287        } else {
288            None
289        };
290
291        tokio::spawn(async move {
292            run_event_loop(conn, cmd_rx, pending, push_tx, aes_key, keep_alive_interval).await;
293        });
294    }
295}
296
297/// 后台事件循环:处理接收、心跳、发送
298async fn run_event_loop(
299    mut conn: Connection,
300    mut cmd_rx: mpsc::Receiver<ClientCommand>,
301    pending: Arc<DashMap<u32, PendingRequest>>,
302    push_tx: mpsc::UnboundedSender<PushMessage>,
303    aes_key: Option<[u8; 16]>,
304    keep_alive_interval: Duration,
305) {
306    let mut heartbeat = tokio::time::interval(keep_alive_interval);
307    heartbeat.tick().await; // 跳过第一次立即触发
308    let mut heartbeat_serial: u32 = 10_000_000; // 心跳用独立序列号空间
309
310    loop {
311        tokio::select! {
312            // 接收服务端消息
313            result = conn.recv() => {
314                match result {
315                    Ok(Some(frame)) => {
316                        // 解密 body
317                        let body = match &aes_key {
318                            Some(key) if !frame.body.is_empty() => {
319                                match encrypt::aes_ecb_decrypt(key, &frame.body) {
320                                    Ok(decrypted) => Bytes::from(decrypted),
321                                    Err(e) => {
322                                        tracing::warn!(error = %e, "decrypt failed, using raw body");
323                                        frame.body.clone()
324                                    }
325                                }
326                            }
327                            _ => frame.body.clone(),
328                        };
329
330                        let proto_id = frame.header.proto_id;
331
332                        if proto_id::is_push_proto(proto_id) {
333                            // 推送消息
334                            let _ = push_tx.send(PushMessage { proto_id, body });
335                        } else {
336                            // 响应消息:匹配 serial number
337                            let serial = frame.header.serial_no;
338                            if let Some((_, req)) = pending.remove(&serial) {
339                                let resp_frame = FutuFrame {
340                                    header: frame.header,
341                                    body,
342                                };
343                                let _ = req.tx.send(resp_frame);
344                            } else {
345                                tracing::debug!(serial = serial, proto_id = proto_id, "unmatched response");
346                            }
347                        }
348                    }
349                    Ok(None) => {
350                        tracing::warn!("connection closed by server");
351                        break;
352                    }
353                    Err(e) => {
354                        tracing::error!(error = %e, "recv error");
355                        break;
356                    }
357                }
358            }
359
360            // 发送命令
361            cmd = cmd_rx.recv() => {
362                match cmd {
363                    Some(ClientCommand::Send(frame, ack)) => {
364                        let result = conn.send(frame).await;
365                        if let Err(e) = &result {
366                            tracing::error!(error = %e, "send failed");
367                        }
368                        let _ = ack.send(());
369                        if result.is_err() {
370                            break;
371                        }
372                    }
373                    Some(ClientCommand::Shutdown) | None => {
374                        tracing::info!("shutting down event loop");
375                        break;
376                    }
377                }
378            }
379
380            // 心跳
381            _ = heartbeat.tick() => {
382                heartbeat_serial += 1;
383                let req = futu_proto::keep_alive::Request {
384                    c2s: futu_proto::keep_alive::C2s {
385                        time: chrono::Utc::now().timestamp(),
386                    },
387                };
388                let body = prost::Message::encode_to_vec(&req);
389                let frame = Connection::build_frame(
390                    proto_id::KEEP_ALIVE,
391                    heartbeat_serial,
392                    body,
393                );
394                if let Err(e) = conn.send(frame).await {
395                    tracing::error!(error = %e, "heartbeat send failed");
396                    break;
397                }
398                tracing::trace!("heartbeat sent");
399            }
400        }
401    }
402
403    // 清理所有 pending 请求
404    pending.clear();
405    tracing::info!("event loop exited");
406}
407
408/// InitConnect 握手返回的信息
409#[derive(Debug, Clone)]
410pub struct InitConnectInfo {
411    pub server_ver: i32,
412    pub login_user_id: u64,
413    pub conn_id: u64,
414    pub conn_aes_key: String,
415    pub keep_alive_interval: i32,
416}
417
418/// 带自动重连的客户端包装
419pub struct ReconnectingClient {
420    config: ClientConfig,
421    policy: ReconnectPolicy,
422}
423
424impl ReconnectingClient {
425    pub fn new(config: ClientConfig) -> Self {
426        Self {
427            config,
428            policy: ReconnectPolicy::default_policy(),
429        }
430    }
431
432    pub fn with_policy(mut self, policy: ReconnectPolicy) -> Self {
433        self.policy = policy;
434        self
435    }
436
437    /// 带重连的连接循环
438    ///
439    /// 返回成功连接的 (FutuClient, push_rx, InitConnectInfo)。
440    /// 如果达到最大重试次数则返回错误。
441    pub async fn connect(
442        &mut self,
443    ) -> Result<(
444        FutuClient,
445        mpsc::UnboundedReceiver<PushMessage>,
446        InitConnectInfo,
447    )> {
448        loop {
449            let (mut client, push_rx) = FutuClient::new(self.config.clone());
450            match client.connect().await {
451                Ok(info) => {
452                    self.policy.reset();
453                    return Ok((client, push_rx, info));
454                }
455                Err(e) => {
456                    tracing::warn!(
457                        error = %e,
458                        attempt = self.policy.attempts(),
459                        "connection failed"
460                    );
461                    match self.policy.next_delay() {
462                        Some(delay) => {
463                            tracing::info!(delay_ms = delay.as_millis(), "reconnecting...");
464                            tokio::time::sleep(delay).await;
465                        }
466                        None => {
467                            return Err(FutuError::Codec(format!(
468                                "max retries reached after {} attempts",
469                                self.policy.attempts()
470                            )));
471                        }
472                    }
473                }
474            }
475        }
476    }
477}
478
479fn hex_str(bytes: &[u8]) -> String {
480    bytes.iter().map(|b| format!("{b:02x}")).collect()
481}
482
483fn hex_decode_16(hex_bytes: &[u8]) -> Option<[u8; 16]> {
484    if hex_bytes.len() != 32 {
485        return None;
486    }
487    let hex_str = std::str::from_utf8(hex_bytes).ok()?;
488    let mut key = [0u8; 16];
489    for i in 0..16 {
490        key[i] = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).ok()?;
491    }
492    Some(key)
493}