1use std::collections::HashSet;
4
5use chrono::{DateTime, NaiveTime, Utc};
6use rand::RngCore;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::limits::Limits;
11use crate::scope::Scope;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct KeyRecord {
16 pub id: String,
18 pub hash: String,
20 pub scopes: HashSet<Scope>,
21
22 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub allowed_markets: Option<HashSet<String>>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub allowed_symbols: Option<HashSet<String>>,
26 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub max_order_value: Option<f64>,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub max_daily_value: Option<f64>,
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub hours_window: Option<String>,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub max_orders_per_minute: Option<u32>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub allowed_trd_sides: Option<HashSet<String>>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub allowed_acc_ids: Option<HashSet<u64>>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub allowed_card_nums: Option<Vec<String>>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub expires_at: Option<DateTime<Utc>>,
65 pub created_at: DateTime<Utc>,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub note: Option<String>,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub allowed_machines: Option<Vec<String>>,
78
79 #[serde(skip)]
92 pub raw_explicit_acc_ids: Option<HashSet<u64>>,
93}
94
95impl KeyRecord {
96 #[must_use = "丢弃生成结果会丢失 plaintext; 调用方必须立即展示给用户"]
100 pub fn generate(
101 id: impl Into<String>,
102 scopes: HashSet<Scope>,
103 limits: Option<Limits>,
104 expires_at: Option<DateTime<Utc>>,
105 note: Option<String>,
106 ) -> (String, KeyRecord) {
107 Self::generate_with_machines(id, scopes, limits, expires_at, note, None)
108 }
109
110 #[must_use = "丢弃生成结果会丢失 plaintext; 调用方必须立即展示给用户"]
112 pub fn generate_with_machines(
113 id: impl Into<String>,
114 scopes: HashSet<Scope>,
115 limits: Option<Limits>,
116 expires_at: Option<DateTime<Utc>>,
117 note: Option<String>,
118 allowed_machines: Option<Vec<String>>,
119 ) -> (String, KeyRecord) {
120 let mut bytes = [0u8; 32];
121 rand::thread_rng().fill_bytes(&mut bytes);
122 let plaintext = hex::encode(bytes);
123 let hash = format!(
124 "sha256:{}",
125 hex::encode(Sha256::digest(plaintext.as_bytes()))
126 );
127 let limits = limits.unwrap_or_default();
128 let raw_explicit_acc_ids = limits.allowed_acc_ids.clone();
130 let record = KeyRecord {
131 id: id.into(),
132 hash,
133 scopes,
134 allowed_markets: limits.allowed_markets,
135 allowed_symbols: limits.allowed_symbols,
136 max_order_value: limits.max_order_value,
137 max_daily_value: limits.max_daily_value,
138 hours_window: limits.hours_window,
139 max_orders_per_minute: limits.max_orders_per_minute,
140 allowed_trd_sides: limits.allowed_trd_sides,
141 allowed_acc_ids: limits.allowed_acc_ids,
142 allowed_card_nums: limits.allowed_card_nums,
143 expires_at,
144 created_at: Utc::now(),
145 note,
146 allowed_machines,
147 raw_explicit_acc_ids,
148 };
149 (plaintext, record)
150 }
151
152 pub fn check_machine(&self) -> Result<(), crate::machine::MachineError> {
154 crate::machine::check(&self.id, self.allowed_machines.as_deref())
155 }
156
157 #[must_use]
159 pub fn matches(&self, plaintext: &str) -> bool {
160 let computed = hash_plaintext(plaintext);
161 let a = self.hash.as_bytes();
163 let b = computed.as_bytes();
164 if a.len() != b.len() {
165 return false;
166 }
167 let mut acc: u8 = 0;
168 for (x, y) in a.iter().zip(b.iter()) {
169 acc |= x ^ y;
170 }
171 acc == 0
172 }
173
174 #[must_use]
176 pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
177 self.expires_at.map(|t| now >= t).unwrap_or(false)
178 }
179
180 pub fn hours_range(&self) -> Result<Option<(NaiveTime, NaiveTime)>, String> {
182 let Some(s) = &self.hours_window else {
183 return Ok(None);
184 };
185 let (l, r) = s
186 .split_once('-')
187 .ok_or_else(|| format!("invalid hours_window {s:?}: expect HH:MM-HH:MM"))?;
188 let parse = |p: &str| {
189 NaiveTime::parse_from_str(p.trim(), "%H:%M")
190 .map_err(|e| format!("invalid time {p:?}: {e}"))
191 };
192 Ok(Some((parse(l)?, parse(r)?)))
193 }
194
195 #[must_use]
197 pub fn limits(&self) -> Limits {
198 Limits {
199 allowed_markets: self.allowed_markets.clone(),
200 allowed_symbols: self.allowed_symbols.clone(),
201 max_order_value: self.max_order_value,
202 max_daily_value: self.max_daily_value,
203 hours_window: self.hours_window.clone(),
204 max_orders_per_minute: self.max_orders_per_minute,
205 allowed_trd_sides: self.allowed_trd_sides.clone(),
206 allowed_acc_ids: self.allowed_acc_ids.clone(),
207 allowed_card_nums: self.allowed_card_nums.clone(),
208 }
209 }
210}
211
212#[must_use]
214pub fn hash_plaintext(plaintext: &str) -> String {
215 format!(
216 "sha256:{}",
217 hex::encode(Sha256::digest(plaintext.as_bytes()))
218 )
219}
220
221#[cfg(test)]
222mod tests;