Skip to main content

futu_server/
listener.rs

1// TCP 监听器:接受客户端连接,管理连接池
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use dashmap::DashMap;
7use tokio::net::TcpListener;
8use tokio::sync::mpsc;
9
10use futu_codec::header::ProtoFmtType;
11use futu_core::proto_id;
12
13use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
14use crate::metrics::GatewayMetrics;
15use crate::router::RequestRouter;
16
17/// 服务端最大连接数
18pub const MAX_CONNECTIONS: usize = 128;
19
20/// 服务端配置
21#[derive(Debug, Clone)]
22pub struct ServerConfig {
23    /// TCP 监听地址(如 `127.0.0.1:11111`)
24    pub listen_addr: String,
25    /// 服务端版本号,InitConnect 响应下发给客户端
26    pub server_ver: i32,
27    /// 服务端登录 user_id,InitConnect 响应下发给客户端
28    pub login_user_id: u64,
29    /// KeepAlive 心跳间隔(秒),InitConnect 响应下发给客户端
30    pub keepalive_interval: i32,
31    /// RSA 私钥 PEM 内容(可选,启用后 InitConnect 使用 RSA 加解密)
32    pub rsa_private_key: Option<String>,
33}
34
35/// API 服务端
36pub struct ApiServer {
37    config: ServerConfig,
38    connections: Arc<DashMap<u64, ClientConn>>,
39    router: Arc<RequestRouter>,
40    subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
41    metrics: Arc<GatewayMetrics>,
42}
43
44impl ApiServer {
45    /// 创建新的服务端实例。不自动启动,需调用 [`ApiServer::run`] 进入接收循环。
46    pub fn new(config: ServerConfig) -> Self {
47        Self {
48            config,
49            connections: Arc::new(DashMap::new()),
50            router: Arc::new(RequestRouter::new()),
51            subscriptions: None,
52            metrics: Arc::new(GatewayMetrics::new()),
53        }
54    }
55
56    /// 设置订阅管理器,用于连接断开时自动清理订阅关系
57    pub fn set_subscriptions(&mut self, subs: Arc<crate::subscription::SubscriptionManager>) {
58        self.subscriptions = Some(subs);
59    }
60
61    /// 获取路由器引用(用于注册业务处理器)
62    pub fn router(&self) -> &Arc<RequestRouter> {
63        &self.router
64    }
65
66    /// 获取连接池引用(用于推送分发)
67    pub fn connections(&self) -> &Arc<DashMap<u64, ClientConn>> {
68        &self.connections
69    }
70
71    /// 设置外部监控指标(共享同一个 Arc,让 bridge 和 server 使用同一份计数器)
72    pub fn set_metrics(&mut self, metrics: Arc<GatewayMetrics>) {
73        self.metrics = metrics;
74    }
75
76    /// 获取监控指标引用
77    pub fn metrics(&self) -> &Arc<GatewayMetrics> {
78        &self.metrics
79    }
80
81    /// 启动服务端监听
82    pub async fn run(&self) -> anyhow::Result<()> {
83        let listener = TcpListener::bind(&self.config.listen_addr).await?;
84        tracing::info!(addr = %self.config.listen_addr, "API server listening");
85
86        let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
87        let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
88
89        // 启动请求处理任务
90        let connections = Arc::clone(&self.connections);
91        let router = Arc::clone(&self.router);
92        let config = self.config.clone();
93        let metrics = Arc::clone(&self.metrics);
94        tokio::spawn(async move {
95            process_requests(req_rx, connections, router, config, metrics).await;
96        });
97
98        // 启动连接清理任务(TCP 断开通知)
99        let cleanup_connections = Arc::clone(&self.connections);
100        let cleanup_subs = self.subscriptions.clone();
101        let cleanup_metrics = Arc::clone(&self.metrics);
102        tokio::spawn(async move {
103            while let Some(notify) = disconnect_rx.recv().await {
104                let removed = cleanup_connections.remove(&notify.conn_id);
105                if removed.is_some() {
106                    cleanup_metrics
107                        .total_disconnections
108                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
109                    // 清理该连接的所有订阅关系
110                    if let Some(ref subs) = cleanup_subs {
111                        subs.on_disconnect(notify.conn_id);
112                    }
113                    tracing::info!(
114                        conn_id = notify.conn_id,
115                        remaining = cleanup_connections.len(),
116                        "connection removed from pool"
117                    );
118                }
119            }
120        });
121
122        // 启动 KeepAlive 超时检测任务(对应 C++ OnTimeTicker,每 66 秒无活动断连)
123        let ka_connections = Arc::clone(&self.connections);
124        let ka_subs = self.subscriptions.clone();
125        let ka_metrics = Arc::clone(&self.metrics);
126        tokio::spawn(async move {
127            const CHECK_INTERVAL_SECS: u64 = 15;
128            const TIMEOUT_SECS: u64 = 66;
129            let mut interval =
130                tokio::time::interval(std::time::Duration::from_secs(CHECK_INTERVAL_SECS));
131            interval.tick().await; // 跳过首次立即触发
132            loop {
133                interval.tick().await;
134                let now = Instant::now();
135                let mut timed_out = Vec::new();
136                for entry in ka_connections.iter() {
137                    let conn = entry.value();
138                    if now.duration_since(conn.last_keepalive).as_secs() >= TIMEOUT_SECS {
139                        timed_out.push(conn.conn_id);
140                    }
141                }
142                for conn_id in timed_out {
143                    if ka_connections.remove(&conn_id).is_some() {
144                        ka_metrics
145                            .keepalive_timeouts
146                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147                        ka_metrics
148                            .total_disconnections
149                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
150                        if let Some(ref subs) = ka_subs {
151                            subs.on_disconnect(conn_id);
152                        }
153                        tracing::info!(
154                            conn_id = conn_id,
155                            remaining = ka_connections.len(),
156                            "keepalive timeout, connection removed"
157                        );
158                    }
159                }
160            }
161        });
162
163        // 接受连接循环
164        let connections = Arc::clone(&self.connections);
165        let accept_metrics = Arc::clone(&self.metrics);
166        loop {
167            let (stream, peer_addr) = listener.accept().await?;
168
169            if connections.len() >= MAX_CONNECTIONS {
170                accept_metrics
171                    .rejected_connections
172                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
173                tracing::warn!(
174                    peer = %peer_addr,
175                    "max connections reached ({}), rejecting",
176                    MAX_CONNECTIONS
177                );
178                drop(stream);
179                continue;
180            }
181
182            accept_metrics
183                .total_connections
184                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
185
186            let conn_id = crate::conn::ClientConn::generate_conn_id();
187            let aes_key = crate::conn::ClientConn::generate_aes_key();
188            stream.set_nodelay(true).ok();
189
190            tracing::info!(
191                conn_id = conn_id,
192                peer = %peer_addr,
193                total = connections.len() + 1,
194                "client connected"
195            );
196
197            let tx = crate::conn::run_connection(
198                stream,
199                conn_id,
200                aes_key,
201                req_tx.clone(),
202                disconnect_tx.clone(),
203            )
204            .await;
205
206            let conn = ClientConn {
207                conn_id,
208                state: ConnState::Connected,
209                aes_key,
210                aes_encrypt_enabled: false,
211                proto_fmt_type: ProtoFmtType::Protobuf,
212                last_keepalive: Instant::now(),
213                recv_notify: false,
214                keepalive_count: std::sync::atomic::AtomicU32::new(0),
215                tx,
216                // 原 TCP listener 不做 per-message scope 校验(保持兼容):
217                // key_id=None / scopes=空集 被 ws_listener 的 gate 解释为"legacy 全放行"
218                key_id: None,
219                scopes: std::collections::HashSet::new(),
220                // v1.4.105 D3 (Phase 4) T-B2: TCP listener 同样 legacy 模式 →
221                // allowed_markets None = 无限制 (push_trd_acc Layer 3 不 trigger).
222                allowed_markets: None,
223                // codex round 1 F4 (P2) v1.4.105: 同 legacy 模式, 无 acc_id 限制.
224                allowed_acc_ids: None,
225            };
226
227            connections.insert(conn_id, conn);
228        }
229    }
230
231    /// 向指定连接发送响应(自动处理 AES 加密)
232    pub async fn send_response(
233        connections: &DashMap<u64, ClientConn>,
234        conn_id: u64,
235        proto_id: u32,
236        serial_no: u32,
237        body: Vec<u8>,
238    ) {
239        if let Some(conn) = connections.get(&conn_id) {
240            let frame = conn.make_frame(proto_id, serial_no, bytes::Bytes::from(body));
241            if conn.tx.send(frame).await.is_err() {
242                tracing::warn!(
243                    conn_id = conn_id,
244                    "failed to send response, connection closed"
245                );
246            }
247        }
248    }
249}
250
251/// 处理所有连接的请求
252async fn process_requests(
253    mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
254    connections: Arc<DashMap<u64, ClientConn>>,
255    router: Arc<RequestRouter>,
256    config: ServerConfig,
257    metrics: Arc<GatewayMetrics>,
258) {
259    while let Some(mut req) = req_rx.recv().await {
260        let conn_id = req.conn_id;
261        let proto_id_val = req.proto_id;
262        let serial_no = req.serial_no;
263        let req_start = Instant::now();
264
265        metrics
266            .total_requests
267            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
268
269        // 更新 last_keepalive(任何包都算活跃,对应 C++ m_nKeepAlive_Count_Curt++)
270        if let Some(mut conn) = connections.get_mut(&conn_id) {
271            conn.last_keepalive = Instant::now();
272        }
273
274        // v1.4.106 codex 0532 F3 (P2): daemon-internal proto_id (高位
275        // 0x8000_0000 bit) 绝不应从 raw TCP 公开 surface 进入 — 仅 REST
276        // handler 内部合成给 router. 显式 reject + log, 防探测 daemon
277        // 内部 routing.
278        if futu_auth::is_internal_proto_id(proto_id_val) {
279            metrics
280                .total_request_errors
281                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
282            tracing::warn!(
283                conn_id,
284                proto_id = proto_id_val,
285                "rejecting daemon-internal proto_id at raw TCP public surface (codex 0532 F3)"
286            );
287            continue;
288        }
289
290        // 非 InitConnect 请求需要 AES 解密(InitConnect 自身处理 RSA 解密)
291        if proto_id_val != proto_id::INIT_CONNECT
292            && let Some(conn) = connections.get(&conn_id)
293            && conn.aes_encrypt_enabled
294        {
295            match conn.decrypt_body(&req.body) {
296                Ok(decrypted) => {
297                    req.body = bytes::Bytes::from(decrypted);
298                }
299                Err(e) => {
300                    metrics
301                        .total_request_errors
302                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
303                    tracing::warn!(
304                        conn_id = conn_id,
305                        proto_id = proto_id_val,
306                        error = %e,
307                        "AES decrypt request failed, dropping"
308                    );
309                    continue;
310                }
311            }
312        }
313
314        // InitConnect 和 KeepAlive 内部处理
315        let response_body = match proto_id_val {
316            proto_id::INIT_CONNECT => match connections.get_mut(&conn_id) {
317                Some(mut conn) => conn
318                    .handle_init_connect(
319                        &req.body,
320                        config.server_ver,
321                        config.login_user_id,
322                        config.keepalive_interval,
323                        config.rsa_private_key.as_deref(),
324                    )
325                    .ok(),
326                _ => None,
327            },
328            proto_id::KEEP_ALIVE => match connections.get(&conn_id) {
329                Some(conn) => conn.handle_keepalive(&req.body).ok(),
330                _ => None,
331            },
332            _ => {
333                // 委托给路由器
334                router.dispatch(conn_id, &req).await
335            }
336        };
337
338        // 记录延迟
339        metrics.record_latency_ns(req_start.elapsed().as_nanos() as u64);
340
341        // 发送响应
342        if let Some(body) = response_body {
343            metrics
344                .total_response_bytes
345                .fetch_add(body.len() as u64, std::sync::atomic::Ordering::Relaxed);
346            ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
347        } else if proto_id_val != proto_id::INIT_CONNECT && proto_id_val != proto_id::KEEP_ALIVE {
348            metrics
349                .total_request_errors
350                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
351        }
352    }
353}