1use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32use std::time::Instant;
33
34use bytes::BytesMut;
35use chrono::Utc;
36use dashmap::DashMap;
37use futures::{SinkExt, StreamExt};
38use tokio::net::TcpListener;
39use tokio::sync::mpsc;
40use tokio_tungstenite::tungstenite::handshake::server::{ErrorResponse, Request, Response};
41use tokio_tungstenite::tungstenite::http::StatusCode;
42use tokio_tungstenite::tungstenite::protocol::Message;
43
44use futu_auth::{KeyRecord, KeyStore, RuntimeCounters, Scope};
45use futu_auth_pipeline::{
46 AuthDecision, AuthEnvelope, Credential, Endpoint, FilterRegistry, RejectKind, SurfaceId,
47 authenticate_request,
48};
49
50pub struct WsAdapter;
63
64impl futu_auth_pipeline::SurfaceAdapter for WsAdapter {
65 type WireResponse = ();
66
67 fn surface_id() -> SurfaceId {
68 SurfaceId::Ws
69 }
70
71 fn translate_reject(kind: RejectKind, reason: String) -> Self::WireResponse {
72 let _ = (kind, reason);
75 }
76}
77use futu_codec::frame::FutuFrame;
78use futu_codec::header::{FutuHeader, HEADER_SIZE, ProtoFmtType};
79use futu_core::proto_id;
80
81use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
82use crate::listener::{MAX_CONNECTIONS, ServerConfig};
83use crate::router::RequestRouter;
84
85pub struct WsServer {
87 listen_addr: String,
88 config: ServerConfig,
89 connections: Arc<DashMap<u64, ClientConn>>,
90 router: Arc<RequestRouter>,
91 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
92 key_store: Option<Arc<KeyStore>>,
94 counters: Option<Arc<RuntimeCounters>>,
96 filter_registry: Option<Arc<FilterRegistry>>,
99}
100
101pub struct WsServerDeps {
107 connections: Arc<DashMap<u64, ClientConn>>,
108 router: Arc<RequestRouter>,
109 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
110}
111
112impl WsServerDeps {
113 pub fn new(
114 connections: Arc<DashMap<u64, ClientConn>>,
115 router: Arc<RequestRouter>,
116 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
117 ) -> Self {
118 Self {
119 connections,
120 router,
121 subscriptions,
122 }
123 }
124}
125
126impl WsServer {
127 pub fn new(
129 listen_addr: String,
130 config: ServerConfig,
131 connections: Arc<DashMap<u64, ClientConn>>,
132 router: Arc<RequestRouter>,
133 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
134 ) -> Self {
135 Self::with_auth(
136 listen_addr,
137 config,
138 WsServerDeps::new(connections, router, subscriptions),
139 None,
140 None,
141 )
142 }
143
144 pub fn with_auth(
147 listen_addr: String,
148 config: ServerConfig,
149 deps: WsServerDeps,
150 key_store: Option<Arc<KeyStore>>,
151 counters: Option<Arc<RuntimeCounters>>,
152 ) -> Self {
153 Self {
154 listen_addr,
155 config,
156 connections: deps.connections,
157 router: deps.router,
158 subscriptions: deps.subscriptions,
159 key_store,
160 counters,
161 filter_registry: None,
162 }
163 }
164
165 pub fn with_filter_registry(mut self, registry: Arc<FilterRegistry>) -> Self {
168 self.filter_registry = Some(registry);
169 self
170 }
171
172 pub async fn run(&self) -> anyhow::Result<()> {
174 let listener = TcpListener::bind(&self.listen_addr).await?;
175 tracing::info!(addr = %self.listen_addr, "WebSocket server listening");
176
177 let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
178 let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
179
180 let connections = Arc::clone(&self.connections);
182 let router = Arc::clone(&self.router);
183 let config = self.config.clone();
184 let key_store_for_process = self
189 .key_store
190 .clone()
191 .unwrap_or_else(|| Arc::new(KeyStore::empty()));
192 let counters_for_process = self
193 .counters
194 .clone()
195 .unwrap_or_else(|| Arc::new(RuntimeCounters::new()));
196 let filter_registry_for_process = self
197 .filter_registry
198 .clone()
199 .unwrap_or_else(|| Arc::new(FilterRegistry::with_defaults()));
200 tokio::spawn(async move {
201 ws_process_requests(
202 req_rx,
203 connections,
204 router,
205 config,
206 counters_for_process,
207 key_store_for_process,
208 filter_registry_for_process,
209 )
210 .await;
211 });
212
213 let cleanup_connections = Arc::clone(&self.connections);
215 let cleanup_subs = self.subscriptions.clone();
216 tokio::spawn(async move {
217 while let Some(notify) = disconnect_rx.recv().await {
218 let removed = cleanup_connections.remove(¬ify.conn_id);
219 if removed.is_some() {
220 if let Some(ref subs) = cleanup_subs {
221 subs.on_disconnect(notify.conn_id);
222 }
223 tracing::info!(
224 conn_id = notify.conn_id,
225 remaining = cleanup_connections.len(),
226 "ws connection removed from pool"
227 );
228 }
229 }
230 });
231
232 let connections = Arc::clone(&self.connections);
234 let key_store_accept = self.key_store.clone();
235 let scope_mode = self.key_store.as_ref().is_some_and(|ks| ks.is_configured());
237 if !scope_mode {
238 tracing::warn!("{}", legacy_mode_warn_tracing_message());
243 eprintln!("{}", legacy_mode_warn_stderr_message());
244 }
245 loop {
246 let (stream, peer_addr) = listener.accept().await?;
247
248 if connections.len() >= MAX_CONNECTIONS {
249 tracing::warn!(
250 peer = %peer_addr,
251 "max connections reached ({}), rejecting ws client",
252 MAX_CONNECTIONS,
253 );
254 drop(stream);
255 continue;
256 }
257
258 let conn_id = ClientConn::generate_conn_id();
259 let aes_key = ClientConn::generate_aes_key();
260 stream.set_nodelay(true).ok();
261
262 tracing::info!(
263 conn_id = conn_id,
264 peer = %peer_addr,
265 total = connections.len() + 1,
266 "ws client connected"
267 );
268
269 let (tx, authed) = run_ws_connection(
270 stream,
271 conn_id,
272 aes_key,
273 req_tx.clone(),
274 disconnect_tx.clone(),
275 key_store_accept.clone(),
276 )
277 .await;
278
279 let Some(authed) = authed else {
281 continue;
282 };
283
284 let (key_id, scopes, allowed_markets, allowed_acc_ids) = match authed {
285 AuthResult::Authenticated(rec) => (
286 Some(rec.id.clone()),
287 rec.scopes.clone(),
288 rec.allowed_markets
291 .as_ref()
292 .map(|s| std::sync::Arc::new(s.clone())),
293 rec.allowed_acc_ids
298 .as_ref()
299 .map(|s| std::sync::Arc::new(s.clone())),
300 ),
301 AuthResult::Legacy => (None, HashSet::new(), None, None),
302 };
303
304 let conn = ClientConn {
305 conn_id,
306 state: ConnState::Connected,
307 aes_key,
308 aes_encrypt_enabled: false,
309 proto_fmt_type: ProtoFmtType::Protobuf,
310 last_keepalive: Instant::now(),
311 recv_notify: false,
312 keepalive_count: std::sync::atomic::AtomicU32::new(0),
313 tx,
314 key_id,
315 scopes,
316 allowed_markets,
317 allowed_acc_ids,
318 };
319
320 connections.insert(conn_id, conn);
321 }
322 }
323}
324
325enum AuthResult {
327 Authenticated(Arc<KeyRecord>),
328 Legacy,
329}
330
331#[allow(clippy::result_large_err)]
337fn store_ws_auth_result(
338 slot: &std::sync::Mutex<Option<AuthResult>>,
339 result: AuthResult,
340 conn_id: u64,
341) -> Result<(), ErrorResponse> {
342 match slot.lock() {
343 Ok(mut guard) => {
344 *guard = Some(result);
345 Ok(())
346 }
347 Err(e) => {
348 tracing::warn!(conn_id = conn_id, error = %e, "ws auth slot poisoned during handshake");
349 Err(make_err_response(
350 StatusCode::INTERNAL_SERVER_ERROR,
351 "ws auth state unavailable",
352 ))
353 }
354 }
355}
356
357fn take_ws_auth_result(
358 slot: &std::sync::Mutex<Option<AuthResult>>,
359 conn_id: u64,
360) -> Option<AuthResult> {
361 match slot.lock() {
362 Ok(mut guard) => match guard.take() {
363 Some(result) => Some(result),
364 None => {
365 tracing::warn!(
366 conn_id = conn_id,
367 "ws handshake succeeded without auth state; closing connection"
368 );
369 None
370 }
371 },
372 Err(e) => {
373 tracing::warn!(conn_id = conn_id, error = %e, "ws auth slot poisoned after handshake");
374 None
375 }
376 }
377}
378
379async fn run_ws_connection(
387 stream: tokio::net::TcpStream,
388 conn_id: u64,
389 _aes_key: [u8; 16],
390 req_tx: mpsc::UnboundedSender<IncomingRequest>,
391 disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
392 key_store: Option<Arc<KeyStore>>,
393) -> (mpsc::Sender<FutuFrame>, Option<AuthResult>) {
394 let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
395
396 let authed_slot: Arc<std::sync::Mutex<Option<AuthResult>>> =
399 Arc::new(std::sync::Mutex::new(None));
400 let slot_cb = Arc::clone(&authed_slot);
401 let store_cb = key_store.clone();
402
403 #[allow(clippy::result_large_err)] let callback = move |req: &Request, resp: Response| -> Result<Response, ErrorResponse> {
405 if let Some(origin_hv) = req.headers().get("origin") {
416 let origin_str = match origin_hv.to_str() {
417 Ok(s) => s,
418 Err(e) => {
419 tracing::warn!(conn_id, error = %e, "ws Origin header is not valid UTF-8");
420 futu_auth::audit::reject("ws", "/ws", "<origin>", "invalid Origin header");
421 return Err(make_err_response(
422 StatusCode::FORBIDDEN,
423 "Invalid Origin header",
424 ));
425 }
426 };
427 let allowed = ws_check_origin(origin_str, store_cb.as_deref());
428 if !allowed {
429 futu_auth::audit::reject("ws", "/ws", "<origin>", "Origin rejected by allowlist");
430 return Err(make_err_response(
431 StatusCode::FORBIDDEN,
432 "Origin not allowed (configure FUTU_WS_ALLOWED_ORIGINS env)",
433 ));
434 }
435 }
436
437 let Some(store) = store_cb.as_ref() else {
439 store_ws_auth_result(slot_cb.as_ref(), AuthResult::Legacy, conn_id)?;
440 return Ok(resp);
441 };
442 if !store.is_configured() {
443 store_ws_auth_result(slot_cb.as_ref(), AuthResult::Legacy, conn_id)?;
444 return Ok(resp);
445 }
446
447 let token = extract_ws_token(req);
449 let Some(token) = token else {
450 futu_auth::audit::reject("ws", "/ws", "<missing>", "missing token");
451 return Err(make_err_response(
452 StatusCode::UNAUTHORIZED,
453 "missing api key (use ?token=... or Authorization: Bearer ...)",
454 ));
455 };
456
457 let Some(rec) = store.verify(&token) else {
458 futu_auth::audit::reject("ws", "/ws", "<invalid>", "invalid api key");
459 return Err(make_err_response(
460 StatusCode::UNAUTHORIZED,
461 "invalid api key",
462 ));
463 };
464
465 if rec.is_expired(Utc::now()) {
466 futu_auth::audit::reject("ws", "/ws", &rec.id, "key expired");
467 return Err(make_err_response(StatusCode::UNAUTHORIZED, "key expired"));
468 }
469
470 if !rec.scopes.contains(&Scope::QotRead) {
475 futu_auth::audit::reject("ws", "/ws", &rec.id, "missing qot:read");
476 return Err(make_err_response(StatusCode::FORBIDDEN, "forbidden"));
477 }
478
479 futu_auth::audit::allow("ws", "/ws", &rec.id, Some("qot:read"));
480 store_ws_auth_result(slot_cb.as_ref(), AuthResult::Authenticated(rec), conn_id)?;
481 Ok(resp)
482 };
483
484 let ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
485 Ok(ws) => ws,
486 Err(e) => {
487 tracing::warn!(conn_id = conn_id, error = %e, "ws handshake failed");
488 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
489 return (frame_tx, None);
490 }
491 };
492
493 let Some(authed) = take_ws_auth_result(authed_slot.as_ref(), conn_id) else {
495 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
496 return (frame_tx, None);
497 };
498
499 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
500
501 tokio::spawn(async move {
503 while let Some(frame) = frame_rx.recv().await {
504 let mut buf = BytesMut::new();
505 frame.header.encode(&mut buf);
506 buf.extend_from_slice(&frame.body);
507 let msg = Message::Binary(buf.freeze());
509 if let Err(e) = ws_sink.send(msg).await {
510 tracing::warn!(conn_id = conn_id, error = %e, "ws send failed");
511 break;
512 }
513 }
514 });
515
516 tokio::spawn(async move {
518 while let Some(result) = ws_stream_rx.next().await {
519 match result {
520 Ok(msg) => {
521 let data = match msg {
522 Message::Binary(data) => data,
523 Message::Close(_) => {
524 tracing::info!(conn_id = conn_id, "ws client sent close");
525 break;
526 }
527 Message::Ping(_) | Message::Pong(_) => {
528 continue;
530 }
531 _ => {
532 continue;
534 }
535 };
536
537 if data.len() < HEADER_SIZE {
539 tracing::warn!(
540 conn_id = conn_id,
541 len = data.len(),
542 "ws message too short for futu header"
543 );
544 continue;
545 }
546
547 let header_buf = BytesMut::from(&data[..]);
548 let header = match FutuHeader::peek(&header_buf) {
549 Ok(Some(h)) => h,
550 Ok(None) => {
551 tracing::warn!(conn_id = conn_id, "ws header peek returned None");
552 continue;
553 }
554 Err(e) => {
555 tracing::warn!(conn_id = conn_id, error = %e, "ws invalid futu header");
556 continue;
557 }
558 };
559
560 let expected_len = HEADER_SIZE + header.body_len as usize;
561 if data.len() < expected_len {
562 tracing::warn!(
563 conn_id = conn_id,
564 expected = expected_len,
565 actual = data.len(),
566 "ws message shorter than expected frame size"
567 );
568 continue;
569 }
570
571 let body = bytes::Bytes::copy_from_slice(&data[HEADER_SIZE..expected_len]);
572
573 let req = IncomingRequest::builder(
574 conn_id,
575 header.proto_id,
576 header.serial_no,
577 header.proto_fmt_type,
578 body,
579 )
580 .build();
581
582 if req_tx.send(req).is_err() {
583 break;
584 }
585 }
586 Err(e) => {
587 tracing::warn!(conn_id = conn_id, error = %e, "ws recv error");
588 break;
589 }
590 }
591 }
592 tracing::info!(conn_id = conn_id, "ws connection closed");
593 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
594 });
595
596 (frame_tx, Some(authed))
597}
598
599pub(crate) const fn legacy_mode_warn_tracing_message() -> &'static str {
604 "WS server running WITHOUT API key auth (legacy mode); \
605 all WS clients accept unauthenticated handshake (no-token / \
606 wrong-bearer / bogus-query all return success). \
607 Pass KeyStore via with_auth() to enable. \
608 v2 will default-reject; migrate to --rest-keys-file / --ws-keys-file for production."
609}
610
611pub(crate) const fn legacy_mode_warn_stderr_message() -> &'static str {
615 "⚠️ WS server (legacy mode, no --ws-keys-file): \
616 unauthenticated handshakes accepted. v2 will default-reject. \
617 Migrate to --ws-keys-file for production."
618}
619
620fn extract_ws_token(req: &Request) -> Option<String> {
630 if let Some(q) = req.uri().query() {
631 let params: HashMap<&str, &str> =
633 q.split('&').filter_map(|kv| kv.split_once('=')).collect();
634 if let Some(v) = params.get("token")
635 && !v.is_empty()
636 {
637 return Some((*v).to_string());
638 }
639 }
640 let header = req.headers().get("authorization")?;
641 let value = header.to_str().ok()?;
642 futu_auth_pipeline::parse_bearer_scheme(value).map(|t| t.to_string())
643}
644
645fn ws_check_origin(origin: &str, key_store: Option<&KeyStore>) -> bool {
655 if let Ok(raw) = std::env::var("FUTU_WS_ALLOWED_ORIGINS") {
656 let trimmed = raw.trim();
657 if !trimmed.is_empty() {
658 for allowed in trimmed
659 .split(',')
660 .map(|s| s.trim())
661 .filter(|s| !s.is_empty())
662 {
663 if allowed == origin {
664 return true;
665 }
666 }
667 return false;
669 }
670 }
671 let auth_enabled = key_store.is_some_and(|ks| ks.is_configured());
673 if !auth_enabled {
674 return is_strict_loopback_origin(origin);
680 }
681 is_strict_loopback_origin(origin)
685}
686
687fn is_strict_loopback_origin(s: &str) -> bool {
691 let after_scheme = match s
692 .strip_prefix("http://")
693 .or_else(|| s.strip_prefix("https://"))
694 {
695 Some(rest) => rest,
696 None => return false,
697 };
698 if after_scheme.contains('/')
699 || after_scheme.contains('?')
700 || after_scheme.contains('#')
701 || after_scheme.contains('@')
702 {
703 return false;
704 }
705 let (host, port_opt): (&str, Option<&str>) =
707 if let Some(rest) = after_scheme.strip_prefix("[::1]") {
708 if rest.is_empty() {
709 ("[::1]", None)
710 } else if let Some(p) = rest.strip_prefix(':') {
711 ("[::1]", Some(p))
712 } else {
713 return false;
714 }
715 } else if let Some((h, p)) = after_scheme.rsplit_once(':') {
716 (h, Some(p))
717 } else {
718 (after_scheme, None)
719 };
720 if !matches!(host, "127.0.0.1" | "localhost" | "[::1]") {
721 return false;
722 }
723 if let Some(port_str) = port_opt {
724 match port_str.parse::<u16>() {
725 Ok(p) if p >= 1 => {}
726 _ => return false,
727 }
728 }
729 true
730}
731
732fn make_err_response(code: StatusCode, msg: &str) -> ErrorResponse {
734 let body = Some(format!(r#"{{"error":"{msg}"}}"#));
735 let mut resp = tokio_tungstenite::tungstenite::http::Response::new(body);
736 *resp.status_mut() = code;
737 resp.headers_mut().insert(
738 "content-type",
739 tokio_tungstenite::tungstenite::http::HeaderValue::from_static("application/json"),
740 );
741 resp
742}
743
744async fn ws_process_requests(
764 mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
765 connections: Arc<DashMap<u64, ClientConn>>,
766 router: Arc<RequestRouter>,
767 config: ServerConfig,
768 counters: Arc<RuntimeCounters>,
769 key_store: Arc<KeyStore>,
770 filter_registry: Arc<FilterRegistry>,
771) {
772 use crate::listener::ApiServer;
773
774 while let Some(mut req) = req_rx.recv().await {
775 let conn_id = req.conn_id;
776 let proto_id_val = req.proto_id;
777 let serial_no = req.serial_no;
778
779 if let Some(mut conn) = connections.get_mut(&conn_id) {
781 conn.last_keepalive = Instant::now();
782 }
783
784 if futu_auth::is_internal_proto_id(proto_id_val) {
789 tracing::warn!(
790 conn_id,
791 proto_id = proto_id_val,
792 "rejecting daemon-internal proto_id at raw WS public surface (codex 0532 F3)"
793 );
794 continue;
795 }
796
797 if proto_id_val != proto_id::INIT_CONNECT
801 && let Some(conn) = connections.get(&conn_id)
802 && conn.aes_encrypt_enabled
803 {
804 match conn.decrypt_body(&req.body) {
805 Ok(decrypted) => {
806 req.body = bytes::Bytes::from(decrypted);
807 }
808 Err(e) => {
809 tracing::warn!(
810 conn_id = conn_id,
811 proto_id = proto_id_val,
812 error = %e,
813 "ws AES decrypt request failed, dropping"
814 );
815 continue;
816 }
817 }
818 }
819
820 let needed_scope = futu_auth_pipeline::capability::scope_for_proto_id(proto_id_val);
824 let dispatch_caller_key_id: Option<String> =
829 connections.get(&conn_id).and_then(|c| c.key_id.clone());
830 let allowed_acc_ids_for_resp_filter: Option<HashSet<u64>> =
831 if proto_id_val == proto_id::INIT_CONNECT || needed_scope.is_none() {
832 None
833 } else {
834 let key_id_snap = dispatch_caller_key_id.clone();
838 let rec_opt = key_id_snap.as_ref().and_then(|id| key_store.get_by_id(id));
839 let credential = match rec_opt {
840 Some(rec) => Credential::PreVerified(rec),
841 None => Credential::None,
842 };
843
844 let env = AuthEnvelope {
845 surface: SurfaceId::Ws,
846 endpoint: Endpoint::Proto(proto_id_val),
847 needed_scope,
848 credential,
849 proto_id: Some(proto_id_val),
850 body: &req.body,
851 explicit_acc_id: None,
852 explicit_ctx: None,
853 commit_rate: true, audit_emit: true,
855 };
856
857 use futu_auth_pipeline::SurfaceAdapter;
861 match authenticate_request(&key_store, &counters, env) {
862 AuthDecision::Allow {
863 allowed_acc_ids, ..
864 } => allowed_acc_ids,
865 decision @ AuthDecision::Reject { .. } => {
866 let _ = WsAdapter::translate_decision(decision);
869 continue;
870 }
871 }
872 };
873
874 let response_body = match proto_id_val {
876 proto_id::INIT_CONNECT => match connections.get_mut(&conn_id) {
877 Some(mut conn) => conn
878 .handle_init_connect(
879 &req.body,
880 config.server_ver,
881 config.login_user_id,
882 config.keepalive_interval,
883 config.rsa_private_key.as_deref(),
884 )
885 .ok(),
886 _ => None,
887 },
888 proto_id::KEEP_ALIVE => match connections.get(&conn_id) {
889 Some(conn) => conn.handle_keepalive(&req.body).ok(),
890 _ => None,
891 },
892 _ => {
893 let dispatch_req = IncomingRequest::builder(
899 req.conn_id,
900 req.proto_id,
901 req.serial_no,
902 req.proto_fmt_type,
903 req.body.clone(),
904 )
905 .with_idempotency_key(req.idempotency_key.clone())
906 .with_caller_scope(
907 allowed_acc_ids_for_resp_filter
908 .as_ref()
909 .map(|s| std::sync::Arc::new(s.clone())),
910 dispatch_caller_key_id.clone(),
911 )
912 .build();
913 router.dispatch(conn_id, &dispatch_req).await
914 }
915 };
916
917 if let Some(body) = response_body {
919 let filtered =
922 filter_registry.apply(proto_id_val, body, allowed_acc_ids_for_resp_filter.as_ref());
923 ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, filtered)
924 .await;
925 }
926 }
927}
928
929#[cfg(test)]
930mod tests;