1use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, StatusCode};
13use axum::response::IntoResponse;
14use chrono::Utc;
15use futures::{SinkExt, StreamExt};
16use tokio::sync::broadcast;
17
18use futu_auth::{KeyRecord, KeyStore, Scope};
19use futu_server::push::ExternalPushSink;
20
21use crate::adapter::RestState;
22
23#[derive(Clone, Debug, serde::Serialize)]
25pub struct WsPushEvent {
26 #[serde(rename = "type")]
28 pub event_type: String,
29 #[serde(skip)]
31 pub required_scope: WsPushScope,
32 pub proto_id: u32,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub sec_key: Option<String>,
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub sub_type: Option<i32>,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub acc_id: Option<u64>,
43 pub body_b64: String,
45}
46
47#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
53pub enum WsPushScope {
54 #[default]
55 Quote,
56 Notify,
57 Trade,
58}
59
60impl WsPushScope {
61 pub fn required_scope(&self) -> Scope {
63 match self {
64 WsPushScope::Quote => Scope::QotRead,
65 WsPushScope::Notify => Scope::QotRead,
66 WsPushScope::Trade => Scope::AccRead,
67 }
68 }
69}
70
71#[derive(Clone)]
77pub struct WsBroadcaster {
78 tx: broadcast::Sender<WsPushEvent>,
79}
80
81impl WsBroadcaster {
82 pub fn new(capacity: usize) -> Self {
83 let (tx, _) = broadcast::channel(capacity);
84 Self { tx }
85 }
86
87 pub fn send(&self, event: WsPushEvent) {
89 let _ = self.tx.send(event);
91 }
92
93 pub fn subscribe(&self) -> broadcast::Receiver<WsPushEvent> {
95 self.tx.subscribe()
96 }
97
98 fn encode_body(body: &[u8]) -> String {
99 use base64::Engine;
100 base64::engine::general_purpose::STANDARD.encode(body)
101 }
102
103 pub fn push_quote(&self, sec_key: &str, sub_type: i32, proto_id: u32, body: &[u8]) {
105 self.send(WsPushEvent {
106 event_type: "quote".to_string(),
107 required_scope: WsPushScope::Quote,
108 proto_id,
109 sec_key: Some(sec_key.to_string()),
110 sub_type: Some(sub_type),
111 acc_id: None,
112 body_b64: Self::encode_body(body),
113 });
114 }
115
116 pub fn push_broadcast(&self, proto_id: u32, body: &[u8]) {
118 self.send(WsPushEvent {
119 event_type: "notify".to_string(),
120 required_scope: WsPushScope::Notify,
121 proto_id,
122 sec_key: None,
123 sub_type: None,
124 acc_id: None,
125 body_b64: Self::encode_body(body),
126 });
127 }
128
129 pub fn push_trade(&self, acc_id: u64, proto_id: u32, body: &[u8]) {
131 self.send(WsPushEvent {
132 event_type: "trade".to_string(),
133 required_scope: WsPushScope::Trade,
134 proto_id,
135 sec_key: None,
136 sub_type: None,
137 acc_id: Some(acc_id),
138 body_b64: Self::encode_body(body),
139 });
140 }
141}
142
143impl ExternalPushSink for WsBroadcaster {
145 fn on_quote_push(&self, sec_key: &str, sub_type: i32, proto_id: u32, body: &[u8]) {
146 self.push_quote(sec_key, sub_type, proto_id, body);
147 }
148
149 fn on_broadcast_push(&self, proto_id: u32, body: &[u8]) {
150 self.push_broadcast(proto_id, body);
151 }
152
153 fn on_trade_push(&self, acc_id: u64, proto_id: u32, body: &[u8]) {
154 self.push_trade(acc_id, proto_id, body);
155 }
156}
157
158fn extract_ws_token(headers: &HeaderMap, query: &HashMap<String, String>) -> Option<String> {
163 if let Some(t) = query.get("token") {
164 return Some(t.clone());
165 }
166 headers
167 .get("authorization")
168 .and_then(|v| v.to_str().ok())
169 .and_then(|v| v.strip_prefix("Bearer ").map(|s| s.trim().to_string()))
170}
171
172fn authenticate_ws(
179 key_store: &KeyStore,
180 headers: &HeaderMap,
181 query: &HashMap<String, String>,
182) -> Result<Option<Arc<KeyRecord>>, (StatusCode, &'static str)> {
183 if !key_store.is_configured() {
184 return Ok(None);
185 }
186
187 let Some(token) = extract_ws_token(headers, query) else {
188 futu_auth::audit::reject(
189 "ws",
190 "/ws",
191 "<missing>",
192 "missing token (query or Authorization)",
193 );
194 return Err((StatusCode::UNAUTHORIZED, "missing api key"));
195 };
196
197 let Some(rec) = key_store.verify(&token) else {
198 futu_auth::audit::reject("ws", "/ws", "<invalid>", "invalid api key");
199 return Err((StatusCode::UNAUTHORIZED, "invalid api key"));
200 };
201
202 if rec.is_expired(Utc::now()) {
203 futu_auth::audit::reject("ws", "/ws", &rec.id, "key expired");
204 return Err((StatusCode::UNAUTHORIZED, "key expired"));
205 }
206
207 if !rec.scopes.contains(&Scope::QotRead) {
208 futu_auth::audit::reject("ws", "/ws", &rec.id, "missing qot:read scope");
209 return Err((StatusCode::FORBIDDEN, "missing qot:read scope"));
210 }
211
212 futu_auth::audit::allow("ws", "/ws", &rec.id, Some("qot:read"));
213 Ok(Some(rec))
214}
215
216pub async fn ws_handler(
218 ws: WebSocketUpgrade,
219 headers: HeaderMap,
220 Query(query): Query<HashMap<String, String>>,
221 State(state): State<RestState>,
222) -> impl IntoResponse {
223 let rec = match authenticate_ws(&state.key_store, &headers, &query) {
224 Ok(rec) => rec,
225 Err((code, msg)) => return (code, msg).into_response(),
226 };
227 let scopes: HashSet<Scope> = match &rec {
229 Some(r) => r.scopes.clone(),
230 None => all_scopes(),
231 };
232 let key_id = rec.as_ref().map(|r| r.id.clone());
233 let broadcaster = Arc::clone(&state.ws_broadcaster);
234 ws.on_upgrade(move |socket| handle_ws_connection(socket, broadcaster, scopes, key_id))
235 .into_response()
236}
237
238fn all_scopes() -> HashSet<Scope> {
240 [
241 Scope::QotRead,
242 Scope::AccRead,
243 Scope::TradeSimulate,
244 Scope::TradeReal,
245 ]
246 .into_iter()
247 .collect()
248}
249
250async fn handle_ws_connection(
255 socket: WebSocket,
256 broadcaster: Arc<WsBroadcaster>,
257 scopes: HashSet<Scope>,
258 key_id: Option<String>,
259) {
260 let (mut ws_tx, mut ws_rx) = socket.split();
261 let mut push_rx = broadcaster.subscribe();
262
263 tracing::info!(
264 key_id = ?key_id,
265 scopes = ?scopes,
266 "WebSocket push client connected"
267 );
268
269 let send_scopes = scopes.clone();
271 let send_key_id = key_id.clone().unwrap_or_else(|| "<none>".to_string());
272 let send_task = tokio::spawn(async move {
273 while let Ok(event) = push_rx.recv().await {
274 if !send_scopes.contains(&event.required_scope.required_scope()) {
276 futu_auth::metrics::bump_ws_filtered(&event.event_type, &send_key_id);
278 continue;
279 }
280 let json = match serde_json::to_string(&event) {
281 Ok(j) => j,
282 Err(_) => continue,
283 };
284 if ws_tx.send(Message::Text(json.into())).await.is_err() {
285 break; }
287 }
288 });
289
290 let recv_task = tokio::spawn(async move {
292 while let Some(msg) = ws_rx.next().await {
293 match msg {
294 Ok(Message::Close(_)) | Err(_) => break,
295 Ok(Message::Ping(data)) => {
296 let _ = data;
298 }
299 _ => {} }
301 }
302 });
303
304 tokio::select! {
306 _ = send_task => {}
307 _ = recv_task => {}
308 }
309
310 tracing::info!("WebSocket push client disconnected");
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
320 fn scope_filter_blocks_trade_for_qot_only_client() {
321 let scopes: HashSet<Scope> = [Scope::QotRead].into_iter().collect();
322
323 assert!(scopes.contains(&WsPushScope::Quote.required_scope()));
325 assert!(scopes.contains(&WsPushScope::Notify.required_scope()));
327 assert!(!scopes.contains(&WsPushScope::Trade.required_scope()));
329 }
330
331 #[test]
333 fn scope_filter_allows_all_for_qot_plus_acc() {
334 let scopes: HashSet<Scope> = [Scope::QotRead, Scope::AccRead].into_iter().collect();
335 for s in [WsPushScope::Quote, WsPushScope::Notify, WsPushScope::Trade] {
336 assert!(
337 scopes.contains(&s.required_scope()),
338 "{:?} should be allowed",
339 s
340 );
341 }
342 }
343
344 #[test]
346 fn legacy_all_scopes_allows_everything() {
347 let scopes = all_scopes();
348 for s in [WsPushScope::Quote, WsPushScope::Notify, WsPushScope::Trade] {
349 assert!(scopes.contains(&s.required_scope()));
350 }
351 }
352
353 #[test]
355 fn event_type_matches_scope_category() {
356 let b = WsBroadcaster::new(4);
357 let mut rx = b.subscribe();
358 b.push_quote("HK.00700", 1, 0, b"x");
359 b.push_broadcast(0, b"x");
360 b.push_trade(42, 0, b"x");
361
362 let e1 = rx.try_recv().unwrap();
363 assert_eq!(e1.event_type, "quote");
364 assert_eq!(e1.required_scope, WsPushScope::Quote);
365
366 let e2 = rx.try_recv().unwrap();
367 assert_eq!(e2.event_type, "notify");
368 assert_eq!(e2.required_scope, WsPushScope::Notify);
369
370 let e3 = rx.try_recv().unwrap();
371 assert_eq!(e3.event_type, "trade");
372 assert_eq!(e3.required_scope, WsPushScope::Trade);
373 }
374}