1use std::collections::HashMap;
4use std::future::Future;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9
10use crate::conn::IncomingRequest;
11
12#[async_trait]
14pub trait RequestHandler: Send + Sync + 'static {
15 async fn handle(&self, conn_id: u64, request: &IncomingRequest) -> Option<Vec<u8>>;
18}
19
20pub struct FnHandler<F>(pub F);
25
26#[async_trait]
27impl<F, Fut> RequestHandler for FnHandler<F>
28where
29 F: Fn(u64, bytes::Bytes) -> Fut + Send + Sync + 'static,
30 Fut: Future<Output = Option<Vec<u8>>> + Send + 'static,
31{
32 async fn handle(&self, conn_id: u64, request: &IncomingRequest) -> Option<Vec<u8>> {
33 (self.0)(conn_id, request.body.clone()).await
34 }
35}
36
37pub struct RequestRouter {
39 handlers: RwLock<HashMap<u32, Arc<dyn RequestHandler>>>,
40}
41
42impl RequestRouter {
43 pub fn new() -> Self {
45 Self {
46 handlers: RwLock::new(HashMap::new()),
47 }
48 }
49
50 pub fn register(&self, proto_id: u32, handler: Arc<dyn RequestHandler>) {
52 self.handlers.write().insert(proto_id, handler);
53 }
54
55 pub async fn dispatch(&self, conn_id: u64, request: &IncomingRequest) -> Option<Vec<u8>> {
57 let handler = {
58 let handlers = self.handlers.read();
59 handlers.get(&request.proto_id).cloned()
60 };
61
62 match handler {
63 Some(h) => h.handle(conn_id, request).await,
64 None => {
65 tracing::warn!(
66 proto_id = request.proto_id,
67 conn_id = conn_id,
68 "no handler registered"
69 );
70 Some(make_error_response(-1, "unknown protocol"))
72 }
73 }
74 }
75}
76
77impl Default for RequestRouter {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83fn make_error_response(ret_type: i32, msg: &str) -> Vec<u8> {
85 let resp = futu_proto::init_connect::Response {
86 ret_type,
87 ret_msg: Some(msg.to_string()),
88 err_code: None,
89 s2c: None,
90 };
91 prost::Message::encode_to_vec(&resp)
92}