1use std::collections::HashSet;
4
5use chrono::{DateTime, NaiveTime, Utc};
6use rand::RngCore;
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::limits::Limits;
11use crate::scope::Scope;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct KeyRecord {
16 pub id: String,
18 pub hash: String,
20 pub scopes: HashSet<Scope>,
21
22 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub allowed_markets: Option<HashSet<String>>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
25 pub allowed_symbols: Option<HashSet<String>>,
26 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub max_order_value: Option<f64>,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub max_daily_value: Option<f64>,
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub hours_window: Option<String>,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub max_orders_per_minute: Option<u32>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub allowed_trd_sides: Option<HashSet<String>>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub expires_at: Option<DateTime<Utc>>,
41 pub created_at: DateTime<Utc>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub note: Option<String>,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub allowed_machines: Option<Vec<String>>,
54}
55
56impl KeyRecord {
57 pub fn generate(
61 id: impl Into<String>,
62 scopes: HashSet<Scope>,
63 limits: Option<Limits>,
64 expires_at: Option<DateTime<Utc>>,
65 note: Option<String>,
66 ) -> (String, KeyRecord) {
67 Self::generate_with_machines(id, scopes, limits, expires_at, note, None)
68 }
69
70 pub fn generate_with_machines(
72 id: impl Into<String>,
73 scopes: HashSet<Scope>,
74 limits: Option<Limits>,
75 expires_at: Option<DateTime<Utc>>,
76 note: Option<String>,
77 allowed_machines: Option<Vec<String>>,
78 ) -> (String, KeyRecord) {
79 let mut bytes = [0u8; 32];
80 rand::thread_rng().fill_bytes(&mut bytes);
81 let plaintext = hex::encode(bytes);
82 let hash = format!(
83 "sha256:{}",
84 hex::encode(Sha256::digest(plaintext.as_bytes()))
85 );
86 let limits = limits.unwrap_or_default();
87 let record = KeyRecord {
88 id: id.into(),
89 hash,
90 scopes,
91 allowed_markets: limits.allowed_markets,
92 allowed_symbols: limits.allowed_symbols,
93 max_order_value: limits.max_order_value,
94 max_daily_value: limits.max_daily_value,
95 hours_window: limits.hours_window,
96 max_orders_per_minute: limits.max_orders_per_minute,
97 allowed_trd_sides: limits.allowed_trd_sides,
98 expires_at,
99 created_at: Utc::now(),
100 note,
101 allowed_machines,
102 };
103 (plaintext, record)
104 }
105
106 pub fn check_machine(&self) -> Result<(), crate::machine::MachineError> {
108 crate::machine::check(&self.id, self.allowed_machines.as_deref())
109 }
110
111 pub fn matches(&self, plaintext: &str) -> bool {
113 let computed = hash_plaintext(plaintext);
114 let a = self.hash.as_bytes();
116 let b = computed.as_bytes();
117 if a.len() != b.len() {
118 return false;
119 }
120 let mut acc: u8 = 0;
121 for (x, y) in a.iter().zip(b.iter()) {
122 acc |= x ^ y;
123 }
124 acc == 0
125 }
126
127 pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
129 self.expires_at.map(|t| now >= t).unwrap_or(false)
130 }
131
132 pub fn hours_range(&self) -> Result<Option<(NaiveTime, NaiveTime)>, String> {
134 let Some(s) = &self.hours_window else {
135 return Ok(None);
136 };
137 let (l, r) = s
138 .split_once('-')
139 .ok_or_else(|| format!("invalid hours_window {s:?}: expect HH:MM-HH:MM"))?;
140 let parse = |p: &str| {
141 NaiveTime::parse_from_str(p.trim(), "%H:%M")
142 .map_err(|e| format!("invalid time {p:?}: {e}"))
143 };
144 Ok(Some((parse(l)?, parse(r)?)))
145 }
146
147 pub fn limits(&self) -> Limits {
149 Limits {
150 allowed_markets: self.allowed_markets.clone(),
151 allowed_symbols: self.allowed_symbols.clone(),
152 max_order_value: self.max_order_value,
153 max_daily_value: self.max_daily_value,
154 hours_window: self.hours_window.clone(),
155 max_orders_per_minute: self.max_orders_per_minute,
156 allowed_trd_sides: self.allowed_trd_sides.clone(),
157 }
158 }
159}
160
161pub fn hash_plaintext(plaintext: &str) -> String {
163 format!(
164 "sha256:{}",
165 hex::encode(Sha256::digest(plaintext.as_bytes()))
166 )
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn generate_and_verify() {
175 let (plaintext, rec) = KeyRecord::generate(
176 "test",
177 [Scope::QotRead].into_iter().collect(),
178 None,
179 None,
180 None,
181 );
182 assert_eq!(plaintext.len(), 64);
183 assert!(rec.matches(&plaintext));
184 assert!(!rec.matches("deadbeef"));
185 }
186
187 #[test]
188 fn hash_deterministic() {
189 assert_eq!(hash_plaintext("abc"), hash_plaintext("abc"));
190 assert_ne!(hash_plaintext("abc"), hash_plaintext("abd"));
191 }
192
193 #[test]
194 fn hours_range_parse() {
195 let mut rec = KeyRecord {
196 id: "t".into(),
197 hash: hash_plaintext("x"),
198 scopes: HashSet::new(),
199 allowed_markets: None,
200 allowed_symbols: None,
201 max_order_value: None,
202 max_daily_value: None,
203 hours_window: Some("09:30-16:00".into()),
204 max_orders_per_minute: None,
205 allowed_trd_sides: None,
206 expires_at: None,
207 created_at: Utc::now(),
208 note: None,
209 allowed_machines: None,
210 };
211 let (a, b) = rec.hours_range().unwrap().unwrap();
212 assert_eq!(a, NaiveTime::from_hms_opt(9, 30, 0).unwrap());
213 assert_eq!(b, NaiveTime::from_hms_opt(16, 0, 0).unwrap());
214
215 rec.hours_window = Some("bad".into());
216 assert!(rec.hours_range().is_err());
217 }
218
219 #[test]
220 fn expiration() {
221 let mut rec = KeyRecord {
222 id: "t".into(),
223 hash: hash_plaintext("x"),
224 scopes: HashSet::new(),
225 allowed_markets: None,
226 allowed_symbols: None,
227 max_order_value: None,
228 max_daily_value: None,
229 hours_window: None,
230 max_orders_per_minute: None,
231 allowed_trd_sides: None,
232 expires_at: Some(Utc::now() - chrono::Duration::seconds(1)),
233 created_at: Utc::now(),
234 note: None,
235 allowed_machines: None,
236 };
237 assert!(rec.is_expired(Utc::now()));
238 rec.expires_at = None;
239 assert!(!rec.is_expired(Utc::now()));
240 }
241}