1use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU32, Ordering};
10
11use bytes::Bytes;
12use futures::{SinkExt, StreamExt};
13use parking_lot::Mutex;
14use tokio::net::TcpStream;
15use tokio::sync::{mpsc, oneshot};
16use tokio_util::codec::Framed;
17
18use futu_core::error::{FutuError, Result};
19use futu_net::encrypt::{aes_cbc_md5_decrypt_var, aes_cbc_md5_encrypt_var};
20
21use crate::nn_codec::{NNCodec, NNFrame, NNHeader, should_skip_encryption};
22
23pub struct BackendConn {
25 serial_no: AtomicU32,
26 sec_data: AtomicU32,
27 session_key: Arc<Mutex<Option<Vec<u8>>>>,
33 cmd_tx: mpsc::Sender<BackendCmd>,
34 pub user_id: AtomicU32,
35 client_ip: Mutex<String>,
40 pub client_type: u8,
41 pub client_ver: u16,
42 pub lang_id: u8,
43}
44
45enum BackendCmd {
46 Send(NNFrame, Option<oneshot::Sender<NNFrame>>),
47}
48
49pub type PushCallback = Arc<dyn Fn(u16, Bytes) + Send + Sync + 'static>;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53enum InboundFrameDecision {
54 Deliver,
55 Drop,
56}
57
58fn decode_inbound_frame_body(
59 frame: &mut NNFrame,
60 session_key: Option<&[u8]>,
61) -> InboundFrameDecision {
62 if !should_skip_encryption(frame.header.cmd_id)
63 && let Some(key) = session_key
64 {
65 let body_len = frame.body.len();
66 if body_len >= 32 && body_len.is_multiple_of(16) {
70 match aes_cbc_md5_decrypt_var(key, &frame.body) {
71 Ok(decrypted) => {
72 frame.body = Bytes::from(decrypted);
74 }
75 Err(e) => {
76 tracing::warn!(
80 cmd_id = frame.header.cmd_id,
81 body_len = body_len,
82 key_len = key.len(),
83 error = %e,
84 "decrypt failed, dropping inbound frame"
85 );
86 return InboundFrameDecision::Drop;
87 }
88 }
89 } else {
90 tracing::debug!(
91 cmd_id = frame.header.cmd_id,
92 body_len = body_len,
93 "body not encrypted (len not aligned to 16)"
94 );
95 }
96 }
97
98 if frame.header.is_compressed() {
99 let compressed_body_len = frame.body.len();
100 match crate::ftlogin_wire::decode_inbound_body(true, frame.body.as_ref()) {
101 Ok(decompressed) => {
102 frame.body = Bytes::from(decompressed);
103 frame.header.body_len = frame.body.len() as u32;
104 tracing::debug!(
105 cmd_id = frame.header.cmd_id,
106 serial_no = frame.header.serial_no,
107 compressed_body_len,
108 body_len = frame.body.len(),
109 "decompressed inbound frame after decrypt"
110 );
111 }
112 Err(e) => {
113 tracing::warn!(
118 cmd_id = frame.header.cmd_id,
119 serial_no = frame.header.serial_no,
120 body_len = compressed_body_len,
121 error = %e,
122 "decompress failed after decrypt, dropping inbound frame"
123 );
124 return InboundFrameDecision::Drop;
125 }
126 }
127 }
128
129 InboundFrameDecision::Deliver
130}
131
132fn release_pending_on_inbound_decode_error(
133 frame: &NNFrame,
134 pending: &Arc<Mutex<HashMap<u32, oneshot::Sender<NNFrame>>>>,
135) {
136 if !frame.header.is_push {
137 pending.lock().remove(&frame.header.serial_no);
138 }
139}
140
141impl BackendConn {
142 pub const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
149
150 pub const CLIENT_VER_FTGTW: u16 = 1030;
157
158 async fn establish_stream(addr: &str, timeout: std::time::Duration) -> Result<TcpStream> {
160 let stream = match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
161 Ok(Ok(s)) => s,
162 Ok(Err(e)) => return Err(e.into()),
163 Err(_elapsed) => {
164 return Err(FutuError::Network(std::io::Error::new(
165 std::io::ErrorKind::TimedOut,
166 format!("connect to {addr} timed out after {}s", timeout.as_secs()),
167 )));
168 }
169 };
170 stream.set_nodelay(true)?;
171 Ok(stream)
172 }
173
174 pub async fn connect(addr: &str, push_callback: PushCallback) -> Result<Self> {
176 let stream = Self::establish_stream(addr, Self::CONNECT_TIMEOUT).await?;
177 tracing::info!(addr = addr, "connected to backend");
178 Self::from_stream(stream, push_callback)
179 }
180
181 pub async fn connect_race(
191 addrs: &[String],
192 push_callback: PushCallback,
193 ) -> Result<(Self, String)> {
194 use futures::stream::{FuturesUnordered, StreamExt};
195
196 if addrs.is_empty() {
197 return Err(FutuError::Network(std::io::Error::new(
198 std::io::ErrorKind::InvalidInput,
199 "connect_race: empty address list",
200 )));
201 }
202
203 tracing::info!(
204 candidates = addrs.len(),
205 addrs = ?addrs,
206 "racing parallel connects"
207 );
208
209 let mut attempts: FuturesUnordered<_> = addrs
210 .iter()
211 .cloned()
212 .map(|addr| async move {
213 let result = Self::establish_stream(&addr, Self::CONNECT_TIMEOUT).await;
214 (addr, result)
215 })
216 .collect();
217
218 let mut last_err: Option<FutuError> = None;
219 while let Some((addr, result)) = attempts.next().await {
220 match result {
221 Ok(stream) => {
222 tracing::info!(
223 addr = %addr,
224 remaining_losers = attempts.len(),
225 "connect race winner"
226 );
227 drop(attempts); let conn = Self::from_stream(stream, push_callback)?;
229 return Ok((conn, addr));
230 }
231 Err(e) => {
232 tracing::debug!(addr = %addr, error = %e, "candidate failed");
233 last_err = Some(e);
234 }
235 }
236 }
237
238 Err(last_err.unwrap_or_else(|| {
239 FutuError::Network(std::io::Error::other("connect_race: all candidates failed"))
240 }))
241 }
242
243 #[cfg(feature = "test-util")]
248 pub fn from_duplex(
249 stream: tokio::io::DuplexStream,
250 push_callback: PushCallback,
251 ) -> Result<Self> {
252 Self::from_stream_inner(stream, push_callback)
253 }
254
255 fn from_stream(stream: TcpStream, push_callback: PushCallback) -> Result<Self> {
257 Self::from_stream_inner(stream, push_callback)
258 }
259
260 pub(crate) fn from_stream_inner<S>(stream: S, push_callback: PushCallback) -> Result<Self>
263 where
264 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static,
265 {
266 let framed = Framed::new(stream, NNCodec);
267 let (mut sink, mut stream_rx) = framed.split();
268
269 let (cmd_tx, mut cmd_rx) = mpsc::channel::<BackendCmd>(256);
270
271 let pending = Arc::new(Mutex::new(HashMap::<u32, oneshot::Sender<NNFrame>>::new()));
272 let pending_recv = pending.clone();
273 let pending_send = pending.clone();
274 let session_key: Arc<Mutex<Option<Vec<u8>>>> = Arc::new(Mutex::new(None));
275 let session_key_for_recv = session_key.clone();
276
277 tokio::spawn(async move {
279 while let Some(Ok(mut frame)) = stream_rx.next().await {
280 tracing::debug!(
282 cmd_id = frame.header.cmd_id,
283 serial_no = frame.header.serial_no,
284 is_push = frame.header.is_push,
285 is_compressed = frame.header.is_compressed,
286 ex_head_len = frame.header.ex_head_len,
287 body_len = frame.header.body_len,
288 actual_body_len = frame.body.len(),
289 "recv frame"
290 );
291
292 if let Some(err) = frame.parse_ex_head_error()
300 && (err.cmd_result != 0
301 || err.code != 0
302 || !err.message.is_empty()
303 || !err.source.is_empty())
304 {
305 let is_soft_fail = err.code == -102;
311 if is_soft_fail {
312 tracing::debug!(
313 cmd_id = frame.header.cmd_id,
314 code = err.code,
315 source = %err.source,
316 message = %err.message,
317 "server: cmd not available on this channel (soft-fail)"
318 );
319 } else {
320 tracing::warn!(
321 cmd_id = frame.header.cmd_id,
322 cmd_result = err.cmd_result,
323 code = err.code,
324 source = %err.source,
325 message = %err.message,
326 "server returned err_info in ex_head"
327 );
328 }
329 }
330
331 let session_key = session_key_for_recv.lock().clone();
332 if decode_inbound_frame_body(&mut frame, session_key.as_deref())
333 == InboundFrameDecision::Drop
334 {
335 release_pending_on_inbound_decode_error(&frame, &pending_recv);
336 continue;
337 }
338
339 let is_push = frame.header.is_push
344 || (frame.header.serial_no == 0 && !pending_recv.lock().contains_key(&0));
345 if is_push {
346 tracing::debug!(
347 cmd_id = frame.header.cmd_id,
348 body_len = frame.body.len(),
349 is_push = frame.header.is_push,
350 is_compressed = frame.header.is_compressed,
351 reserved = ?frame.header.reserved,
352 "backend push received"
353 );
354 push_callback(frame.header.cmd_id, frame.body);
355 } else {
356 let tx = pending_recv.lock().remove(&frame.header.serial_no);
357 if let Some(tx) = tx {
358 let _ = tx.send(frame);
359 }
360 }
361 }
362 tracing::warn!("backend connection closed");
364 let mut pending = pending_recv.lock();
365 let count = pending.len();
366 if count > 0 {
367 tracing::warn!(
368 pending_count = count,
369 "aborting pending requests due to disconnect"
370 );
371 }
372 pending.clear();
376 });
377
378 tokio::spawn(async move {
380 while let Some(cmd) = cmd_rx.recv().await {
381 match cmd {
382 BackendCmd::Send(frame, resp_tx) => {
383 if let Some(tx) = resp_tx {
384 pending_send.lock().insert(frame.header.serial_no, tx);
385 }
386 if let Err(e) = sink.send(frame).await {
387 tracing::error!(error = %e, "backend send failed");
388 break;
389 }
390 }
391 }
392 }
393 });
394
395 Ok(Self {
396 serial_no: AtomicU32::new(0),
397 sec_data: AtomicU32::new(1),
398 session_key, cmd_tx,
400 user_id: AtomicU32::new(0),
401 client_ip: Mutex::new(String::new()),
402 client_type: 40, client_ver: Self::CLIENT_VER_FTGTW,
404 lang_id: 0,
405 })
406 }
407
408 pub fn set_session_key(&self, key: Vec<u8>) {
412 *self.session_key.lock() = Some(key);
413 }
414
415 pub fn set_sec_data(&self, val: u32) {
417 self.sec_data.store(val, Ordering::Relaxed);
418 }
419
420 pub fn set_client_ip(&self, ip: String) {
422 *self.client_ip.lock() = ip;
423 }
424
425 pub fn client_ip(&self) -> String {
427 self.client_ip.lock().clone()
428 }
429
430 pub async fn request(&self, cmd_id: u16, body: Vec<u8>) -> Result<NNFrame> {
432 self.request_with_reserved(cmd_id, body, [0u8; 10]).await
433 }
434
435 pub async fn request_with_reserved(
437 &self,
438 cmd_id: u16,
439 body: Vec<u8>,
440 reserved: [u8; 10],
441 ) -> Result<NNFrame> {
442 let frame = self.build_outbound_frame(cmd_id, body, reserved)?;
443
444 let (resp_tx, resp_rx) = oneshot::channel();
445 self.cmd_tx
446 .send(BackendCmd::Send(frame, Some(resp_tx)))
447 .await
448 .map_err(|_| FutuError::NotInitialized)?;
449
450 let resp = tokio::time::timeout(std::time::Duration::from_secs(10), resp_rx)
451 .await
452 .map_err(|_| FutuError::Timeout)?
453 .map_err(|_| FutuError::Codec("response channel closed".into()))?;
454
455 Ok(resp)
456 }
457
458 pub async fn send_fire_and_forget(&self, cmd_id: u16, body: Vec<u8>) -> Result<()> {
460 let frame = self.build_outbound_frame(cmd_id, body, [0u8; 10])?;
461 self.cmd_tx
462 .send(BackendCmd::Send(frame, None))
463 .await
464 .map_err(|_| FutuError::NotInitialized)?;
465
466 Ok(())
467 }
468
469 fn build_outbound_frame(
470 &self,
471 cmd_id: u16,
472 body: Vec<u8>,
473 reserved: [u8; 10],
474 ) -> Result<NNFrame> {
475 let serial = self.next_serial();
476 let mut header = NNHeader::new(cmd_id, serial);
477 header.user_id = self.user_id.load(Ordering::Relaxed);
478 header.client_type = self.client_type;
479 header.client_ver = self.client_ver;
480 header.lang_id = self.lang_id;
481 header.reserved.copy_from_slice(&reserved[..8]);
485
486 let final_body = self.encode_outbound_body(cmd_id, body)?;
487 header.body_len = final_body.len() as u32;
488
489 Ok(NNFrame {
490 header,
491 body: Bytes::from(final_body),
492 ex_head: Bytes::new(),
493 })
494 }
495
496 fn encode_outbound_body(&self, cmd_id: u16, body: Vec<u8>) -> Result<Vec<u8>> {
497 if should_skip_encryption(cmd_id) {
498 return Ok(body);
499 }
500
501 let key = self.session_key.lock().clone();
502 match key {
503 Some(key) => {
504 let sec = self.sec_data.fetch_add(1, Ordering::Relaxed) + 1;
506 let mut plaintext = Vec::with_capacity(4 + body.len());
507 plaintext.extend_from_slice(&sec.to_be_bytes());
508 plaintext.extend_from_slice(&body);
509 aes_cbc_md5_encrypt_var(&key, &plaintext)
511 }
512 None => Ok(body),
513 }
514 }
515
516 fn next_serial(&self) -> u32 {
517 self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
518 }
519}
520
521#[cfg(test)]
522mod tests;