1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
9use std::sync::Arc;
10
11use axum::http::StatusCode;
12use axum::response::Json;
13use bytes::Bytes;
14use prost::Message;
15use serde_json::Value;
16
17use futu_auth::KeyStore;
18use futu_codec::header::ProtoFmtType;
19use futu_server::conn::IncomingRequest;
20use futu_server::router::RequestRouter;
21
22use crate::ws::WsBroadcaster;
23
24#[derive(Clone)]
26pub struct RestState {
27 pub router: Arc<RequestRouter>,
29 pub ws_broadcaster: Arc<WsBroadcaster>,
31 pub key_store: Arc<KeyStore>,
33 pub counters: Arc<futu_auth::RuntimeCounters>,
36 conn_id_counter: Arc<AtomicU64>,
38 serial_counter: Arc<AtomicU32>,
40}
41
42impl RestState {
43 pub fn new(router: Arc<RequestRouter>, ws_broadcaster: Arc<WsBroadcaster>) -> Self {
44 Self::with_key_store(router, ws_broadcaster, Arc::new(KeyStore::empty()))
45 }
46
47 pub fn with_key_store(
48 router: Arc<RequestRouter>,
49 ws_broadcaster: Arc<WsBroadcaster>,
50 key_store: Arc<KeyStore>,
51 ) -> Self {
52 Self::with_auth(
53 router,
54 ws_broadcaster,
55 key_store,
56 Arc::new(futu_auth::RuntimeCounters::new()),
57 )
58 }
59
60 pub fn with_auth(
62 router: Arc<RequestRouter>,
63 ws_broadcaster: Arc<WsBroadcaster>,
64 key_store: Arc<KeyStore>,
65 counters: Arc<futu_auth::RuntimeCounters>,
66 ) -> Self {
67 Self {
68 router,
69 ws_broadcaster,
70 key_store,
71 counters,
72 conn_id_counter: Arc::new(AtomicU64::new(10_000_000)),
73 serial_counter: Arc::new(AtomicU32::new(1)),
74 }
75 }
76
77 pub fn next_conn_id(&self) -> u64 {
79 self.conn_id_counter.fetch_add(1, Ordering::Relaxed)
80 }
81
82 fn next_serial(&self) -> u32 {
84 self.serial_counter.fetch_add(1, Ordering::Relaxed)
85 }
86}
87
88pub async fn proto_request<Req, Rsp>(
96 state: &RestState,
97 proto_id: u32,
98 json_body: Option<Value>,
99) -> Result<Json<Value>, (StatusCode, Json<Value>)>
100where
101 Req: Message + Default + serde::de::DeserializeOwned,
102 Rsp: Message + Default + serde::Serialize,
103{
104 let req_msg: Req = if let Some(body) = json_body {
106 serde_json::from_value(body).map_err(|e| {
107 (
108 StatusCode::BAD_REQUEST,
109 Json(serde_json::json!({
110 "error": format!("invalid request body: {e}")
111 })),
112 )
113 })?
114 } else {
115 Req::default()
116 };
117
118 let body = Bytes::from(req_msg.encode_to_vec());
120
121 let incoming = IncomingRequest {
123 conn_id: state.next_conn_id(),
124 proto_id,
125 serial_no: state.next_serial(),
126 proto_fmt_type: ProtoFmtType::Protobuf,
127 body,
128 };
129
130 let resp_bytes = state
131 .router
132 .dispatch(incoming.conn_id, &incoming)
133 .await
134 .ok_or_else(|| {
135 (
136 StatusCode::INTERNAL_SERVER_ERROR,
137 Json(serde_json::json!({
138 "error": "handler returned no response"
139 })),
140 )
141 })?;
142
143 let rsp_msg = Rsp::decode(Bytes::from(resp_bytes)).map_err(|e| {
145 (
146 StatusCode::INTERNAL_SERVER_ERROR,
147 Json(serde_json::json!({
148 "error": format!("failed to decode response: {e}")
149 })),
150 )
151 })?;
152
153 let json_rsp = serde_json::to_value(&rsp_msg).map_err(|e| {
155 (
156 StatusCode::INTERNAL_SERVER_ERROR,
157 Json(serde_json::json!({
158 "error": format!("failed to serialize response: {e}")
159 })),
160 )
161 })?;
162
163 Ok(Json(json_rsp))
164}
165
166#[derive(serde::Serialize)]
168pub struct ApiResponse<T: serde::Serialize> {
169 pub ret_type: i32,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub ret_msg: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub data: Option<T>,
174}