futu_gateway/handlers/
mod.rs

1pub mod qot;
2pub mod sys;
3pub mod trd;
4
5// 通用响应构建辅助
6
7/// 从 SharedBackend 加载当前后端连接(原子读取,支持重连后自动更新)
8pub fn load_backend(
9    shared: &crate::bridge::SharedBackend,
10) -> Option<std::sync::Arc<futu_backend::conn::BackendConn>> {
11    let guard = shared.load();
12    guard.as_ref().clone()
13}
14
15/// 构建成功响应(ret_type=0, s2c 为 protobuf 编码后的 bytes)
16pub fn make_success_response<M: prost::Message>(resp: &M) -> Vec<u8> {
17    prost::Message::encode_to_vec(resp)
18}
19
20/// 解码后端响应 protobuf,兼容可能存在的 4 字节长度前缀
21///
22/// 部分后端协议(如 CMD 5120/5121)的响应可能包含 OMBinSrz 4 字节大端长度前缀,
23/// 也可能是裸 protobuf。此函数先尝试裸 protobuf 解码,失败后尝试跳过前 4 字节。
24pub fn decode_backend_proto<M: prost::Message + Default>(
25    body: &[u8],
26) -> Result<M, prost::DecodeError> {
27    // 先尝试裸 protobuf 解码
28    match prost::Message::decode(body) {
29        Ok(m) => Ok(m),
30        Err(e1) => {
31            // 如果 body 长度 >= 4,尝试跳过 4 字节长度前缀
32            if body.len() >= 4 {
33                match prost::Message::decode(&body[4..]) {
34                    Ok(m) => {
35                        tracing::debug!(
36                            body_len = body.len(),
37                            "decoded backend proto after skipping 4-byte length prefix"
38                        );
39                        Ok(m)
40                    }
41                    Err(_) => Err(e1), // 返回原始错误
42                }
43            } else {
44                Err(e1)
45            }
46        }
47    }
48}
49
50/// 从 raw protobuf 中提取指定 field number 的所有 LengthDelimited 值,
51/// 然后将每个值解码为 prost::Message。
52/// 用于处理后端使用了不同于 proto 文件定义的 field number 的情况。
53pub fn extract_repeated_field<M: prost::Message + Default>(
54    body: &[u8],
55    target_field: u32,
56) -> Vec<M> {
57    let mut results = Vec::new();
58    let mut pos = 0;
59    while pos < body.len() {
60        // 解析 tag (varint)
61        let (tag, new_pos) = match decode_varint(body, pos) {
62            Some(v) => v,
63            None => break,
64        };
65        let field_number = (tag >> 3) as u32;
66        let wire_type = (tag & 0x7) as u8;
67
68        match wire_type {
69            0 => {
70                // varint — skip
71                match decode_varint(body, new_pos) {
72                    Some((_, p)) => pos = p,
73                    None => break,
74                }
75            }
76            1 => {
77                // 64-bit — skip 8 bytes
78                pos = new_pos + 8;
79            }
80            2 => {
81                // length-delimited
82                let (length, data_start) = match decode_varint(body, new_pos) {
83                    Some(v) => (v.0 as usize, v.1),
84                    None => break,
85                };
86                if data_start + length > body.len() {
87                    break;
88                }
89                if field_number == target_field {
90                    if let Ok(msg) = prost::Message::decode(&body[data_start..data_start + length])
91                    {
92                        results.push(msg);
93                    }
94                }
95                pos = data_start + length;
96            }
97            5 => {
98                // 32-bit — skip 4 bytes
99                pos = new_pos + 4;
100            }
101            _ => break,
102        }
103    }
104    results
105}
106
107/// 从 SRPC 封装的响应体中提取 field 5 的数据并解码为指定消息类型。
108/// 后端某些命令(CMD 5120/5121)的响应被 SRPC envelope 包装,实际数据在 field 5。
109pub fn extract_field5_message<M: prost::Message + Default>(body: &[u8]) -> Option<M> {
110    extract_field5_message_with(body, |_| true)
111}
112
113/// 从 SRPC field 5 解码消息,使用 validator 验证结果。
114/// 如果 field 5 直接解码无效,还会尝试从 field 5 → field 4 提取(嵌套 SRPC)。
115pub fn extract_field5_validated<M: prost::Message + Default>(
116    body: &[u8],
117    validator: impl Fn(&M) -> bool,
118) -> Option<M> {
119    extract_field5_message_with(body, validator)
120}
121
122fn extract_field5_message_with<M: prost::Message + Default>(
123    body: &[u8],
124    validator: impl Fn(&M) -> bool,
125) -> Option<M> {
126    let field5_data = extract_raw_field(body, 5)?;
127
128    // 尝试直接解码 field 5 数据
129    if let Ok(m) = prost::Message::decode(field5_data) {
130        if validator(&m) {
131            return Some(m);
132        }
133    }
134
135    // field 5 内部可能还有嵌套封装(如 SRPC service envelope:field 1=version, field 2=service, field 4=data)
136    // 尝试从 field 5 的 field 4 中提取
137    if let Some(inner) = extract_raw_field(field5_data, 4) {
138        if let Ok(m) = prost::Message::decode(inner) {
139            if validator(&m) {
140                tracing::debug!("decoded message from SRPC field 5 → inner field 4");
141                return Some(m);
142            }
143        }
144    }
145
146    None
147}
148
149/// 从 protobuf 消息中提取指定 field number 的第一个 length-delimited 数据切片
150fn extract_raw_field(body: &[u8], target_field: u32) -> Option<&[u8]> {
151    let mut pos = 0;
152    while pos < body.len() {
153        let (tag, new_pos) = decode_varint(body, pos)?;
154        let field_number = (tag >> 3) as u32;
155        let wire_type = (tag & 0x7) as u8;
156
157        match wire_type {
158            0 => {
159                let (_, p) = decode_varint(body, new_pos)?;
160                pos = p;
161            }
162            1 => {
163                pos = new_pos + 8;
164            }
165            2 => {
166                let (length, data_start) = decode_varint(body, new_pos)?;
167                let length = length as usize;
168                if data_start + length > body.len() {
169                    return None;
170                }
171                if field_number == target_field {
172                    return Some(&body[data_start..data_start + length]);
173                }
174                pos = data_start + length;
175            }
176            5 => {
177                pos = new_pos + 4;
178            }
179            _ => return None,
180        }
181    }
182    None
183}
184
185fn decode_varint(body: &[u8], start: usize) -> Option<(u64, usize)> {
186    let mut result: u64 = 0;
187    let mut shift = 0;
188    let mut pos = start;
189    loop {
190        if pos >= body.len() {
191            return None;
192        }
193        let b = body[pos];
194        pos += 1;
195        result |= ((b & 0x7f) as u64) << shift;
196        if b & 0x80 == 0 {
197            return Some((result, pos));
198        }
199        shift += 7;
200        if shift >= 64 {
201            return None;
202        }
203    }
204}
205
206/// 统一 SRPC 封装解码:先尝试标准解码,如果结果无效则尝试 SRPC field 5 解码。
207///
208/// - `body`: 后端响应体
209/// - `validator`: 验证解码结果是否有效的闭包。返回 true 表示结果有效,直接使用;
210///   返回 false 表示结果无效,继续尝试 SRPC field 5 解码。
211///
212/// 适用于 CMD 5120/5121 等被 SRPC envelope 包装的后端命令。
213pub fn decode_srpc_or_direct<M: prost::Message + Default>(
214    body: &[u8],
215    validator: impl Fn(&M) -> bool,
216) -> M {
217    // 1. 尝试标准解码(裸 protobuf + 跳过 4 字节前缀)
218    if let Ok(r) = decode_backend_proto::<M>(body) {
219        if validator(&r) {
220            return r;
221        }
222    }
223
224    // 2. 标准解码无效 → 从 SRPC envelope 的 field 5 提取(带 validator)
225    if let Some(r) = extract_field5_validated::<M>(body, &validator) {
226        tracing::debug!("decoded message from SRPC field 5");
227        return r;
228    }
229
230    // 3. fallback: 返回默认值
231    tracing::warn!(
232        body_len = body.len(),
233        "SRPC decode: all attempts returned invalid data, using default"
234    );
235    M::default()
236}
237
238/// 多语言名称(手动解析,不通过 prost)
239#[derive(Debug, Clone)]
240pub struct MultiLangName {
241    pub language_id: i32,
242    pub name: String,
243}
244
245/// CMD5121 专用解码:SRPC field 5 包含的是 repeated GroupInfo(不是 GetGroupListResp)。
246///
247/// GroupInfo 的 field 4 (multi_lang_name) 包含后端数据导致 prost 整体解码失败,
248/// 因此 proto 中不定义 field 4。本函数先用 prost 解码 fields 1-3,
249/// 再手动从原始字节中提取 field 4 的 MultiLanguageName。
250pub fn decode_cmd5121_groups(
251    body: &[u8],
252) -> (
253    futu_backend::proto_internal::wch_lst::GetGroupListResp,
254    Vec<Vec<MultiLangName>>,
255) {
256    use futu_backend::proto_internal::wch_lst::{GetGroupListResp, GroupInfo};
257
258    // 1. 尝试标准解码
259    if let Ok(r) = decode_backend_proto::<GetGroupListResp>(body) {
260        if r.result_code.is_some() && !r.group_list.is_empty() {
261            let empty_langs = vec![vec![]; r.group_list.len()];
262            return (r, empty_langs);
263        }
264    }
265
266    // 2. 检查 SRPC metadata (field 3) 中的 result_code
267    //    SRPC field 3 内部: field 1 = result_code (0=成功, 非0=错误)
268    let srpc_ok = extract_raw_field_bytes(body, 3)
269        .first()
270        .and_then(|meta| {
271            // meta 内部 field 1 是 varint = result_code
272            let mut pos = 0;
273            while pos < meta.len() {
274                let (tag, new_pos) = decode_varint(meta, pos)?;
275                let field_number = (tag >> 3) as u32;
276                let wire_type = (tag & 0x7) as u8;
277                match wire_type {
278                    0 => {
279                        let (val, p) = decode_varint(meta, new_pos)?;
280                        if field_number == 1 {
281                            return Some(val == 0);
282                        }
283                        pos = p;
284                    }
285                    2 => {
286                        let (len, start) = decode_varint(meta, new_pos)?;
287                        pos = start + len as usize;
288                    }
289                    _ => break,
290                }
291            }
292            Some(true) // 没找到 result_code,默认成功
293        })
294        .unwrap_or(true);
295
296    if !srpc_ok {
297        // SRPC metadata 报错,但 field 5 可能仍然包含有效数据
298        // 尝试继续解码而非直接返回空
299        let meta = extract_raw_field_bytes(body, 3);
300        let hex: String = meta
301            .first()
302            .map(|m| {
303                m.iter()
304                    .map(|b| format!("{b:02x}"))
305                    .collect::<Vec<_>>()
306                    .join(" ")
307            })
308            .unwrap_or_default();
309        tracing::warn!(body_len = body.len(), srpc_meta_hex = %hex, "CMD5121: SRPC metadata error, trying field 5 anyway");
310    }
311
312    // 3. SRPC field 5 包含 repeated GroupInfo,用 extract_repeated_field 解码
313    //    (proto 中不定义 field 4,prost 自动跳过)
314    let group_list: Vec<GroupInfo> = extract_repeated_field(body, 5);
315    if group_list.is_empty() {
316        tracing::warn!(body_len = body.len(), "CMD5121 decode: no GroupInfo found");
317        return (GetGroupListResp::default(), vec![]);
318    }
319
320    // 4. 同时提取 field 5 的原始字节,从中手动解析 field 4 (multi_lang_name)
321    let raw_items: Vec<&[u8]> = extract_raw_field_bytes(body, 5);
322    let lang_list: Vec<Vec<MultiLangName>> = raw_items
323        .iter()
324        .map(|raw| extract_multi_lang_names(raw))
325        .collect();
326
327    tracing::debug!(
328        count = group_list.len(),
329        "decoded GroupInfo from SRPC field 5 with multi_lang_name"
330    );
331    let resp = GetGroupListResp {
332        result_code: Some(0),
333        group_count: Some(group_list.len() as u32),
334        group_list,
335        max_reached: None,
336    };
337    (resp, lang_list)
338}
339
340/// 从 raw protobuf 中提取指定 field number 的所有 LengthDelimited 值的原始字节切片。
341fn extract_raw_field_bytes(body: &[u8], target_field: u32) -> Vec<&[u8]> {
342    let mut results = Vec::new();
343    let mut pos = 0;
344    while pos < body.len() {
345        let (tag, new_pos) = match decode_varint(body, pos) {
346            Some(v) => v,
347            None => break,
348        };
349        let field_number = (tag >> 3) as u32;
350        let wire_type = (tag & 0x7) as u8;
351
352        match wire_type {
353            0 => match decode_varint(body, new_pos) {
354                Some((_, p)) => pos = p,
355                None => break,
356            },
357            1 => {
358                pos = new_pos + 8;
359            }
360            2 => {
361                let (length, data_start) = match decode_varint(body, new_pos) {
362                    Some(v) => (v.0 as usize, v.1),
363                    None => break,
364                };
365                if data_start + length > body.len() {
366                    break;
367                }
368                if field_number == target_field {
369                    results.push(&body[data_start..data_start + length]);
370                }
371                pos = data_start + length;
372            }
373            5 => {
374                pos = new_pos + 4;
375            }
376            _ => break,
377        }
378    }
379    results
380}
381
382/// 从 GroupInfo 原始字节中手动提取 field 4 (repeated MultiLanguageName)。
383///
384/// MultiLanguageName = { optional int32 language_id = 1; optional string name = 2; }
385/// 由于后端数据可能导致 prost 解码失败,这里用容错方式手动解析。
386fn extract_multi_lang_names(group_info_bytes: &[u8]) -> Vec<MultiLangName> {
387    let mut results = Vec::new();
388
389    // 提取所有 field 4 的 length-delimited 数据
390    let raw_items = extract_raw_field_bytes(group_info_bytes, 4);
391
392    for raw in raw_items {
393        // 手动解析 MultiLanguageName: field 1 = language_id (varint), field 2 = name (string)
394        let mut language_id: i32 = 0;
395        let mut name = String::new();
396        let mut pos = 0;
397
398        while pos < raw.len() {
399            let (tag, new_pos) = match decode_varint(raw, pos) {
400                Some(v) => v,
401                None => break,
402            };
403            let field_number = (tag >> 3) as u32;
404            let wire_type = (tag & 0x7) as u8;
405
406            match wire_type {
407                0 => {
408                    // varint
409                    let (val, p) = match decode_varint(raw, new_pos) {
410                        Some(v) => v,
411                        None => break,
412                    };
413                    if field_number == 1 {
414                        language_id = val as i32;
415                    }
416                    pos = p;
417                }
418                1 => {
419                    pos = new_pos + 8;
420                }
421                2 => {
422                    let (length, data_start) = match decode_varint(raw, new_pos) {
423                        Some(v) => (v.0 as usize, v.1),
424                        None => break,
425                    };
426                    if data_start + length > raw.len() {
427                        break;
428                    }
429                    if field_number == 2 {
430                        name = String::from_utf8_lossy(&raw[data_start..data_start + length])
431                            .to_string();
432                    }
433                    pos = data_start + length;
434                }
435                5 => {
436                    pos = new_pos + 4;
437                }
438                _ => break,
439            }
440        }
441
442        if !name.is_empty() {
443            results.push(MultiLangName { language_id, name });
444        }
445    }
446
447    results
448}
449
450/// 辅助函数:将字节切片转为 hex 预览字符串(最多显示 max_bytes 字节)
451#[allow(dead_code)]
452fn hex_preview(data: &[u8], max_bytes: usize) -> String {
453    let show = &data[..data.len().min(max_bytes)];
454    let hex: String = show
455        .iter()
456        .map(|b| format!("{:02x}", b))
457        .collect::<Vec<_>>()
458        .join(" ");
459    if data.len() > max_bytes {
460        format!("{}... ({} bytes total)", hex, data.len())
461    } else {
462        hex
463    }
464}
465
466/// 构建错误响应
467pub fn make_error_response(ret_type: i32, msg: &str) -> Vec<u8> {
468    let resp = futu_proto::init_connect::Response {
469        ret_type,
470        ret_msg: Some(msg.to_string()),
471        err_code: None,
472        s2c: None,
473    };
474    prost::Message::encode_to_vec(&resp)
475}