1use std::sync::Arc;
4use std::time::Instant;
5
6use dashmap::DashMap;
7use tokio::net::TcpListener;
8use tokio::sync::mpsc;
9
10use futu_codec::header::ProtoFmtType;
11use futu_core::proto_id;
12
13use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
14use crate::metrics::GatewayMetrics;
15use crate::router::RequestRouter;
16
17pub const MAX_CONNECTIONS: usize = 128;
19
20#[derive(Debug, Clone)]
22pub struct ServerConfig {
23 pub listen_addr: String,
25 pub server_ver: i32,
27 pub login_user_id: u64,
29 pub keepalive_interval: i32,
31 pub rsa_private_key: Option<String>,
33}
34
35pub struct ApiServer {
37 config: ServerConfig,
38 connections: Arc<DashMap<u64, ClientConn>>,
39 router: Arc<RequestRouter>,
40 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
41 metrics: Arc<GatewayMetrics>,
42}
43
44impl ApiServer {
45 pub fn new(config: ServerConfig) -> Self {
47 Self {
48 config,
49 connections: Arc::new(DashMap::new()),
50 router: Arc::new(RequestRouter::new()),
51 subscriptions: None,
52 metrics: Arc::new(GatewayMetrics::new()),
53 }
54 }
55
56 pub fn set_subscriptions(&mut self, subs: Arc<crate::subscription::SubscriptionManager>) {
58 self.subscriptions = Some(subs);
59 }
60
61 pub fn router(&self) -> &Arc<RequestRouter> {
63 &self.router
64 }
65
66 pub fn connections(&self) -> &Arc<DashMap<u64, ClientConn>> {
68 &self.connections
69 }
70
71 pub fn set_metrics(&mut self, metrics: Arc<GatewayMetrics>) {
73 self.metrics = metrics;
74 }
75
76 pub fn metrics(&self) -> &Arc<GatewayMetrics> {
78 &self.metrics
79 }
80
81 pub async fn run(&self) -> anyhow::Result<()> {
83 let listener = TcpListener::bind(&self.config.listen_addr).await?;
84 tracing::info!(addr = %self.config.listen_addr, "API server listening");
85
86 let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
87 let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
88
89 let connections = Arc::clone(&self.connections);
91 let router = Arc::clone(&self.router);
92 let config = self.config.clone();
93 let metrics = Arc::clone(&self.metrics);
94 tokio::spawn(async move {
95 process_requests(req_rx, connections, router, config, metrics).await;
96 });
97
98 let cleanup_connections = Arc::clone(&self.connections);
100 let cleanup_subs = self.subscriptions.clone();
101 let cleanup_metrics = Arc::clone(&self.metrics);
102 tokio::spawn(async move {
103 while let Some(notify) = disconnect_rx.recv().await {
104 let removed = cleanup_connections.remove(¬ify.conn_id);
105 if removed.is_some() {
106 cleanup_metrics
107 .total_disconnections
108 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
109 if let Some(ref subs) = cleanup_subs {
111 subs.on_disconnect(notify.conn_id);
112 }
113 tracing::info!(
114 conn_id = notify.conn_id,
115 remaining = cleanup_connections.len(),
116 "connection removed from pool"
117 );
118 }
119 }
120 });
121
122 let ka_connections = Arc::clone(&self.connections);
124 let ka_subs = self.subscriptions.clone();
125 let ka_metrics = Arc::clone(&self.metrics);
126 tokio::spawn(async move {
127 const CHECK_INTERVAL_SECS: u64 = 15;
128 const TIMEOUT_SECS: u64 = 66;
129 let mut interval =
130 tokio::time::interval(std::time::Duration::from_secs(CHECK_INTERVAL_SECS));
131 interval.tick().await; loop {
133 interval.tick().await;
134 let now = Instant::now();
135 let mut timed_out = Vec::new();
136 for entry in ka_connections.iter() {
137 let conn = entry.value();
138 if now.duration_since(conn.last_keepalive).as_secs() >= TIMEOUT_SECS {
139 timed_out.push(conn.conn_id);
140 }
141 }
142 for conn_id in timed_out {
143 if ka_connections.remove(&conn_id).is_some() {
144 ka_metrics
145 .keepalive_timeouts
146 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147 ka_metrics
148 .total_disconnections
149 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
150 if let Some(ref subs) = ka_subs {
151 subs.on_disconnect(conn_id);
152 }
153 tracing::info!(
154 conn_id = conn_id,
155 remaining = ka_connections.len(),
156 "keepalive timeout, connection removed"
157 );
158 }
159 }
160 }
161 });
162
163 let connections = Arc::clone(&self.connections);
165 let accept_metrics = Arc::clone(&self.metrics);
166 loop {
167 let (stream, peer_addr) = listener.accept().await?;
168
169 if connections.len() >= MAX_CONNECTIONS {
170 accept_metrics
171 .rejected_connections
172 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
173 tracing::warn!(
174 peer = %peer_addr,
175 "max connections reached ({}), rejecting",
176 MAX_CONNECTIONS
177 );
178 drop(stream);
179 continue;
180 }
181
182 accept_metrics
183 .total_connections
184 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
185
186 let conn_id = crate::conn::ClientConn::generate_conn_id();
187 let aes_key = crate::conn::ClientConn::generate_aes_key();
188 stream.set_nodelay(true).ok();
189
190 tracing::info!(
191 conn_id = conn_id,
192 peer = %peer_addr,
193 total = connections.len() + 1,
194 "client connected"
195 );
196
197 let tx = crate::conn::run_connection(
198 stream,
199 conn_id,
200 aes_key,
201 req_tx.clone(),
202 disconnect_tx.clone(),
203 )
204 .await;
205
206 let conn = ClientConn {
207 conn_id,
208 state: ConnState::Connected,
209 aes_key,
210 aes_encrypt_enabled: false,
211 proto_fmt_type: ProtoFmtType::Protobuf,
212 last_keepalive: Instant::now(),
213 recv_notify: false,
214 keepalive_count: std::sync::atomic::AtomicU32::new(0),
215 tx,
216 key_id: None,
219 scopes: std::collections::HashSet::new(),
220 allowed_markets: None,
223 allowed_acc_ids: None,
225 };
226
227 connections.insert(conn_id, conn);
228 }
229 }
230
231 pub async fn send_response(
233 connections: &DashMap<u64, ClientConn>,
234 conn_id: u64,
235 proto_id: u32,
236 serial_no: u32,
237 body: Vec<u8>,
238 ) {
239 if let Some(conn) = connections.get(&conn_id) {
240 let frame = conn.make_frame(proto_id, serial_no, bytes::Bytes::from(body));
241 if conn.tx.send(frame).await.is_err() {
242 tracing::warn!(
243 conn_id = conn_id,
244 "failed to send response, connection closed"
245 );
246 }
247 }
248 }
249}
250
251async fn process_requests(
253 mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
254 connections: Arc<DashMap<u64, ClientConn>>,
255 router: Arc<RequestRouter>,
256 config: ServerConfig,
257 metrics: Arc<GatewayMetrics>,
258) {
259 while let Some(mut req) = req_rx.recv().await {
260 let conn_id = req.conn_id;
261 let proto_id_val = req.proto_id;
262 let serial_no = req.serial_no;
263 let req_start = Instant::now();
264
265 metrics
266 .total_requests
267 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
268
269 if let Some(mut conn) = connections.get_mut(&conn_id) {
271 conn.last_keepalive = Instant::now();
272 }
273
274 if futu_auth::is_internal_proto_id(proto_id_val) {
279 metrics
280 .total_request_errors
281 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
282 tracing::warn!(
283 conn_id,
284 proto_id = proto_id_val,
285 "rejecting daemon-internal proto_id at raw TCP public surface (codex 0532 F3)"
286 );
287 continue;
288 }
289
290 if proto_id_val != proto_id::INIT_CONNECT
292 && let Some(conn) = connections.get(&conn_id)
293 && conn.aes_encrypt_enabled
294 {
295 match conn.decrypt_body(&req.body) {
296 Ok(decrypted) => {
297 req.body = bytes::Bytes::from(decrypted);
298 }
299 Err(e) => {
300 metrics
301 .total_request_errors
302 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
303 tracing::warn!(
304 conn_id = conn_id,
305 proto_id = proto_id_val,
306 error = %e,
307 "AES decrypt request failed, dropping"
308 );
309 continue;
310 }
311 }
312 }
313
314 let response_body = match proto_id_val {
316 proto_id::INIT_CONNECT => match connections.get_mut(&conn_id) {
317 Some(mut conn) => conn
318 .handle_init_connect(
319 &req.body,
320 config.server_ver,
321 config.login_user_id,
322 config.keepalive_interval,
323 config.rsa_private_key.as_deref(),
324 )
325 .ok(),
326 _ => None,
327 },
328 proto_id::KEEP_ALIVE => match connections.get(&conn_id) {
329 Some(conn) => conn.handle_keepalive(&req.body).ok(),
330 _ => None,
331 },
332 _ => {
333 router.dispatch(conn_id, &req).await
335 }
336 };
337
338 metrics.record_latency_ns(req_start.elapsed().as_nanos() as u64);
340
341 if let Some(body) = response_body {
343 metrics
344 .total_response_bytes
345 .fetch_add(body.len() as u64, std::sync::atomic::Ordering::Relaxed);
346 ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
347 } else if proto_id_val != proto_id::INIT_CONNECT && proto_id_val != proto_id::KEEP_ALIVE {
348 metrics
349 .total_request_errors
350 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
351 }
352 }
353}