Skip to main content

futu_backend/
conn.rs

1// 后端 TCP 连接管理(直连富途后端服务器)
2//
3// 加密方式:
4// - 登录命令(1001/6001/26001): 不加密
5// - 其他命令: AES-128 加密,body = encrypt(sec_data(4B BE) + proto_data)
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU32, Ordering};
10
11use bytes::Bytes;
12use futures::{SinkExt, StreamExt};
13use parking_lot::Mutex;
14use tokio::net::TcpStream;
15use tokio::sync::{mpsc, oneshot};
16use tokio_util::codec::Framed;
17
18use futu_core::error::{FutuError, Result};
19use futu_net::encrypt::{aes_cbc_md5_decrypt_var, aes_cbc_md5_encrypt_var};
20
21use crate::nn_codec::{NNCodec, NNFrame, NNHeader, should_skip_encryption};
22
23/// 后端连接
24pub struct BackendConn {
25    serial_no: AtomicU32,
26    sec_data: AtomicU32,
27    /// 共享的 session key — 接收任务和发送方共用同一个 Arc。
28    /// 对齐 C++ `Logger::session_key_` 是 `std::string`(`logger.h:152`),
29    /// 长度由服务端下发决定,Platform 通常 16 字节(AES-128),Broker 可能是
30    /// 32 字节(AES-256)。v1.4.7 之前固定 `[u8; 16]` 对 broker session_key
31    /// 截断会导致服务端 `CONN decrypt failed`。
32    session_key: Arc<Mutex<Option<Vec<u8>>>>,
33    cmd_tx: mpsc::Sender<BackendCmd>,
34    pub user_id: AtomicU32,
35    /// RspEncryptData.client_ip(field 14), set after TCP login.
36    ///
37    /// C++ `logger.cpp:511-516` stores this server-observed public IP on the
38    /// working TCP client, then `logger.cpp:1122-1125` echoes it in CMD20147.
39    client_ip: Mutex<String>,
40    pub client_type: u8,
41    pub client_ver: u16,
42    pub lang_id: u8,
43}
44
45enum BackendCmd {
46    Send(NNFrame, Option<oneshot::Sender<NNFrame>>),
47}
48
49/// 后端推送回调
50pub type PushCallback = Arc<dyn Fn(u16, Bytes) + Send + Sync + 'static>;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53enum InboundFrameDecision {
54    Deliver,
55    Drop,
56}
57
58fn decode_inbound_frame_body(
59    frame: &mut NNFrame,
60    session_key: Option<&[u8]>,
61) -> InboundFrameDecision {
62    if !should_skip_encryption(frame.header.cmd_id)
63        && let Some(key) = session_key
64    {
65        let body_len = frame.body.len();
66        // C++ `logger.cpp:1840-1855` 的约束:加密包 body_len 必须 >= 32
67        // 且是 16 字节对齐(AES-CBC-MD5 最小输出 32 字节)。
68        // C++ 还有特例:body_len == 32 认为是空包,直接清空。
69        if body_len >= 32 && body_len.is_multiple_of(16) {
70            match aes_cbc_md5_decrypt_var(key, &frame.body) {
71                Ok(decrypted) => {
72                    // 后端响应不含 sec_data 前缀(仅 client→server 方向有)
73                    frame.body = Bytes::from(decrypted);
74                }
75                Err(e) => {
76                    // C++ `NNTCPConnBase.cpp:347-351` 会先用 current session
77                    // key 解密,再用 old session key 重试;两次失败后
78                    // `If_OMWarn_ReturnVoid`,不会把原始密文交给业务层。
79                    tracing::warn!(
80                        cmd_id = frame.header.cmd_id,
81                        body_len = body_len,
82                        key_len = key.len(),
83                        error = %e,
84                        "decrypt failed, dropping inbound frame"
85                    );
86                    return InboundFrameDecision::Drop;
87                }
88            }
89        } else {
90            tracing::debug!(
91                cmd_id = frame.header.cmd_id,
92                body_len = body_len,
93                "body not encrypted (len not aligned to 16)"
94            );
95        }
96    }
97
98    if frame.header.is_compressed() {
99        let compressed_body_len = frame.body.len();
100        match crate::ftlogin_wire::decode_inbound_body(true, frame.body.as_ref()) {
101            Ok(decompressed) => {
102                frame.body = Bytes::from(decompressed);
103                frame.header.body_len = frame.body.len() as u32;
104                tracing::debug!(
105                    cmd_id = frame.header.cmd_id,
106                    serial_no = frame.header.serial_no,
107                    compressed_body_len,
108                    body_len = frame.body.len(),
109                    "decompressed inbound frame after decrypt"
110                );
111            }
112            Err(e) => {
113                // C++ FTLogin `channel_impl.cpp:1905-1954` marks
114                // decompression errors as `kSdkMsgRecvDataError`; it does not
115                // deliver the compressed raw body as a successful business
116                // payload.
117                tracing::warn!(
118                    cmd_id = frame.header.cmd_id,
119                    serial_no = frame.header.serial_no,
120                    body_len = compressed_body_len,
121                    error = %e,
122                    "decompress failed after decrypt, dropping inbound frame"
123                );
124                return InboundFrameDecision::Drop;
125            }
126        }
127    }
128
129    InboundFrameDecision::Deliver
130}
131
132fn release_pending_on_inbound_decode_error(
133    frame: &NNFrame,
134    pending: &Arc<Mutex<HashMap<u32, oneshot::Sender<NNFrame>>>>,
135) {
136    if !frame.header.is_push {
137        pending.lock().remove(&frame.header.serial_no);
138    }
139}
140
141impl BackendConn {
142    /// 连接 TCP 的超时时间。
143    ///
144    /// Linux 默认 `tcp_syn_retries=6` 时 `TcpStream::connect` 等到 `ETIMEDOUT`
145    /// 需要约 127 秒——用户启动 opend 时如果选到一个不通的 IP 就卡 2 分钟后
146    /// 才报错 offline mode(某位 Rocky Linux 用户踩过)。加 10s 超时快速失败,
147    /// 让上层(`bridge.rs` 的 connect 循环)有机会 fallback 到下一个候选 IP。
148    pub const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
149
150    /// Backend 侧可识别的 Rust OpenD 客户端版本号。
151    ///
152    /// Rust OpenD 的后台诊断版本号。
153    ///
154    /// 该值与普通 C++ OpenD 拉开,便于后台排查时区分 Rust 链路。服务端若要求
155    /// version >= 800,低版本会返回 `error_code=45 "当前应用版本过低"`。
156    pub const CLIENT_VER_FTGTW: u16 = 1030;
157
158    /// 建立底层 TcpStream —— 带超时 + set_nodelay。不 spawn 任何 task。
159    async fn establish_stream(addr: &str, timeout: std::time::Duration) -> Result<TcpStream> {
160        let stream = match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
161            Ok(Ok(s)) => s,
162            Ok(Err(e)) => return Err(e.into()),
163            Err(_elapsed) => {
164                return Err(FutuError::Network(std::io::Error::new(
165                    std::io::ErrorKind::TimedOut,
166                    format!("connect to {addr} timed out after {}s", timeout.as_secs()),
167                )));
168            }
169        };
170        stream.set_nodelay(true)?;
171        Ok(stream)
172    }
173
174    /// 连接到后端服务器(带 10s 超时)
175    pub async fn connect(addr: &str, push_callback: PushCallback) -> Result<Self> {
176        let stream = Self::establish_stream(addr, Self::CONNECT_TIMEOUT).await?;
177        tracing::info!(addr = addr, "connected to backend");
178        Self::from_stream(stream, push_callback)
179    }
180
181    /// 并发连接多个候选地址,谁先通用谁(对齐 C++ `connector.cpp:175-189`
182    /// `ConnectStrategyAddr` 的 concurrency_ip 语义)。
183    ///
184    /// - 每个候选独立带 `CONNECT_TIMEOUT`(10s)超时
185    /// - 第一个 `Ok(stream)` 胜出,其余 pending task drop 时会关闭半连接
186    /// - 全部失败返回最后一个错误
187    ///
188    /// 返回 `(BackendConn, winner_addr)`,调用方用 winner_addr 做登录协议里的
189    /// host_ip/host_port 字段。
190    pub async fn connect_race(
191        addrs: &[String],
192        push_callback: PushCallback,
193    ) -> Result<(Self, String)> {
194        use futures::stream::{FuturesUnordered, StreamExt};
195
196        if addrs.is_empty() {
197            return Err(FutuError::Network(std::io::Error::new(
198                std::io::ErrorKind::InvalidInput,
199                "connect_race: empty address list",
200            )));
201        }
202
203        tracing::info!(
204            candidates = addrs.len(),
205            addrs = ?addrs,
206            "racing parallel connects"
207        );
208
209        let mut attempts: FuturesUnordered<_> = addrs
210            .iter()
211            .cloned()
212            .map(|addr| async move {
213                let result = Self::establish_stream(&addr, Self::CONNECT_TIMEOUT).await;
214                (addr, result)
215            })
216            .collect();
217
218        let mut last_err: Option<FutuError> = None;
219        while let Some((addr, result)) = attempts.next().await {
220            match result {
221                Ok(stream) => {
222                    tracing::info!(
223                        addr = %addr,
224                        remaining_losers = attempts.len(),
225                        "connect race winner"
226                    );
227                    drop(attempts); // 其他 FuturesUnordered 里的 future drop 即取消
228                    let conn = Self::from_stream(stream, push_callback)?;
229                    return Ok((conn, addr));
230                }
231                Err(e) => {
232                    tracing::debug!(addr = %addr, error = %e, "candidate failed");
233                    last_err = Some(e);
234                }
235            }
236        }
237
238        Err(last_err.unwrap_or_else(|| {
239            FutuError::Network(std::io::Error::other("connect_race: all candidates failed"))
240        }))
241    }
242
243    /// v1.4.70 D1: test-only 从 `tokio::io::DuplexStream` 构造 BackendConn
244    ///
245    /// 用于 integration tests(`crates/futu-gateway/tests/common/mock_backend.rs`)
246    /// 替代真 `TcpStream`。生产路径通过 `connect()` → `from_stream()` 不变。
247    #[cfg(feature = "test-util")]
248    pub fn from_duplex(
249        stream: tokio::io::DuplexStream,
250        push_callback: PushCallback,
251    ) -> Result<Self> {
252        Self::from_stream_inner(stream, push_callback)
253    }
254
255    /// 从已建立的 TcpStream 构造 BackendConn(spawn recv/send task)
256    fn from_stream(stream: TcpStream, push_callback: PushCallback) -> Result<Self> {
257        Self::from_stream_inner(stream, push_callback)
258    }
259
260    /// v1.4.70 D1: 泛型化的 stream → BackendConn 构造(内部实现),
261    /// 生产路径用 `TcpStream`,test 用 `DuplexStream`。
262    pub(crate) fn from_stream_inner<S>(stream: S, push_callback: PushCallback) -> Result<Self>
263    where
264        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static,
265    {
266        let framed = Framed::new(stream, NNCodec);
267        let (mut sink, mut stream_rx) = framed.split();
268
269        let (cmd_tx, mut cmd_rx) = mpsc::channel::<BackendCmd>(256);
270
271        let pending = Arc::new(Mutex::new(HashMap::<u32, oneshot::Sender<NNFrame>>::new()));
272        let pending_recv = pending.clone();
273        let pending_send = pending.clone();
274        let session_key: Arc<Mutex<Option<Vec<u8>>>> = Arc::new(Mutex::new(None));
275        let session_key_for_recv = session_key.clone();
276
277        // 接收任务
278        tokio::spawn(async move {
279            while let Some(Ok(mut frame)) = stream_rx.next().await {
280                // 解密(非登录命令)
281                tracing::debug!(
282                    cmd_id = frame.header.cmd_id,
283                    serial_no = frame.header.serial_no,
284                    is_push = frame.header.is_push,
285                    is_compressed = frame.header.is_compressed,
286                    ex_head_len = frame.header.ex_head_len,
287                    body_len = frame.header.body_len,
288                    actual_body_len = frame.body.len(),
289                    "recv frame"
290                );
291
292                // 如果 ex_head 里有业务错误(err_info.cmd_result != 0),
293                // 把它 log 出来 —— 服务端 body 空时会通过 ex_head 返回错误。
294                //
295                // 软失败特判:某些返回 code 是正常的"无此服务/无数据"信号(例如
296                // 800000 账号在 Futu HK broker 上没有授权账户时 CMD 2298 会收到
297                // `code=-102 CONN can not find command service`)。这类日志降到
298                // DEBUG 避免噪声;其他错误保持 WARN。
299                if let Some(err) = frame.parse_ex_head_error()
300                    && (err.cmd_result != 0
301                        || err.code != 0
302                        || !err.message.is_empty()
303                        || !err.source.is_empty())
304                {
305                    // `code=-102` 是服务端的软失败信号:"此账户/通道不支持该 cmd"。
306                    // 服务端实际把 "CONN can not find command service" 放在 `source`
307                    // 字段里,`message` 为空(和 C++ ErrorInfo 字段 2/4 定义略有偏差)。
308                    // 例如:800000 账号 Futu HK broker 通道不认 CMD 2298 / CMD 1003 heartbeat
309                    // 都会走到这里。降级到 debug 避免日志噪声;其他 code 保持 warn。
310                    let is_soft_fail = err.code == -102;
311                    if is_soft_fail {
312                        tracing::debug!(
313                            cmd_id = frame.header.cmd_id,
314                            code = err.code,
315                            source = %err.source,
316                            message = %err.message,
317                            "server: cmd not available on this channel (soft-fail)"
318                        );
319                    } else {
320                        tracing::warn!(
321                            cmd_id = frame.header.cmd_id,
322                            cmd_result = err.cmd_result,
323                            code = err.code,
324                            source = %err.source,
325                            message = %err.message,
326                            "server returned err_info in ex_head"
327                        );
328                    }
329                }
330
331                let session_key = session_key_for_recv.lock().clone();
332                if decode_inbound_frame_body(&mut frame, session_key.as_deref())
333                    == InboundFrameDecision::Drop
334                {
335                    release_pending_on_inbound_decode_error(&frame, &pending_recv);
336                    continue;
337                }
338
339                // 判断是否为推送帧:
340                // 1. flags.push_ == 1 (标准推送)
341                // 2. flags.push_ == 0 但 serial_no == 0 (后端首次订阅快照,
342                //    以 Reply 形式发送, 如 CMD6212 的初始摆盘/报价)
343                let is_push = frame.header.is_push
344                    || (frame.header.serial_no == 0 && !pending_recv.lock().contains_key(&0));
345                if is_push {
346                    tracing::debug!(
347                        cmd_id = frame.header.cmd_id,
348                        body_len = frame.body.len(),
349                        is_push = frame.header.is_push,
350                        is_compressed = frame.header.is_compressed,
351                        reserved = ?frame.header.reserved,
352                        "backend push received"
353                    );
354                    push_callback(frame.header.cmd_id, frame.body);
355                } else {
356                    let tx = pending_recv.lock().remove(&frame.header.serial_no);
357                    if let Some(tx) = tx {
358                        let _ = tx.send(frame);
359                    }
360                }
361            }
362            // 连接断开 — 清理所有 pending 请求 (C++ OnDisConnectRelpyAll)
363            tracing::warn!("backend connection closed");
364            let mut pending = pending_recv.lock();
365            let count = pending.len();
366            if count > 0 {
367                tracing::warn!(
368                    pending_count = count,
369                    "aborting pending requests due to disconnect"
370                );
371            }
372            // 清空 pending HashMap 会 drop 所有 oneshot::Sender
373            // 这会导致对应的 oneshot::Receiver 返回 RecvError
374            // 从而让 request() 方法返回 "response channel closed" 错误
375            pending.clear();
376        });
377
378        // 发送任务
379        tokio::spawn(async move {
380            while let Some(cmd) = cmd_rx.recv().await {
381                match cmd {
382                    BackendCmd::Send(frame, resp_tx) => {
383                        if let Some(tx) = resp_tx {
384                            pending_send.lock().insert(frame.header.serial_no, tx);
385                        }
386                        if let Err(e) = sink.send(frame).await {
387                            tracing::error!(error = %e, "backend send failed");
388                            break;
389                        }
390                    }
391                }
392            }
393        });
394
395        Ok(Self {
396            serial_no: AtomicU32::new(0),
397            sec_data: AtomicU32::new(1),
398            session_key, // 与接收任务共享同一个 Arc
399            cmd_tx,
400            user_id: AtomicU32::new(0),
401            client_ip: Mutex::new(String::new()),
402            client_type: 40, // NN_ClientType_ApiGateway
403            client_ver: Self::CLIENT_VER_FTGTW,
404            lang_id: 0,
405        })
406    }
407
408    /// 设置 session key(登录成功后调用)。接受变长字节,16/24/32 分别对应
409    /// AES-128/192/256。对齐 C++ `Logger::session_key_` 是 `std::string`,
410    /// 长度取决于服务端下发的 `RspEncryptData.session_key` 字段原始长度。
411    pub fn set_session_key(&self, key: Vec<u8>) {
412        *self.session_key.lock() = Some(key);
413    }
414
415    /// 设置 sec_data 初始值(登录成功后调用)
416    pub fn set_sec_data(&self, val: u32) {
417        self.sec_data.store(val, Ordering::Relaxed);
418    }
419
420    /// 设置登录响应里的客户端外网 IP。
421    pub fn set_client_ip(&self, ip: String) {
422        *self.client_ip.lock() = ip;
423    }
424
425    /// 读取登录响应里的客户端外网 IP。
426    pub fn client_ip(&self) -> String {
427        self.client_ip.lock().clone()
428    }
429
430    /// 发送请求并等待响应
431    pub async fn request(&self, cmd_id: u16, body: Vec<u8>) -> Result<NNFrame> {
432        self.request_with_reserved(cmd_id, body, [0u8; 10]).await
433    }
434
435    /// 发送请求并等待响应(带 header reserved 字段,用于行情命令传递市场类型)
436    pub async fn request_with_reserved(
437        &self,
438        cmd_id: u16,
439        body: Vec<u8>,
440        reserved: [u8; 10],
441    ) -> Result<NNFrame> {
442        let frame = self.build_outbound_frame(cmd_id, body, reserved)?;
443
444        let (resp_tx, resp_rx) = oneshot::channel();
445        self.cmd_tx
446            .send(BackendCmd::Send(frame, Some(resp_tx)))
447            .await
448            .map_err(|_| FutuError::NotInitialized)?;
449
450        let resp = tokio::time::timeout(std::time::Duration::from_secs(10), resp_rx)
451            .await
452            .map_err(|_| FutuError::Timeout)?
453            .map_err(|_| FutuError::Codec("response channel closed".into()))?;
454
455        Ok(resp)
456    }
457
458    /// 发送无需等待响应的消息
459    pub async fn send_fire_and_forget(&self, cmd_id: u16, body: Vec<u8>) -> Result<()> {
460        let frame = self.build_outbound_frame(cmd_id, body, [0u8; 10])?;
461        self.cmd_tx
462            .send(BackendCmd::Send(frame, None))
463            .await
464            .map_err(|_| FutuError::NotInitialized)?;
465
466        Ok(())
467    }
468
469    fn build_outbound_frame(
470        &self,
471        cmd_id: u16,
472        body: Vec<u8>,
473        reserved: [u8; 10],
474    ) -> Result<NNFrame> {
475        let serial = self.next_serial();
476        let mut header = NNHeader::new(cmd_id, serial);
477        header.user_id = self.user_id.load(Ordering::Relaxed);
478        header.client_type = self.client_type;
479        header.client_ver = self.client_ver;
480        header.lang_id = self.lang_id;
481        // 行情命令实际用到的只有 reserved[0..2](market_type + ex_type),
482        // 后 2 字节(原 [8..10])在 C++ protocol header 里是 ex_head_len 位置,
483        // 我们不发 ex_head 所以这 2 字节固定为 0。
484        header.reserved.copy_from_slice(&reserved[..8]);
485
486        let final_body = self.encode_outbound_body(cmd_id, body)?;
487        header.body_len = final_body.len() as u32;
488
489        Ok(NNFrame {
490            header,
491            body: Bytes::from(final_body),
492            ex_head: Bytes::new(),
493        })
494    }
495
496    fn encode_outbound_body(&self, cmd_id: u16, body: Vec<u8>) -> Result<Vec<u8>> {
497        if should_skip_encryption(cmd_id) {
498            return Ok(body);
499        }
500
501        let key = self.session_key.lock().clone();
502        match key {
503            Some(key) => {
504                // C++ 先 m_nSecData++ 再使用,所以要用 +1 后的值。
505                let sec = self.sec_data.fetch_add(1, Ordering::Relaxed) + 1;
506                let mut plaintext = Vec::with_capacity(4 + body.len());
507                plaintext.extend_from_slice(&sec.to_be_bytes());
508                plaintext.extend_from_slice(&body);
509                // var 版支持 16/24/32 字节 key —— broker session_key 可能是 32 字节。
510                aes_cbc_md5_encrypt_var(&key, &plaintext)
511            }
512            None => Ok(body),
513        }
514    }
515
516    fn next_serial(&self) -> u32 {
517        self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
518    }
519}
520
521#[cfg(test)]
522mod tests;