1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use bytes::Bytes;
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot};
8
9use futu_codec::frame::FutuFrame;
10use futu_core::error::{FutuError, Result};
11use futu_core::proto_id;
12
13use crate::connection::Connection;
14use crate::encrypt;
15use crate::reconnect::ReconnectPolicy;
16
17#[derive(Debug, Clone)]
19pub struct ClientConfig {
20 pub addr: String,
21 pub client_ver: String,
22 pub client_id: String,
23 pub recv_notify: bool,
24 pub rsa_key: Option<String>,
29}
30
31#[derive(Debug, Clone)]
33pub struct PushMessage {
34 pub proto_id: u32,
35 pub body: Bytes,
36}
37
38struct PendingRequest {
40 tx: oneshot::Sender<FutuFrame>,
41}
42
43pub struct FutuClient {
52 config: ClientConfig,
53 serial_no: AtomicU32,
54 pending: Arc<DashMap<u32, PendingRequest>>,
55 push_tx: mpsc::UnboundedSender<PushMessage>,
56 cmd_tx: Option<mpsc::Sender<ClientCommand>>,
57 conn_aes_key: parking_lot::Mutex<Option<[u8; 16]>>,
58 conn_id: AtomicU64,
59}
60
61enum ClientCommand {
62 Send(FutuFrame, oneshot::Sender<()>),
63}
64
65fn decode_client_inbound_body(frame: &FutuFrame, aes_key: Option<&[u8; 16]>) -> Result<Bytes> {
66 match aes_key {
67 Some(key) if !frame.body.is_empty() => {
68 encrypt::aes_ecb_decrypt(key, &frame.body).map(Bytes::from)
69 }
70 _ => Ok(frame.body.clone()),
71 }
72}
73
74fn release_pending_on_inbound_decode_error(
75 frame: &FutuFrame,
76 pending: &DashMap<u32, PendingRequest>,
77) {
78 if !proto_id::is_push_proto(frame.header.proto_id) {
79 pending.remove(&frame.header.serial_no);
80 }
81}
82
83impl FutuClient {
84 pub fn new(config: ClientConfig) -> (Self, mpsc::UnboundedReceiver<PushMessage>) {
88 let (push_tx, push_rx) = mpsc::unbounded_channel();
89
90 let client = Self {
91 config,
92 serial_no: AtomicU32::new(0),
93 pending: Arc::new(DashMap::new()),
94 push_tx,
95 cmd_tx: None,
96 conn_aes_key: parking_lot::Mutex::new(None),
97 conn_id: AtomicU64::new(0),
98 };
99
100 (client, push_rx)
101 }
102
103 pub async fn connect(&mut self) -> Result<InitConnectInfo> {
105 let mut conn = Connection::connect(&self.config.addr).await?;
106
107 let serial = self.next_serial();
109 let req = futu_proto::init_connect::Request {
110 c2s: futu_proto::init_connect::C2s {
111 client_ver: 100,
112 client_id: self.config.client_id.clone(),
113 recv_notify: Some(self.config.recv_notify),
114 packet_enc_algo: Some(0),
115 push_proto_fmt: Some(0),
116 programming_language: Some(String::new()),
117 },
118 };
119 let raw_body = prost::Message::encode_to_vec(&req);
120
121 let send_body = if let Some(ref rsa_key) = self.config.rsa_key {
123 tracing::debug!("encrypting InitConnect with RSA");
124 encrypt::rsa_public_encrypt(rsa_key, &raw_body)?
125 } else {
126 raw_body.clone()
127 };
128
129 let mut frame = Connection::build_frame(proto_id::INIT_CONNECT, serial, send_body);
131 {
132 use sha1::{Digest, Sha1};
133 let mut hasher = Sha1::new();
134 hasher.update(&raw_body);
135 let hash = hasher.finalize();
136 frame.header.body_sha1.copy_from_slice(&hash);
137 }
138 conn.send(frame).await?;
139
140 let resp_frame = conn.recv().await?.ok_or(FutuError::Codec(
142 "connection closed during InitConnect".into(),
143 ))?;
144
145 let resp_body = if let Some(ref rsa_key) = self.config.rsa_key {
147 tracing::debug!("decrypting InitConnect response with RSA");
148 encrypt::rsa_private_decrypt(rsa_key, &resp_frame.body)?
149 } else {
150 resp_frame.body.to_vec()
151 };
152
153 let resp: futu_proto::init_connect::Response =
154 prost::Message::decode(resp_body.as_slice()).map_err(FutuError::Proto)?;
155
156 let ret_type = resp.ret_type;
158 if ret_type != 0 {
159 return Err(FutuError::ServerError {
160 ret_type,
161 msg: resp.ret_msg.unwrap_or_default(),
162 });
163 }
164
165 let s2c = resp.s2c.ok_or(FutuError::Codec(
166 "missing s2c in InitConnect response".into(),
167 ))?;
168
169 let info = InitConnectInfo {
170 server_ver: s2c.server_ver,
171 login_user_id: s2c.login_user_id,
172 conn_id: s2c.conn_id,
173 conn_aes_key: s2c.conn_aes_key.clone(),
174 keep_alive_interval: s2c.keep_alive_interval,
175 };
176 self.conn_id.store(info.conn_id, Ordering::Relaxed);
177
178 if !info.conn_aes_key.is_empty() {
180 let key_bytes = info.conn_aes_key.as_bytes();
181 tracing::debug!(
182 key_len = key_bytes.len(),
183 key_hex = hex_str(key_bytes),
184 "received AES key"
185 );
186 if key_bytes.len() == 16 {
187 let mut key = [0u8; 16];
189 key.copy_from_slice(key_bytes);
190 *self.conn_aes_key.lock() = Some(key);
191 } else if key_bytes.len() == 32 {
192 if let Some(key) = hex_decode_16(key_bytes) {
194 *self.conn_aes_key.lock() = Some(key);
195 } else {
196 tracing::warn!(
197 "AES key is 32 chars but not valid hex, using raw first 16 bytes"
198 );
199 let mut key = [0u8; 16];
200 key.copy_from_slice(&key_bytes[..16]);
201 *self.conn_aes_key.lock() = Some(key);
202 }
203 } else {
204 tracing::warn!(
205 key_len = key_bytes.len(),
206 "unexpected AES key length, using raw bytes (truncated/padded to 16)"
207 );
208 let mut key = [0u8; 16];
209 let copy_len = key_bytes.len().min(16);
210 key[..copy_len].copy_from_slice(&key_bytes[..copy_len]);
211 *self.conn_aes_key.lock() = Some(key);
212 }
213 }
214
215 tracing::info!(
216 server_ver = info.server_ver,
217 conn_id = info.conn_id,
218 keep_alive_interval = info.keep_alive_interval,
219 "InitConnect succeeded"
220 );
221
222 let keep_alive_interval = Duration::from_secs(info.keep_alive_interval as u64);
224 self.start_background_tasks(conn, keep_alive_interval);
225
226 Ok(info)
227 }
228
229 pub async fn request(&self, proto_id: u32, body: Vec<u8>) -> Result<FutuFrame> {
231 let serial = self.next_serial();
232
233 let (final_body, sha1) = self.prepare_body(&body);
235
236 let mut frame = FutuFrame::new(proto_id, serial, Bytes::from(final_body));
237 frame.header.body_sha1 = sha1;
239
240 let (resp_tx, resp_rx) = oneshot::channel();
241 self.pending.insert(serial, PendingRequest { tx: resp_tx });
242
243 if let Some(cmd_tx) = &self.cmd_tx {
245 let (ack_tx, ack_rx) = oneshot::channel();
246 cmd_tx
247 .send(ClientCommand::Send(frame, ack_tx))
248 .await
249 .map_err(|_| FutuError::NotInitialized)?;
250 ack_rx
251 .await
252 .map_err(|_| FutuError::Codec("send ack failed".into()))?;
253 } else {
254 self.pending.remove(&serial);
255 return Err(FutuError::NotInitialized);
256 }
257
258 let resp = tokio::time::timeout(Duration::from_secs(12), resp_rx)
260 .await
261 .map_err(|_| {
262 self.pending.remove(&serial);
263 FutuError::Timeout
264 })?
265 .map_err(|_| FutuError::Codec("response channel closed".into()))?;
266
267 Ok(resp)
268 }
269
270 pub fn conn_id(&self) -> Option<u64> {
275 let conn_id = self.conn_id.load(Ordering::Relaxed);
276 (conn_id != 0).then_some(conn_id)
277 }
278
279 fn next_serial(&self) -> u32 {
280 self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
281 }
282
283 fn prepare_body(&self, plaintext: &[u8]) -> (Vec<u8>, [u8; 20]) {
284 use sha1::{Digest, Sha1};
285
286 let mut hasher = Sha1::new();
288 hasher.update(plaintext);
289 let sha1_result = hasher.finalize();
290 let mut sha1 = [0u8; 20];
291 sha1.copy_from_slice(&sha1_result);
292
293 let body = if self.config.rsa_key.is_some() {
295 let key = self.conn_aes_key.lock();
296 match key.as_ref() {
297 Some(k) => encrypt::aes_ecb_encrypt(k, plaintext),
298 None => plaintext.to_vec(),
299 }
300 } else {
301 plaintext.to_vec()
302 };
303
304 (body, sha1)
305 }
306
307 fn start_background_tasks(&mut self, conn: Connection, keep_alive_interval: Duration) {
308 let (cmd_tx, cmd_rx) = mpsc::channel(256);
309 self.cmd_tx = Some(cmd_tx.clone());
310
311 let pending = Arc::clone(&self.pending);
312 let push_tx = self.push_tx.clone();
313 let aes_key = if self.config.rsa_key.is_some() {
314 *self.conn_aes_key.lock()
315 } else {
316 None
317 };
318
319 tokio::spawn(async move {
320 run_event_loop(conn, cmd_rx, pending, push_tx, aes_key, keep_alive_interval).await;
321 });
322 }
323}
324
325async fn run_event_loop(
327 mut conn: Connection,
328 mut cmd_rx: mpsc::Receiver<ClientCommand>,
329 pending: Arc<DashMap<u32, PendingRequest>>,
330 push_tx: mpsc::UnboundedSender<PushMessage>,
331 aes_key: Option<[u8; 16]>,
332 keep_alive_interval: Duration,
333) {
334 let mut heartbeat = tokio::time::interval(keep_alive_interval);
335 heartbeat.tick().await; let mut heartbeat_serial: u32 = 10_000_000; loop {
339 tokio::select! {
340 result = conn.recv() => {
342 match result {
343 Ok(Some(frame)) => {
344 let proto_id = frame.header.proto_id;
345 let body = match decode_client_inbound_body(&frame, aes_key.as_ref()) {
346 Ok(body) => body,
347 Err(e) => {
348 tracing::warn!(
349 error = %e,
350 serial = frame.header.serial_no,
351 proto_id,
352 "decrypt failed, dropping inbound frame"
353 );
354 release_pending_on_inbound_decode_error(&frame, &pending);
355 continue;
356 }
357 };
358
359 if proto_id::is_push_proto(proto_id) {
360 let _ = push_tx.send(PushMessage { proto_id, body });
362 } else {
363 let serial = frame.header.serial_no;
365 match pending.remove(&serial) { Some((_, req)) => {
366 let resp_frame = FutuFrame {
367 header: frame.header,
368 body,
369 };
370 let _ = req.tx.send(resp_frame);
371 } _ => {
372 tracing::debug!(serial = serial, proto_id = proto_id, "unmatched response");
373 }}
374 }
375 }
376 Ok(None) => {
377 tracing::warn!("connection closed by server");
378 break;
379 }
380 Err(e) => {
381 tracing::error!(error = %e, "recv error");
382 break;
383 }
384 }
385 }
386
387 cmd = cmd_rx.recv() => {
389 match cmd {
390 Some(ClientCommand::Send(frame, ack)) => {
391 let result = conn.send(frame).await;
392 if let Err(e) = &result {
393 tracing::error!(error = %e, "send failed");
394 }
395 let _ = ack.send(());
396 if result.is_err() {
397 break;
398 }
399 }
400 None => {
401 tracing::info!("shutting down event loop");
402 break;
403 }
404 }
405 }
406
407 _ = heartbeat.tick() => {
409 heartbeat_serial += 1;
410 let req = futu_proto::keep_alive::Request {
411 c2s: futu_proto::keep_alive::C2s {
412 time: chrono::Utc::now().timestamp(),
413 },
414 };
415 let body = prost::Message::encode_to_vec(&req);
416 let frame = Connection::build_frame(
417 proto_id::KEEP_ALIVE,
418 heartbeat_serial,
419 body,
420 );
421 if let Err(e) = conn.send(frame).await {
422 tracing::error!(error = %e, "heartbeat send failed");
423 break;
424 }
425 tracing::trace!("heartbeat sent");
426 }
427 }
428 }
429
430 pending.clear();
432 tracing::info!("event loop exited");
433}
434
435#[derive(Debug, Clone)]
437pub struct InitConnectInfo {
438 pub server_ver: i32,
439 pub login_user_id: u64,
440 pub conn_id: u64,
441 pub conn_aes_key: String,
442 pub keep_alive_interval: i32,
443}
444
445pub struct ReconnectingClient {
447 config: ClientConfig,
448 policy: ReconnectPolicy,
449}
450
451impl ReconnectingClient {
452 pub fn new(config: ClientConfig) -> Self {
453 Self {
454 config,
455 policy: ReconnectPolicy::default_policy(),
456 }
457 }
458
459 pub fn with_policy(mut self, policy: ReconnectPolicy) -> Self {
460 self.policy = policy;
461 self
462 }
463
464 pub async fn connect(
469 &mut self,
470 ) -> Result<(
471 FutuClient,
472 mpsc::UnboundedReceiver<PushMessage>,
473 InitConnectInfo,
474 )> {
475 loop {
476 let (mut client, push_rx) = FutuClient::new(self.config.clone());
477 match client.connect().await {
478 Ok(info) => {
479 self.policy.reset();
480 return Ok((client, push_rx, info));
481 }
482 Err(e) => {
483 tracing::warn!(
484 error = %e,
485 attempt = self.policy.attempts(),
486 "connection failed"
487 );
488 match self.policy.next_delay() {
489 Some(delay) => {
490 tracing::info!(delay_ms = delay.as_millis(), "reconnecting...");
491 tokio::time::sleep(delay).await;
492 }
493 None => {
494 return Err(FutuError::Codec(format!(
495 "max retries reached after {} attempts",
496 self.policy.attempts()
497 )));
498 }
499 }
500 }
501 }
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests;
508
509fn hex_str(bytes: &[u8]) -> String {
510 bytes.iter().map(|b| format!("{b:02x}")).collect()
511}
512
513fn hex_decode_16(hex_bytes: &[u8]) -> Option<[u8; 16]> {
514 if hex_bytes.len() != 32 {
515 return None;
516 }
517 let hex_str = std::str::from_utf8(hex_bytes).ok()?;
518 let mut key = [0u8; 16];
519 for i in 0..16 {
520 key[i] = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).ok()?;
521 }
522 Some(key)
523}