Skip to main content

futu_backend/auth/
webtcp.rs

1//! WebTCP-short request path for broker auth.
2//!
3//! C++ FTLogin sends broker-auth HTTP payloads through a short-lived TLS
4//! channel before falling back to ordinary HTTP. The wire command is 65507 and
5//! the protobuf body is an internal `TcpHttpRequest`.
6
7use crate::conn::{BackendConn, PushCallback};
8use futu_core::error::{FutuError, Result};
9use rustls_pki_types::ServerName;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::TcpStream;
13use tokio_rustls::TlsConnector;
14use tokio_rustls::rustls::{ClientConfig, RootCertStore};
15
16const WEB_REQUEST_CMD: u16 = 65507;
17const WEBTCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
18const WEBTCP_TLS_TIMEOUT: Duration = Duration::from_secs(10);
19static RUSTLS_PROVIDER_INIT: std::sync::Once = std::sync::Once::new();
20
21/// rustls 0.23 requires an explicit process-level crypto provider when more
22/// than one provider feature is present. WebTCP uses rustls directly, so do
23/// this before the first `ClientConfig::builder()` call.
24pub fn install_default_rustls_crypto_provider() {
25    RUSTLS_PROVIDER_INIT.call_once(|| {
26        let _ = tokio_rustls::rustls::crypto::ring::default_provider().install_default();
27    });
28}
29
30#[derive(Debug)]
31pub(super) struct ProtoError(String);
32
33impl std::fmt::Display for ProtoError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        self.0.fmt(f)
36    }
37}
38
39impl std::error::Error for ProtoError {}
40
41pub(super) type ProtoResult<T> = std::result::Result<T, ProtoError>;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub(super) enum WireType {
45    Varint = 0,
46    Fixed64 = 1,
47    LengthDelimited = 2,
48    Fixed32 = 5,
49}
50
51impl WireType {
52    fn from_u64(value: u64) -> ProtoResult<Self> {
53        match value {
54            0 => Ok(Self::Varint),
55            1 => Ok(Self::Fixed64),
56            2 => Ok(Self::LengthDelimited),
57            5 => Ok(Self::Fixed32),
58            other => Err(ProtoError(format!("unsupported wire type {other}"))),
59        }
60    }
61}
62
63pub(super) struct ProtoWriter {
64    buf: Vec<u8>,
65}
66
67impl ProtoWriter {
68    pub(super) fn new() -> Self {
69        Self { buf: Vec::new() }
70    }
71
72    pub(super) fn finish(self) -> Vec<u8> {
73        self.buf
74    }
75
76    pub(super) fn write_key(&mut self, field: u32, wire_type: WireType) {
77        self.write_varint(((field as u64) << 3) | wire_type as u64);
78    }
79
80    pub(super) fn write_varint(&mut self, mut value: u64) {
81        while value >= 0x80 {
82            self.buf.push((value as u8) | 0x80);
83            value >>= 7;
84        }
85        self.buf.push(value as u8);
86    }
87
88    pub(super) fn write_bytes(&mut self, field: u32, value: &[u8]) {
89        self.write_key(field, WireType::LengthDelimited);
90        self.write_varint(value.len() as u64);
91        self.buf.extend_from_slice(value);
92    }
93
94    pub(super) fn write_string(&mut self, field: u32, value: &str) {
95        self.write_bytes(field, value.as_bytes());
96    }
97
98    pub(super) fn write_message(&mut self, field: u32, value: Vec<u8>) {
99        self.write_bytes(field, &value);
100    }
101}
102
103pub(super) struct ProtoReader<'a> {
104    data: &'a [u8],
105    pos: usize,
106}
107
108impl<'a> ProtoReader<'a> {
109    pub(super) fn new(data: &'a [u8]) -> Self {
110        Self { data, pos: 0 }
111    }
112
113    fn eof(&self) -> bool {
114        self.pos >= self.data.len()
115    }
116
117    fn read_byte(&mut self) -> ProtoResult<u8> {
118        let byte = *self
119            .data
120            .get(self.pos)
121            .ok_or_else(|| ProtoError("unexpected eof".into()))?;
122        self.pos += 1;
123        Ok(byte)
124    }
125
126    pub(super) fn read_varint(&mut self) -> ProtoResult<u64> {
127        let mut value = 0u64;
128        for shift in (0..64).step_by(7) {
129            let byte = self.read_byte()? as u64;
130            value |= (byte & 0x7f) << shift;
131            if byte & 0x80 == 0 {
132                return Ok(value);
133            }
134        }
135        Err(ProtoError("varint too long".into()))
136    }
137
138    pub(super) fn next_key(&mut self) -> ProtoResult<Option<(u32, WireType)>> {
139        if self.eof() {
140            return Ok(None);
141        }
142        let key = self.read_varint()?;
143        Ok(Some(((key >> 3) as u32, WireType::from_u64(key & 0x07)?)))
144    }
145
146    pub(super) fn read_bytes(&mut self) -> ProtoResult<&'a [u8]> {
147        let len = self.read_varint()? as usize;
148        let end = self
149            .pos
150            .checked_add(len)
151            .ok_or_else(|| ProtoError("length overflow".into()))?;
152        if end > self.data.len() {
153            return Err(ProtoError(format!(
154                "length-delimited out of bounds len={len} remaining={}",
155                self.data.len().saturating_sub(self.pos)
156            )));
157        }
158        let out = &self.data[self.pos..end];
159        self.pos = end;
160        Ok(out)
161    }
162
163    pub(super) fn read_string(&mut self) -> ProtoResult<String> {
164        let bytes = self.read_bytes()?;
165        std::str::from_utf8(bytes)
166            .map(|s| s.to_string())
167            .map_err(|e| ProtoError(format!("utf8: {e}")))
168    }
169
170    pub(super) fn read_i32(&mut self) -> ProtoResult<i32> {
171        Ok(self.read_varint()? as i32)
172    }
173
174    fn read_bool(&mut self) -> ProtoResult<bool> {
175        Ok(self.read_varint()? != 0)
176    }
177
178    pub(super) fn skip_field(&mut self, wire_type: WireType) -> ProtoResult<()> {
179        match wire_type {
180            WireType::Varint => {
181                let _ = self.read_varint()?;
182            }
183            WireType::Fixed64 => {
184                self.skip_bytes(8)?;
185            }
186            WireType::LengthDelimited => {
187                let len = self.read_varint()? as usize;
188                self.skip_bytes(len)?;
189            }
190            WireType::Fixed32 => {
191                self.skip_bytes(4)?;
192            }
193        }
194        Ok(())
195    }
196
197    fn skip_bytes(&mut self, len: usize) -> ProtoResult<()> {
198        let end = self
199            .pos
200            .checked_add(len)
201            .ok_or_else(|| ProtoError("skip overflow".into()))?;
202        if end > self.data.len() {
203            return Err(ProtoError("skip out of bounds".into()));
204        }
205        self.pos = end;
206        Ok(())
207    }
208}
209
210#[derive(Debug, Clone)]
211struct TcpHttpHeader {
212    key: String,
213    value: String,
214}
215
216impl TcpHttpHeader {
217    fn encode(&self) -> Vec<u8> {
218        let mut writer = ProtoWriter::new();
219        writer.write_string(1, &self.key);
220        writer.write_string(2, &self.value);
221        writer.finish()
222    }
223}
224
225#[derive(Debug, Clone)]
226struct TcpHttpRequest {
227    method: String,
228    url: String,
229    headers: Vec<TcpHttpHeader>,
230    body: Vec<u8>,
231}
232
233impl TcpHttpRequest {
234    fn encode(&self) -> Vec<u8> {
235        let mut writer = ProtoWriter::new();
236        writer.write_string(1, &self.method);
237        writer.write_string(2, &self.url);
238        for header in &self.headers {
239            writer.write_message(3, header.encode());
240        }
241        writer.write_bytes(4, &self.body);
242        writer.finish()
243    }
244}
245
246#[derive(Debug, Clone)]
247struct TcpHttpResponseBody {
248    status_code: Option<i32>,
249    message: Option<String>,
250    data: Option<Vec<u8>>,
251}
252
253impl TcpHttpResponseBody {
254    fn decode(bytes: &[u8]) -> ProtoResult<Self> {
255        let mut reader = ProtoReader::new(bytes);
256        let mut status_code = None;
257        let mut message = None;
258        let mut data = None;
259        while let Some((field, wire_type)) = reader.next_key()? {
260            match (field, wire_type) {
261                (1, WireType::Varint) => status_code = Some(reader.read_i32()?),
262                (2, WireType::LengthDelimited) => message = Some(reader.read_string()?),
263                (5, WireType::LengthDelimited) => data = Some(reader.read_bytes()?.to_vec()),
264                _ => reader.skip_field(wire_type)?,
265            }
266        }
267        Ok(Self {
268            status_code,
269            message,
270            data,
271        })
272    }
273}
274
275#[derive(Debug, Clone)]
276struct TcpHttpResponse {
277    code: Option<i32>,
278    message: Option<String>,
279    response_body: Option<TcpHttpResponseBody>,
280    current_need_fallback: Option<bool>,
281}
282
283impl TcpHttpResponse {
284    fn decode(bytes: &[u8]) -> ProtoResult<Self> {
285        let mut reader = ProtoReader::new(bytes);
286        let mut code = None;
287        let mut message = None;
288        let mut response_body = None;
289        let mut current_need_fallback = None;
290        while let Some((field, wire_type)) = reader.next_key()? {
291            match (field, wire_type) {
292                (1, WireType::Varint) => code = Some(reader.read_i32()?),
293                (2, WireType::LengthDelimited) => message = Some(reader.read_string()?),
294                (3, WireType::LengthDelimited) => {
295                    response_body = Some(TcpHttpResponseBody::decode(reader.read_bytes()?)?)
296                }
297                (4, WireType::Varint) => current_need_fallback = Some(reader.read_bool()?),
298                _ => reader.skip_field(wire_type)?,
299            }
300        }
301        Ok(Self {
302            code,
303            message,
304            response_body,
305            current_need_fallback,
306        })
307    }
308}
309
310/// C++ uses wildcard certificate domains; rustls needs a concrete SNI name.
311pub(super) fn sni_for_web_identity(identity: u32) -> &'static str {
312    match identity {
313        crate::auth::commconfig::CONN_WEB_CN | crate::auth::commconfig::CONN_WEB_HK => {
314            "www.futunn.com"
315        }
316        _ => "www.moomoo.com",
317    }
318}
319
320fn tls_connector() -> TlsConnector {
321    install_default_rustls_crypto_provider();
322
323    let roots = RootCertStore {
324        roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
325    };
326    let config = ClientConfig::builder()
327        .with_root_certificates(roots)
328        .with_no_client_auth();
329    TlsConnector::from(Arc::new(config))
330}
331
332fn proto_err(context: &str, err: ProtoError) -> FutuError {
333    FutuError::Codec(format!("webtcp {context}: {err}"))
334}
335
336fn decode_response_json(frame_body: &[u8]) -> Result<serde_json::Value> {
337    let response = TcpHttpResponse::decode(frame_body).map_err(|e| proto_err("decode", e))?;
338    if response.code.unwrap_or(-1) != 0 {
339        return Err(FutuError::Codec(format!(
340            "webtcp response code={} message={:?} fallback={:?}",
341            response.code.unwrap_or(-1),
342            response.message,
343            response.current_need_fallback
344        )));
345    }
346    if response.current_need_fallback.unwrap_or(false) {
347        return Err(FutuError::Codec(format!(
348            "webtcp response requested fallback message={:?}",
349            response.message
350        )));
351    }
352    let response_body = response
353        .response_body
354        .ok_or_else(|| FutuError::Codec("webtcp missing response_body".into()))?;
355    let status_code = response_body.status_code.unwrap_or(0);
356    if !(200..300).contains(&status_code) {
357        return Err(FutuError::Codec(format!(
358            "webtcp http status={status_code} message={:?}",
359            response_body.message
360        )));
361    }
362    let data = response_body
363        .data
364        .ok_or_else(|| FutuError::Codec("webtcp missing http response data".into()))?;
365    serde_json::from_slice(&data).map_err(|e| FutuError::Codec(format!("webtcp json decode: {e}")))
366}
367
368/// C++ `FTAuthImpl::InitRequest` builds one request object for WebTCP and HTTP,
369/// then `SetHttpHeaders` adds the `X-Futu-*` auth headers before dispatch.
370/// Ref: `FTLogin/Src/ftlogin/auth/impl/auth_impl.cpp:2415-2422,3014-3030`.
371fn cpp_auth_headers(client_type: u8, host: String) -> Vec<TcpHttpHeader> {
372    vec![
373        TcpHttpHeader {
374            key: "Host".to_string(),
375            value: host,
376        },
377        TcpHttpHeader {
378            key: "Content-Type".to_string(),
379            value: "application/json".to_string(),
380        },
381        TcpHttpHeader {
382            key: "X-Futu-Client-Type".to_string(),
383            value: client_type.to_string(),
384        },
385        TcpHttpHeader {
386            key: "X-Futu-Client-Version".to_string(),
387            value: crate::conn::BackendConn::CLIENT_VER_FTGTW.to_string(),
388        },
389        TcpHttpHeader {
390            key: "X-Futu-Client-Lang".to_string(),
391            value: "sc".to_string(),
392        },
393    ]
394}
395
396pub(super) async fn connect_webtcp(ip: &str, port: u16, sni: &'static str) -> Result<BackendConn> {
397    let addr = format!("{ip}:{port}");
398    let stream = tokio::time::timeout(WEBTCP_CONNECT_TIMEOUT, TcpStream::connect(&addr))
399        .await
400        .map_err(|_| {
401            FutuError::Network(std::io::Error::new(
402                std::io::ErrorKind::TimedOut,
403                format!("webtcp connect {addr} timed out"),
404            ))
405        })?
406        .map_err(FutuError::Network)?;
407    stream.set_nodelay(true)?;
408
409    let server_name = ServerName::try_from(sni.to_string())
410        .map_err(|e| FutuError::Codec(format!("webtcp invalid tls server name {sni}: {e}")))?;
411    let tls = tokio::time::timeout(
412        WEBTCP_TLS_TIMEOUT,
413        tls_connector().connect(server_name, stream),
414    )
415    .await
416    .map_err(|_| {
417        FutuError::Network(std::io::Error::new(
418            std::io::ErrorKind::TimedOut,
419            format!("webtcp tls connect {addr} timed out"),
420        ))
421    })?
422    .map_err(|e| FutuError::Codec(format!("webtcp tls connect {addr}: {e}")))?;
423
424    let noop: PushCallback = Arc::new(|_, _| {});
425    BackendConn::from_stream_inner(tls, noop)
426}
427
428/// POST JSON through FTLogin WebTCP-short.
429pub(crate) async fn post_json_via_webtcp(
430    client_type: u8,
431    web_identity: u32,
432    addrs: &[(String, u16)],
433    url: &str,
434    body: &serde_json::Value,
435) -> Result<serde_json::Value> {
436    if addrs.is_empty() {
437        return Err(FutuError::Codec(format!(
438            "webtcp identity {web_identity} has empty addr pool"
439        )));
440    }
441    let parsed = reqwest::Url::parse(url)
442        .map_err(|e| FutuError::Codec(format!("webtcp invalid url {url}: {e}")))?;
443    let host = parsed
444        .host_str()
445        .ok_or_else(|| FutuError::Codec(format!("webtcp url missing host: {url}")))?
446        .to_string();
447    let payload = serde_json::to_vec(body)
448        .map_err(|e| FutuError::Codec(format!("webtcp json encode: {e}")))?;
449    let request = TcpHttpRequest {
450        method: "POST".to_string(),
451        url: url.to_string(),
452        headers: cpp_auth_headers(client_type, host),
453        body: payload,
454    };
455    let request_body = request.encode();
456
457    let sni = sni_for_web_identity(web_identity);
458    let mut last_error: Option<FutuError> = None;
459    for (ip, port) in addrs {
460        tracing::debug!(
461            web_identity,
462            ip,
463            port,
464            sni,
465            "broker_auth webtcp-short connecting"
466        );
467        let conn = match connect_webtcp(ip, *port, sni).await {
468            Ok(conn) => conn,
469            Err(e) => {
470                tracing::warn!(
471                    web_identity,
472                    ip,
473                    port,
474                    sni,
475                    error = %e,
476                    "broker_auth webtcp-short connect failed; trying next IP"
477                );
478                last_error = Some(e);
479                continue;
480            }
481        };
482        let frame = match conn.request(WEB_REQUEST_CMD, request_body.clone()).await {
483            Ok(frame) => frame,
484            Err(e) => {
485                tracing::warn!(
486                    web_identity,
487                    ip,
488                    port,
489                    sni,
490                    error = %e,
491                    "broker_auth webtcp-short request failed; trying next IP"
492                );
493                last_error = Some(e);
494                continue;
495            }
496        };
497        return decode_response_json(&frame.body);
498    }
499    Err(last_error.unwrap_or_else(|| {
500        FutuError::Network(std::io::Error::other(format!(
501            "webtcp identity {web_identity}: all IPs failed"
502        )))
503    }))
504}
505
506#[cfg(test)]
507mod tests;