futu_server/conn.rs
1// 单连接管理:状态机、帧收发、加密、心跳超时
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::Instant;
6
7use bytes::Bytes;
8use futures::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::mpsc;
11use tokio_util::codec::Framed;
12
13use futu_auth::Scope;
14use futu_codec::FutuCodec;
15use futu_codec::frame::FutuFrame;
16use futu_codec::header::ProtoFmtType;
17use futu_core::error::FutuError;
18use futu_net::encrypt;
19
20/// 连接状态
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ConnState {
24 /// TCP / WebSocket 刚建立,尚未完成 InitConnect 握手
25 Connected,
26 /// 已完成 InitConnect,等待首次业务请求
27 Initialized,
28 /// 正常交互中(业务请求 / KeepAlive / push 均已流转)
29 Active,
30 /// 连接已断开(客户端主动关闭 / 被动超时 / IO 错误)
31 Disconnected,
32}
33
34/// 单个客户端连接
35pub struct ClientConn {
36 /// 随机连接 ID(对应 C++ `GetRand_MilliTimeAndU22`)
37 pub conn_id: u64,
38 /// 连接状态:InitConnect 前 / 后 / 已断开
39 pub state: ConnState,
40 /// 随机 AES-128 key(InitConnect 响应里下发给客户端)
41 pub aes_key: [u8; 16],
42 /// AES 加解密已启用(InitConnect 完成且配置了 RSA 时为 true)
43 pub aes_encrypt_enabled: bool,
44 /// 该连接协商的 proto 格式(Protobuf / JSON)
45 pub proto_fmt_type: ProtoFmtType,
46 /// 上次收到 KeepAlive 的时间,用于超时检查
47 pub last_keepalive: Instant,
48 /// InitConnect.C2S.recvNotify:此连接是否接收市场状态 / 交易解锁等通知。
49 ///
50 /// C++ 在 `APIServer_InitConnect.cpp` 里把该字段写入 ConnInfo;
51 /// `RegQotPush` / `Qot_Sub(isRegOrUnRegPush)` 不会修改这个开关。
52 pub recv_notify: bool,
53 /// 已收到的 KeepAlive 计数(监控用)
54 pub keepalive_count: AtomicU32,
55 /// 发送帧到此连接
56 pub tx: mpsc::Sender<FutuFrame>,
57
58 // ---- v1.0 WS per-message scope 鉴权(raw TCP legacy 兼容:scopes 空集全放行)----
59 /// 该连接绑定的 API key id;WS 握手时填,未配 keys.json 时为 None
60 pub key_id: Option<String>,
61 /// 该连接持有的 scope 集合;空集 = legacy 模式 / TCP 直连,scope 检查放行
62 pub scopes: HashSet<Scope>,
63 /// v1.4.105 D3 (Phase 4) T-B2: 该连接 caller key 的 `allowed_markets` 硬
64 /// 限额 (大写字符串 set, e.g. {"HK","US"}). `None` = 无限制 (legacy 模式
65 /// / TCP 直连默认全开 / 未配 allowed_markets); `Some(set)` 非空 → push 端
66 /// 应过滤 trd_market 不在 set 中的 trade event.
67 ///
68 /// **触发**: WS handshake 时从 `KeyRecord.allowed_markets` 拷贝过来.
69 /// `PushDispatcher::push_trd_acc` 端 Layer 3 filter 检查. 与
70 /// `caller_allowed_acc_ids` (Layer 1, per-call snapshot in IncomingRequest)
71 /// 区别: 本字段是 per-conn snapshot (handshake 时一次性), 不随 per-call
72 /// 重读 — KeyRecord SIGHUP reload 后**仅新建连接生效**, 老连接保持 snapshot
73 /// (与 `scopes` / `caller_allowed_acc_ids` 的 snapshot 语义一致).
74 pub allowed_markets: Option<std::sync::Arc<HashSet<String>>>,
75 /// codex round 1 F4 (P2) v1.4.105: 该连接 caller key 的 `allowed_acc_ids`
76 /// 硬限额 (per-conn snapshot, handshake 时一次性). `None` / `Some(empty)` =
77 /// 无限制 (legacy 模式 / TCP 直连默认全开 / 未配 allowed_acc_ids);
78 /// `Some(non-empty set)` →
79 /// `PushDispatcher::push_trd_acc` 端 push-time 硬过滤 acc_id 不在 set 中
80 /// 的 trade event (Layer 1, 与 `allowed_markets` 的 Layer 3 互补).
81 /// Deny-all 使用 sentinel `{0}`,不使用空集合。
82 ///
83 /// **触发**: codex F4 指出 raw TCP push 端只查 `acc:read` scope +
84 /// `allowed_markets`, 不查 `allowed_acc_ids`. 即使 request-time
85 /// `SubAccPushHandler` 已阻止越权订阅, stale subscription / KeyRecord
86 /// reload 后窄化的 acc 范围 / 历史 bug 留下的 conn→acc 关系 仍可能让 push
87 /// 漏 leak. 本字段提供第二层 push-time 兜底.
88 ///
89 /// 与 `caller_allowed_acc_ids` (IncomingRequest, per-call) 区别: 本字段
90 /// 在 push-time 用 (无 IncomingRequest), per-conn snapshot 与 `scopes` /
91 /// `allowed_markets` 的 snapshot 语义一致.
92 pub allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
93}
94
95/// 从连接接收到的请求
96#[derive(Debug)]
97pub struct IncomingRequest {
98 /// 发送请求的连接 ID(用于响应路由 + SubscriptionManager / cache 记账)
99 ///
100 /// **跨 surface 命名空间分配**(v1.4.106 codex 0517 ζ25-redo F2 沉淀):
101 /// - raw TCP listener: `ClientConn::generate_conn_id()` 派生(u32 范围)
102 /// - REST: `crates/futu-rest/src/routes/qot.rs::REST_SHARED_CONN`
103 /// = `0xFFFF_FFFE`(u32 上限附近, 单值共享)
104 /// - gRPC: `crates/futu-grpc/src/auth.rs::GRPC_STABLE_CONN_NAMESPACE`
105 /// = `0x4000_0000_0000_0000`(bit 62 namespace, 按 caller 派生)
106 /// - WS / MCP: 通常派生自所属物理 TCP 连接的 conn_id
107 ///
108 /// 各 surface 不重叠. 加新 surface 时分配一个 namespace base, 不要与
109 /// 上述 4 个段重合.
110 pub conn_id: u64,
111 /// 协议 ID(对齐 C++ `NN_ProtoCmd_*`)
112 pub proto_id: u32,
113 /// 序列号(和 Response 配对,供 client 端请求-响应匹配)
114 pub serial_no: u32,
115 /// 请求体 proto 格式(Protobuf / JSON)
116 pub proto_fmt_type: ProtoFmtType,
117 /// 请求 body(已解密后的明文)
118 pub body: Bytes,
119 /// v1.4.38 Phase 4: 订单幂等 key(由 REST `Idempotency-Key` header / gRPC
120 /// metadata / WS envelope / MCP tool args 填入)。None 表示客户端未传,
121 /// handler 走无幂等直通 path(backward-compat)。
122 pub idempotency_key: Option<String>,
123 /// v1.4.106 codex 0920 F1 (P1): caller key id 副本 (per-call snapshot,
124 /// 由 surface adapter 层从 KeyRecord 读取后填入).
125 ///
126 /// **目标**: idempotency cache key namespace 必须含 caller key id, 否则
127 /// 不同 caller 用同 Idempotency-Key 会跨 caller 命中老 response —— 严重
128 /// 跨账户数据泄漏 + 重复下单 silent fail.
129 ///
130 /// `None` = 无 caller 标识 (legacy TCP / 未 auth) → namespace 用 `<no_key>`
131 /// 占位符. `Some("alice")` = WS / MCP / REST 已 auth 的 caller —— namespace
132 /// 用 `<caller_key_id="alice">`, 不与其他 caller 串.
133 pub caller_key_id: Option<String>,
134 /// v1.4.105 D2 contract-hardening 补丁: caller key 的 `allowed_acc_ids` 硬限额
135 /// 副本 (per-call snapshot, 在 surface adapter 层从 KeyRecord 读取后填入).
136 ///
137 /// **目标**: 让 dispatch-time handlers (e.g. `SubAccPushHandler` 注册 acc_id
138 /// 到 SubscriptionManager) 也能 enforce per-acc whitelist — 即使上游 pipeline
139 /// body-aware step 已 enforce, 让 handler 自己 defense-in-depth 防 future
140 /// regression (新 surface 加进来漏调 pipeline body-aware).
141 ///
142 /// `None` / `Some(empty)` = caller 无 acc_id 限制 (legacy mode 或 unrestricted
143 /// key) → handler 不 filter; `Some(non-empty set)` → handler 应 reject 不在
144 /// set 中的 acc_id. Deny-all 使用 sentinel `{0}`,不使用空集合。
145 pub caller_allowed_acc_ids: Option<std::sync::Arc<std::collections::HashSet<u64>>>,
146}
147
148impl IncomingRequest {
149 /// codex 0522 F4 v1.4.106: cross-surface 单测 hook. 构 IncomingRequest
150 /// 并填 caller scope (`caller_key_id` + `caller_allowed_acc_ids`) — 让
151 /// REST / gRPC / WS / MCP 等 surface 的 adapter 都用同一构造路径, 防
152 /// "某个 surface 漏填字段" silent regression.
153 ///
154 /// 之前 4 surface 各写一份 struct literal, 加新字段需逐个改, 漏一个就
155 /// 出现 silent None — 与坑 #54 schema-only fix 同模式 (实装符号 vs 真
156 /// 行为差距). 本 helper 是 single point, 加新字段 schema 自动 propagate.
157 ///
158 /// **注意**: 本 helper 不 take ownership of body — caller 已 own bytes.
159 /// idempotency_key / caller_key_id 接 String 而非 &str 让 caller 决定
160 /// 是 clone 还是 move.
161 pub fn builder(
162 conn_id: u64,
163 proto_id: u32,
164 serial_no: u32,
165 proto_fmt_type: ProtoFmtType,
166 body: Bytes,
167 ) -> IncomingRequestBuilder {
168 IncomingRequestBuilder {
169 request: Self {
170 conn_id,
171 proto_id,
172 serial_no,
173 proto_fmt_type,
174 body,
175 idempotency_key: None,
176 caller_allowed_acc_ids: None,
177 caller_key_id: None,
178 },
179 }
180 }
181}
182
183/// Thin builder for `IncomingRequest`.
184///
185/// The base request shape is the wire envelope; idempotency and caller scope are
186/// optional per-surface decorations. Keeping those defaults here avoids every
187/// REST / gRPC / raw WS / MCP adapter spelling out `None` independently.
188#[derive(Debug)]
189pub struct IncomingRequestBuilder {
190 request: IncomingRequest,
191}
192
193impl IncomingRequestBuilder {
194 pub fn with_idempotency_key(mut self, idempotency_key: Option<String>) -> Self {
195 self.request.idempotency_key = idempotency_key;
196 self
197 }
198
199 pub fn with_caller_scope(
200 mut self,
201 caller_allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
202 caller_key_id: Option<String>,
203 ) -> Self {
204 self.request.caller_allowed_acc_ids = caller_allowed_acc_ids;
205 self.request.caller_key_id = caller_key_id;
206 self
207 }
208
209 pub fn build(self) -> IncomingRequest {
210 self.request
211 }
212}
213
214impl From<IncomingRequestBuilder> for IncomingRequest {
215 fn from(builder: IncomingRequestBuilder) -> Self {
216 builder.build()
217 }
218}
219
220impl ClientConn {
221 /// 生成随机连接 ID(与 C++ 的 GetRand_MilliTimeAndU22 对应)
222 pub fn generate_conn_id() -> u64 {
223 use std::time::{SystemTime, UNIX_EPOCH};
224 let millis = SystemTime::now()
225 .duration_since(UNIX_EPOCH)
226 .unwrap_or_default()
227 .as_millis() as u64;
228 let rand_part: u32 = rand::random();
229 (millis << 22) | (rand_part as u64 & 0x3FFFFF)
230 }
231
232 /// 生成随机 AES key(16 字节 hex 字符串的 ASCII 字节)
233 pub fn generate_aes_key() -> [u8; 16] {
234 let rand_val: u64 = rand::random();
235 let hex = format!("{rand_val:016X}");
236 let mut key = [0u8; 16];
237 key.copy_from_slice(hex.as_bytes());
238 key
239 }
240
241 /// 创建发送帧,自动处理 AES 加密
242 ///
243 /// 当 aes_encrypt_enabled 为 true 时:
244 /// - SHA1 基于明文计算(FutuFrame::new 自动处理)
245 /// - body 使用 AES-128 ECB 加密
246 /// - header.body_len 更新为密文长度
247 ///
248 /// 对应 C++ APIServerCS_Conn::OnSendPacketData 的加密逻辑
249 pub fn make_frame(&self, proto_id: u32, serial_no: u32, body: Bytes) -> FutuFrame {
250 if self.aes_encrypt_enabled {
251 // 先用明文构建 frame(SHA1 基于明文计算)
252 let mut frame = FutuFrame::new(proto_id, serial_no, body);
253 // 再加密 body
254 let encrypted = encrypt::aes_ecb_encrypt(&self.aes_key, &frame.body);
255 frame.header.body_len = encrypted.len() as u32;
256 frame.body = Bytes::from(encrypted);
257 frame
258 } else {
259 FutuFrame::new(proto_id, serial_no, body)
260 }
261 }
262
263 /// 解密请求 body(如果启用了 AES 加密)
264 ///
265 /// 对应 C++ APIServerCS_Conn::OnRecvPacket 的解密逻辑
266 pub fn decrypt_body(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
267 if self.aes_encrypt_enabled {
268 encrypt::aes_ecb_decrypt(&self.aes_key, body).map_err(|e| {
269 tracing::warn!(conn_id = self.conn_id, error = %e, "AES decrypt body failed");
270 e
271 })
272 } else {
273 Ok(body.to_vec())
274 }
275 }
276
277 /// 处理 InitConnect 请求,返回 InitConnect 响应 body
278 ///
279 /// 当配置了 RSA 私钥时:
280 /// - C2S 请求 body 使用 RSA 公钥加密(需要用私钥解密)
281 /// - S2C 响应 body 使用 RSA 公钥加密(客户端用私钥解密)
282 ///
283 /// 对应 C++ APIServer::OnRecvInitConnect
284 pub fn handle_init_connect(
285 &mut self,
286 body: &[u8],
287 server_ver: i32,
288 login_user_id: u64,
289 keepalive_interval: i32,
290 rsa_private_key: Option<&str>,
291 ) -> Result<Vec<u8>, FutuError> {
292 // 1. 解密 C2S(如果配置了 RSA)
293 let decrypted_body;
294 let req_body = if let Some(rsa_key) = rsa_private_key {
295 decrypted_body =
296 futu_net::encrypt::rsa_private_decrypt_blocks(rsa_key, body).map_err(|e| {
297 tracing::warn!(error = %e, "RSA decrypt InitConnect C2S failed");
298 e
299 })?;
300 tracing::debug!(
301 encrypted_len = body.len(),
302 decrypted_len = decrypted_body.len(),
303 "RSA decrypted InitConnect C2S"
304 );
305 &decrypted_body[..]
306 } else {
307 body
308 };
309
310 let req: futu_proto::init_connect::Request =
311 prost::Message::decode(req_body).map_err(FutuError::Proto)?;
312
313 self.state = ConnState::Initialized;
314 self.recv_notify = req.c2s.recv_notify.unwrap_or(false);
315
316 // 当配置了 RSA 时,后续所有帧使用 AES 加解密
317 if rsa_private_key.is_some() {
318 self.aes_encrypt_enabled = true;
319 tracing::debug!(conn_id = self.conn_id, "AES body encryption enabled");
320 }
321
322 let aes_key_str = std::str::from_utf8(&self.aes_key)
323 .map_err(|e| FutuError::Codec(format!("invalid InitConnect conn_aes_key: {e}")))?
324 .to_string();
325
326 let resp = futu_proto::init_connect::Response {
327 ret_type: 0,
328 ret_msg: None,
329 err_code: None,
330 s2c: Some(futu_proto::init_connect::S2c {
331 server_ver,
332 login_user_id,
333 conn_id: self.conn_id,
334 conn_aes_key: aes_key_str,
335 keep_alive_interval: keepalive_interval,
336 aes_cb_civ: None,
337 user_attribution: None,
338 }),
339 };
340
341 let resp_body = prost::Message::encode_to_vec(&resp);
342
343 // 2. 加密 S2C(如果配置了 RSA)
344 if let Some(rsa_key) = rsa_private_key {
345 let encrypted = futu_net::encrypt::rsa_public_encrypt_blocks(rsa_key, &resp_body)
346 .map_err(|e| {
347 tracing::warn!(error = %e, "RSA encrypt InitConnect S2C failed");
348 e
349 })?;
350 tracing::debug!(
351 plaintext_len = resp_body.len(),
352 encrypted_len = encrypted.len(),
353 "RSA encrypted InitConnect S2C"
354 );
355 Ok(encrypted)
356 } else {
357 Ok(resp_body)
358 }
359 }
360
361 /// 处理 KeepAlive 请求
362 pub fn handle_keepalive(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
363 let _req: futu_proto::keep_alive::Request =
364 prost::Message::decode(body).map_err(FutuError::Proto)?;
365
366 self.keepalive_count.fetch_add(1, Ordering::Relaxed);
367
368 let resp = futu_proto::keep_alive::Response {
369 ret_type: 0,
370 ret_msg: None,
371 err_code: None,
372 s2c: Some(futu_proto::keep_alive::S2c {
373 time: chrono::Utc::now().timestamp(),
374 }),
375 };
376
377 Ok(prost::Message::encode_to_vec(&resp))
378 }
379}
380
381/// 连接断开通知
382pub struct DisconnectNotify {
383 /// 被断开的连接 ID(订阅 / push / auth 状态清理用)
384 pub conn_id: u64,
385}
386
387/// 运行单个连接的收发循环
388///
389/// 返回 (接收请求的 channel, 连接信息)
390pub async fn run_connection(
391 stream: TcpStream,
392 conn_id: u64,
393 _aes_key: [u8; 16],
394 req_tx: mpsc::UnboundedSender<IncomingRequest>,
395 disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
396) -> mpsc::Sender<FutuFrame> {
397 let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
398
399 let framed = Framed::new(stream, FutuCodec);
400 let (mut sink, mut stream) = framed.split();
401
402 // 发送任务
403 tokio::spawn(async move {
404 while let Some(frame) = frame_rx.recv().await {
405 if let Err(e) = sink.send(frame).await {
406 tracing::warn!(conn_id = conn_id, error = %e, "send failed");
407 break;
408 }
409 }
410 });
411
412 // 接收任务
413 tokio::spawn(async move {
414 while let Some(result) = stream.next().await {
415 match result {
416 Ok(frame) => {
417 // TCP wire currently has no idempotency or key scope fields.
418 // The builder keeps legacy defaults explicit in one place.
419 let req = IncomingRequest::builder(
420 conn_id,
421 frame.header.proto_id,
422 frame.header.serial_no,
423 frame.header.proto_fmt_type,
424 frame.body,
425 )
426 .build();
427 if req_tx.send(req).is_err() {
428 break;
429 }
430 }
431 Err(e) => {
432 tracing::warn!(conn_id = conn_id, error = %e, "recv error");
433 break;
434 }
435 }
436 }
437 tracing::info!(conn_id = conn_id, "connection closed");
438 // 通知 listener 清理连接
439 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
440 });
441
442 frame_tx
443}
444
445#[cfg(test)]
446mod tests;