Skip to main content

futu_server/
conn.rs

1// 单连接管理:状态机、帧收发、加密、心跳超时
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::Instant;
6
7use bytes::Bytes;
8use futures::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio_util::codec::Framed;
12
13use futu_auth::Scope;
14use futu_codec::FutuCodec;
15use futu_codec::frame::FutuFrame;
16use futu_codec::header::ProtoFmtType;
17use futu_core::error::FutuError;
18use futu_net::encrypt;
19
20/// 连接状态
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ConnState {
24    /// TCP / WebSocket 刚建立,尚未完成 InitConnect 握手
25    Connected,
26    /// 已完成 InitConnect,等待首次业务请求
27    Initialized,
28    /// 正常交互中(业务请求 / KeepAlive / push 均已流转)
29    Active,
30    /// 连接已断开(客户端主动关闭 / 被动超时 / IO 错误)
31    Disconnected,
32}
33
34/// 单个客户端连接
35pub struct ClientConn {
36    /// 随机连接 ID(对应 C++ `GetRand_MilliTimeAndU22`)
37    pub conn_id: u64,
38    /// 连接状态:InitConnect 前 / 后 / 已断开
39    pub state: ConnState,
40    /// 随机 AES-128 key(InitConnect 响应里下发给客户端)
41    pub aes_key: [u8; 16],
42    /// AES 加解密已启用(InitConnect 完成且配置了 RSA 时为 true)
43    pub aes_encrypt_enabled: bool,
44    /// 该连接协商的 proto 格式(Protobuf / JSON)
45    pub proto_fmt_type: ProtoFmtType,
46    /// 上次收到 KeepAlive 的时间,用于超时检查
47    pub last_keepalive: Instant,
48    /// InitConnect.C2S.recvNotify:此连接是否接收市场状态 / 交易解锁等通知。
49    ///
50    /// C++ 在 `APIServer_InitConnect.cpp` 里把该字段写入 ConnInfo;
51    /// `RegQotPush` / `Qot_Sub(isRegOrUnRegPush)` 不会修改这个开关。
52    pub recv_notify: bool,
53    /// 已收到的 KeepAlive 计数(监控用)
54    pub keepalive_count: AtomicU32,
55    /// 发送帧到此连接
56    pub tx: mpsc::Sender<FutuFrame>,
57
58    // ---- v1.0 WS per-message scope 鉴权(raw TCP legacy 兼容:scopes 空集全放行)----
59    /// 该连接绑定的 API key id;WS 握手时填,未配 keys.json 时为 None
60    pub key_id: Option<String>,
61    /// 该连接持有的 scope 集合;空集 = legacy 模式 / TCP 直连,scope 检查放行
62    pub scopes: HashSet<Scope>,
63    /// v1.4.105 D3 (Phase 4) T-B2: 该连接 caller key 的 `allowed_markets` 硬
64    /// 限额 (大写字符串 set, e.g. {"HK","US"}). `None` = 无限制 (legacy 模式
65    /// / TCP 直连默认全开 / 未配 allowed_markets); `Some(set)` 非空 → push 端
66    /// 应过滤 trd_market 不在 set 中的 trade event.
67    ///
68    /// **触发**: WS handshake 时从 `KeyRecord.allowed_markets` 拷贝过来.
69    /// `PushDispatcher::push_trd_acc` 端 Layer 3 filter 检查. 与
70    /// `caller_allowed_acc_ids` (Layer 1, per-call snapshot in IncomingRequest)
71    /// 区别: 本字段是 per-conn snapshot (handshake 时一次性), 不随 per-call
72    /// 重读 — KeyRecord SIGHUP reload 后**仅新建连接生效**, 老连接保持 snapshot
73    /// (与 `scopes` / `caller_allowed_acc_ids` 的 snapshot 语义一致).
74    pub allowed_markets: Option<std::sync::Arc<HashSet<String>>>,
75    /// codex round 1 F4 (P2) v1.4.105: 该连接 caller key 的 `allowed_acc_ids`
76    /// 硬限额 (per-conn snapshot, handshake 时一次性). `None` / `Some(empty)` =
77    /// 无限制 (legacy 模式 / TCP 直连默认全开 / 未配 allowed_acc_ids);
78    /// `Some(non-empty set)` →
79    /// `PushDispatcher::push_trd_acc` 端 push-time 硬过滤 acc_id 不在 set 中
80    /// 的 trade event (Layer 1, 与 `allowed_markets` 的 Layer 3 互补).
81    /// Deny-all 使用 sentinel `{0}`,不使用空集合。
82    ///
83    /// **触发**: codex F4 指出 raw TCP push 端只查 `acc:read` scope +
84    /// `allowed_markets`, 不查 `allowed_acc_ids`. 即使 request-time
85    /// `SubAccPushHandler` 已阻止越权订阅, stale subscription / KeyRecord
86    /// reload 后窄化的 acc 范围 / 历史 bug 留下的 conn→acc 关系 仍可能让 push
87    /// 漏 leak. 本字段提供第二层 push-time 兜底.
88    ///
89    /// 与 `caller_allowed_acc_ids` (IncomingRequest, per-call) 区别: 本字段
90    /// 在 push-time 用 (无 IncomingRequest), per-conn snapshot 与 `scopes` /
91    /// `allowed_markets` 的 snapshot 语义一致.
92    pub allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
93}
94
95/// 从连接接收到的请求
96#[derive(Debug)]
97pub struct IncomingRequest {
98    /// 发送请求的连接 ID(用于响应路由 + SubscriptionManager / cache 记账)
99    ///
100    /// **跨 surface 命名空间分配**(v1.4.106 codex 0517 ζ25-redo F2 沉淀):
101    /// - raw TCP listener: `ClientConn::generate_conn_id()` 派生(u32 范围)
102    /// - REST: `crates/futu-rest/src/routes/qot.rs::REST_SHARED_CONN`
103    ///   = `0xFFFF_FFFE`(u32 上限附近, 单值共享)
104    /// - gRPC: `crates/futu-grpc/src/auth.rs::GRPC_STABLE_CONN_NAMESPACE`
105    ///   = `0x4000_0000_0000_0000`(bit 62 namespace, 按 caller 派生)
106    /// - WS / MCP: 通常派生自所属物理 TCP 连接的 conn_id
107    ///
108    /// 各 surface 不重叠. 加新 surface 时分配一个 namespace base, 不要与
109    /// 上述 4 个段重合.
110    pub conn_id: u64,
111    /// 协议 ID(对齐 C++ `NN_ProtoCmd_*`)
112    pub proto_id: u32,
113    /// 序列号(和 Response 配对,供 client 端请求-响应匹配)
114    pub serial_no: u32,
115    /// 请求体 proto 格式(Protobuf / JSON)
116    pub proto_fmt_type: ProtoFmtType,
117    /// 请求 body(已解密后的明文)
118    pub body: Bytes,
119    /// v1.4.38 Phase 4: 订单幂等 key(由 REST `Idempotency-Key` header / gRPC
120    /// metadata / WS envelope / MCP tool args 填入)。None 表示客户端未传,
121    /// handler 走无幂等直通 path(backward-compat)。
122    pub idempotency_key: Option<String>,
123    /// v1.4.106 codex 0920 F1 (P1): caller key id 副本 (per-call snapshot,
124    /// 由 surface adapter 层从 KeyRecord 读取后填入).
125    ///
126    /// **目标**: idempotency cache key namespace 必须含 caller key id, 否则
127    /// 不同 caller 用同 Idempotency-Key 会跨 caller 命中老 response —— 严重
128    /// 跨账户数据泄漏 + 重复下单 silent fail.
129    ///
130    /// `None` = 无 caller 标识 (legacy TCP / 未 auth) → namespace 用 `<no_key>`
131    /// 占位符. `Some("alice")` = WS / MCP / REST 已 auth 的 caller —— namespace
132    /// 用 `<caller_key_id="alice">`, 不与其他 caller 串.
133    pub caller_key_id: Option<String>,
134    /// v1.4.105 D2 contract-hardening 补丁: caller key 的 `allowed_acc_ids` 硬限额
135    /// 副本 (per-call snapshot, 在 surface adapter 层从 KeyRecord 读取后填入).
136    ///
137    /// **目标**: 让 dispatch-time handlers (e.g. `SubAccPushHandler` 注册 acc_id
138    /// 到 SubscriptionManager) 也能 enforce per-acc whitelist — 即使上游 pipeline
139    /// body-aware step 已 enforce, 让 handler 自己 defense-in-depth 防 future
140    /// regression (新 surface 加进来漏调 pipeline body-aware).
141    ///
142    /// `None` / `Some(empty)` = caller 无 acc_id 限制 (legacy mode 或 unrestricted
143    /// key) → handler 不 filter; `Some(non-empty set)` → handler 应 reject 不在
144    /// set 中的 acc_id. Deny-all 使用 sentinel `{0}`,不使用空集合。
145    pub caller_allowed_acc_ids: Option<std::sync::Arc<std::collections::HashSet<u64>>>,
146}
147
148impl IncomingRequest {
149    /// codex 0522 F4 v1.4.106: cross-surface 单测 hook. 构 IncomingRequest
150    /// 并填 caller scope (`caller_key_id` + `caller_allowed_acc_ids`) — 让
151    /// REST / gRPC / WS / MCP 等 surface 的 adapter 都用同一构造路径, 防
152    /// "某个 surface 漏填字段" silent regression.
153    ///
154    /// 之前 4 surface 各写一份 struct literal, 加新字段需逐个改, 漏一个就
155    /// 出现 silent None — 与坑 #54 schema-only fix 同模式 (实装符号 vs 真
156    /// 行为差距). 本 helper 是 single point, 加新字段 schema 自动 propagate.
157    ///
158    /// **注意**: 本 helper 不 take ownership of body — caller 已 own bytes.
159    /// idempotency_key / caller_key_id 接 String 而非 &str 让 caller 决定
160    /// 是 clone 还是 move.
161    pub fn builder(
162        conn_id: u64,
163        proto_id: u32,
164        serial_no: u32,
165        proto_fmt_type: ProtoFmtType,
166        body: Bytes,
167    ) -> IncomingRequestBuilder {
168        IncomingRequestBuilder {
169            request: Self {
170                conn_id,
171                proto_id,
172                serial_no,
173                proto_fmt_type,
174                body,
175                idempotency_key: None,
176                caller_allowed_acc_ids: None,
177                caller_key_id: None,
178            },
179        }
180    }
181}
182
183/// Thin builder for `IncomingRequest`.
184///
185/// The base request shape is the wire envelope; idempotency and caller scope are
186/// optional per-surface decorations. Keeping those defaults here avoids every
187/// REST / gRPC / raw WS / MCP adapter spelling out `None` independently.
188#[derive(Debug)]
189pub struct IncomingRequestBuilder {
190    request: IncomingRequest,
191}
192
193impl IncomingRequestBuilder {
194    pub fn with_idempotency_key(mut self, idempotency_key: Option<String>) -> Self {
195        self.request.idempotency_key = idempotency_key;
196        self
197    }
198
199    pub fn with_caller_scope(
200        mut self,
201        caller_allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
202        caller_key_id: Option<String>,
203    ) -> Self {
204        self.request.caller_allowed_acc_ids = caller_allowed_acc_ids;
205        self.request.caller_key_id = caller_key_id;
206        self
207    }
208
209    pub fn build(self) -> IncomingRequest {
210        self.request
211    }
212}
213
214impl From<IncomingRequestBuilder> for IncomingRequest {
215    fn from(builder: IncomingRequestBuilder) -> Self {
216        builder.build()
217    }
218}
219
220impl ClientConn {
221    /// 生成随机连接 ID(与 C++ 的 GetRand_MilliTimeAndU22 对应)
222    pub fn generate_conn_id() -> u64 {
223        use std::time::{SystemTime, UNIX_EPOCH};
224        let millis = SystemTime::now()
225            .duration_since(UNIX_EPOCH)
226            .unwrap_or_default()
227            .as_millis() as u64;
228        let rand_part: u32 = rand::random();
229        (millis << 22) | (rand_part as u64 & 0x3FFFFF)
230    }
231
232    /// 生成随机 AES key(16 字节 hex 字符串的 ASCII 字节)
233    pub fn generate_aes_key() -> [u8; 16] {
234        let rand_val: u64 = rand::random();
235        let hex = format!("{rand_val:016X}");
236        let mut key = [0u8; 16];
237        key.copy_from_slice(hex.as_bytes());
238        key
239    }
240
241    /// 创建发送帧,自动处理 AES 加密
242    ///
243    /// 当 aes_encrypt_enabled 为 true 时:
244    /// - SHA1 基于明文计算(FutuFrame::new 自动处理)
245    /// - body 使用 AES-128 ECB 加密
246    /// - header.body_len 更新为密文长度
247    ///
248    /// 对应 C++ APIServerCS_Conn::OnSendPacketData 的加密逻辑
249    pub fn make_frame(&self, proto_id: u32, serial_no: u32, body: Bytes) -> FutuFrame {
250        if self.aes_encrypt_enabled {
251            // 先用明文构建 frame(SHA1 基于明文计算)
252            let mut frame = FutuFrame::new(proto_id, serial_no, body);
253            // 再加密 body
254            let encrypted = encrypt::aes_ecb_encrypt(&self.aes_key, &frame.body);
255            frame.header.body_len = encrypted.len() as u32;
256            frame.body = Bytes::from(encrypted);
257            frame
258        } else {
259            FutuFrame::new(proto_id, serial_no, body)
260        }
261    }
262
263    /// 解密请求 body(如果启用了 AES 加密)
264    ///
265    /// 对应 C++ APIServerCS_Conn::OnRecvPacket 的解密逻辑
266    pub fn decrypt_body(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
267        if self.aes_encrypt_enabled {
268            encrypt::aes_ecb_decrypt(&self.aes_key, body).map_err(|e| {
269                tracing::warn!(conn_id = self.conn_id, error = %e, "AES decrypt body failed");
270                e
271            })
272        } else {
273            Ok(body.to_vec())
274        }
275    }
276
277    /// 处理 InitConnect 请求,返回 InitConnect 响应 body
278    ///
279    /// 当配置了 RSA 私钥时:
280    /// - C2S 请求 body 使用 RSA 公钥加密(需要用私钥解密)
281    /// - S2C 响应 body 使用 RSA 公钥加密(客户端用私钥解密)
282    ///
283    /// 对应 C++ APIServer::OnRecvInitConnect
284    pub fn handle_init_connect(
285        &mut self,
286        body: &[u8],
287        server_ver: i32,
288        login_user_id: u64,
289        keepalive_interval: i32,
290        rsa_private_key: Option<&str>,
291    ) -> Result<Vec<u8>, FutuError> {
292        // 1. 解密 C2S(如果配置了 RSA)
293        let decrypted_body;
294        let req_body = if let Some(rsa_key) = rsa_private_key {
295            decrypted_body =
296                futu_net::encrypt::rsa_private_decrypt_blocks(rsa_key, body).map_err(|e| {
297                    tracing::warn!(error = %e, "RSA decrypt InitConnect C2S failed");
298                    e
299                })?;
300            tracing::debug!(
301                encrypted_len = body.len(),
302                decrypted_len = decrypted_body.len(),
303                "RSA decrypted InitConnect C2S"
304            );
305            &decrypted_body[..]
306        } else {
307            body
308        };
309
310        let req: futu_proto::init_connect::Request =
311            prost::Message::decode(req_body).map_err(FutuError::Proto)?;
312
313        self.state = ConnState::Initialized;
314        self.recv_notify = req.c2s.recv_notify.unwrap_or(false);
315
316        // 当配置了 RSA 时,后续所有帧使用 AES 加解密
317        if rsa_private_key.is_some() {
318            self.aes_encrypt_enabled = true;
319            tracing::debug!(conn_id = self.conn_id, "AES body encryption enabled");
320        }
321
322        let aes_key_str = std::str::from_utf8(&self.aes_key)
323            .map_err(|e| FutuError::Codec(format!("invalid InitConnect conn_aes_key: {e}")))?
324            .to_string();
325
326        let resp = futu_proto::init_connect::Response {
327            ret_type: 0,
328            ret_msg: None,
329            err_code: None,
330            s2c: Some(futu_proto::init_connect::S2c {
331                server_ver,
332                login_user_id,
333                conn_id: self.conn_id,
334                conn_aes_key: aes_key_str,
335                keep_alive_interval: keepalive_interval,
336                aes_cb_civ: None,
337                user_attribution: None,
338            }),
339        };
340
341        let resp_body = prost::Message::encode_to_vec(&resp);
342
343        // 2. 加密 S2C(如果配置了 RSA)
344        if let Some(rsa_key) = rsa_private_key {
345            let encrypted = futu_net::encrypt::rsa_public_encrypt_blocks(rsa_key, &resp_body)
346                .map_err(|e| {
347                    tracing::warn!(error = %e, "RSA encrypt InitConnect S2C failed");
348                    e
349                })?;
350            tracing::debug!(
351                plaintext_len = resp_body.len(),
352                encrypted_len = encrypted.len(),
353                "RSA encrypted InitConnect S2C"
354            );
355            Ok(encrypted)
356        } else {
357            Ok(resp_body)
358        }
359    }
360
361    /// 处理 KeepAlive 请求
362    pub fn handle_keepalive(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
363        let _req: futu_proto::keep_alive::Request =
364            prost::Message::decode(body).map_err(FutuError::Proto)?;
365
366        self.keepalive_count.fetch_add(1, Ordering::Relaxed);
367
368        let resp = futu_proto::keep_alive::Response {
369            ret_type: 0,
370            ret_msg: None,
371            err_code: None,
372            s2c: Some(futu_proto::keep_alive::S2c {
373                time: chrono::Utc::now().timestamp(),
374            }),
375        };
376
377        Ok(prost::Message::encode_to_vec(&resp))
378    }
379}
380
381/// 连接断开通知
382pub struct DisconnectNotify {
383    /// 被断开的连接 ID(订阅 / push / auth 状态清理用)
384    pub conn_id: u64,
385}
386
387/// 运行单个连接的收发循环
388///
389/// 返回 (接收请求的 channel, 连接信息)
390pub async fn run_connection(
391    stream: TcpStream,
392    conn_id: u64,
393    _aes_key: [u8; 16],
394    req_tx: mpsc::UnboundedSender<IncomingRequest>,
395    disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
396) -> mpsc::Sender<FutuFrame> {
397    let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
398
399    let framed = Framed::new(stream, FutuCodec);
400    let (mut sink, mut stream) = framed.split();
401
402    // 发送任务
403    tokio::spawn(async move {
404        while let Some(frame) = frame_rx.recv().await {
405            if let Err(e) = sink.send(frame).await {
406                tracing::warn!(conn_id = conn_id, error = %e, "send failed");
407                break;
408            }
409        }
410    });
411
412    // 接收任务
413    tokio::spawn(async move {
414        while let Some(result) = stream.next().await {
415            match result {
416                Ok(frame) => {
417                    // TCP wire currently has no idempotency or key scope fields.
418                    // The builder keeps legacy defaults explicit in one place.
419                    let req = IncomingRequest::builder(
420                        conn_id,
421                        frame.header.proto_id,
422                        frame.header.serial_no,
423                        frame.header.proto_fmt_type,
424                        frame.body,
425                    )
426                    .build();
427                    if req_tx.send(req).is_err() {
428                        break;
429                    }
430                }
431                Err(e) => {
432                    tracing::warn!(conn_id = conn_id, error = %e, "recv error");
433                    break;
434                }
435            }
436        }
437        tracing::info!(conn_id = conn_id, "connection closed");
438        // 通知 listener 清理连接
439        let _ = disconnect_tx.send(DisconnectNotify { conn_id });
440    });
441
442    frame_tx
443}
444
445#[cfg(test)]
446mod tests;