Cloudflare Worker version (#6)
Refactor/generalise API/DB interactions out of OIDC.
This commit is contained in:
199
src/db/cf.rs
Normal file
199
src/db/cf.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
// use cached::{stores::TimedCache, Cached};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use matchit::Node;
|
||||
use std::collections::HashMap;
|
||||
use worker::*;
|
||||
|
||||
use super::*;
|
||||
|
||||
const KV_NAMESPACE: &str = "SIWE-OIDC";
|
||||
const DO_NAMESPACE: &str = "SIWE-OIDC-CODES";
|
||||
|
||||
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
|
||||
// Heavily relying on:
|
||||
// A Durable Object is given 30 seconds of additional CPU time for every
|
||||
// request it processes, including WebSocket messages. In the absence of
|
||||
// failures, in-memory state should not be reset after less than 30 seconds of
|
||||
// inactivity.
|
||||
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
|
||||
|
||||
// Wanted to use TimedCache but it (probably) crashes because it's using std::time::Instant which isn't available on wasm32.
|
||||
|
||||
#[durable_object]
|
||||
pub struct DOCodes {
|
||||
// codes: TimedCache<String, CodeEntry>,
|
||||
codes: HashMap<String, (DateTime<Utc>, CodeEntry)>,
|
||||
// state: State,
|
||||
// env: Env,
|
||||
}
|
||||
|
||||
#[durable_object]
|
||||
impl DurableObject for DOCodes {
|
||||
fn new(state: State, _env: Env) -> Self {
|
||||
Self {
|
||||
// codes: TimedCache::with_lifespan(ENTRY_LIFETIME.try_into().unwrap()),
|
||||
codes: HashMap::new(),
|
||||
// state,
|
||||
// env,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch(&mut self, mut req: Request) -> worker::Result<Response> {
|
||||
// Can't use the Router because we need to reference self (thus move the var to the closure)
|
||||
if matches!(req.method(), Method::Get) {
|
||||
let mut matcher = Node::new();
|
||||
matcher.insert("/:code", ())?;
|
||||
let path = req.path();
|
||||
let matched = match matcher.at(&path) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return Response::error("Bad request", 400),
|
||||
};
|
||||
let code = if let Some(c) = matched.params.get("code") {
|
||||
c
|
||||
} else {
|
||||
return Response::error("Bad request", 400);
|
||||
};
|
||||
if let Some(c) = self.codes.get(code) {
|
||||
if c.0 + Duration::seconds(ENTRY_LIFETIME.try_into().unwrap()) < Utc::now() {
|
||||
self.codes.remove(code);
|
||||
Response::error("Not found", 404)
|
||||
} else {
|
||||
Response::from_json(&c.1)
|
||||
}
|
||||
} else {
|
||||
Response::error("Not found", 404)
|
||||
}
|
||||
} else if matches!(req.method(), Method::Post) {
|
||||
let mut matcher = Node::new();
|
||||
matcher.insert("/:code", ())?;
|
||||
let path = req.path();
|
||||
let matched = match matcher.at(&path) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return Response::error("Bad request", 400),
|
||||
};
|
||||
let code = if let Some(c) = matched.params.get("code") {
|
||||
c
|
||||
} else {
|
||||
return Response::error("Bad request", 400);
|
||||
};
|
||||
let code_entry = match req.json().await {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Response::error(format!("Bad request: {}", e), 400),
|
||||
};
|
||||
self.codes
|
||||
.insert(code.to_string(), (Utc::now(), code_entry));
|
||||
Response::empty()
|
||||
} else {
|
||||
Response::error("Method Not Allowed", 405)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CFClient {
|
||||
pub ctx: RouteContext<()>,
|
||||
pub url: Url,
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
impl DBClient for CFClient {
|
||||
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
|
||||
self.ctx
|
||||
.kv(KV_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
|
||||
.put(
|
||||
&format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||
serde_json::to_string(&client_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to build KV put: {}", e))?
|
||||
// TODO put some sort of expiration for dynamic registration
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to put KV: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
|
||||
let entry = self
|
||||
.ctx
|
||||
.kv(KV_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
|
||||
.get(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get KV: {}", e))?
|
||||
.map(|e| e.as_string());
|
||||
if let Some(e) = entry {
|
||||
Ok(serde_json::from_str(&e)
|
||||
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
|
||||
let namespace = self
|
||||
.ctx
|
||||
.durable_object(DO_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
|
||||
let stub = namespace
|
||||
.id_from_name(&code)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
|
||||
.get_stub()
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
|
||||
let mut headers = Headers::new();
|
||||
headers.set("Content-Type", "application/json").unwrap();
|
||||
let mut url = self.url.clone();
|
||||
url.set_path(&code);
|
||||
url.set_query(None);
|
||||
let req = Request::new_with_init(
|
||||
url.as_str(),
|
||||
&RequestInit {
|
||||
body: Some(wasm_bindgen::JsValue::from_str(
|
||||
&serde_json::to_string(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize: {}", e))?,
|
||||
)),
|
||||
method: Method::Post,
|
||||
headers,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to construct request for Durable Object: {}", e))?;
|
||||
let res = stub
|
||||
.fetch_with_request(req)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
|
||||
match res.status_code() {
|
||||
200 => Ok(()),
|
||||
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
|
||||
}
|
||||
}
|
||||
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
|
||||
let namespace = self
|
||||
.ctx
|
||||
.durable_object(DO_NAMESPACE)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
|
||||
let stub = namespace
|
||||
.id_from_name(&code)
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
|
||||
.get_stub()
|
||||
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
|
||||
let mut url = self.url.clone();
|
||||
url.set_path(&code);
|
||||
url.set_query(None);
|
||||
let mut res = stub
|
||||
.fetch_with_str(url.as_str())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
|
||||
match res.status_code() {
|
||||
200 => Ok(Some(res.json().await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Response to Durable Object failed to be deserialized: {}",
|
||||
e
|
||||
)
|
||||
})?)),
|
||||
404 => Ok(None),
|
||||
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
|
||||
}
|
||||
}
|
||||
}
|
||||
40
src/db/mod.rs
Normal file
40
src/db/mod.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use openidconnect::{Nonce, RedirectUrl};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod redis;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use redis::RedisClient;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod cf;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use cf::CFClient;
|
||||
|
||||
const KV_CLIENT_PREFIX: &str = "clients";
|
||||
const ENTRY_LIFETIME: usize = 30;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct CodeEntry {
|
||||
pub exchange_count: usize,
|
||||
pub address: String,
|
||||
pub nonce: Option<Nonce>,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct ClientEntry {
|
||||
pub secret: String,
|
||||
pub redirect_uris: Vec<RedirectUrl>,
|
||||
}
|
||||
|
||||
// Using a trait to easily pass async functions with async_trait
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
pub trait DBClient {
|
||||
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
|
||||
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
|
||||
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
|
||||
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
|
||||
}
|
||||
89
src/db/redis.rs
Normal file
89
src/db/redis.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisClient {
|
||||
pub pool: Pool<RedisConnectionManager>,
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
impl DBClient for RedisClient {
|
||||
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
|
||||
conn.set(
|
||||
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||
serde_json::to_string(&client_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let entry: Option<String> = conn
|
||||
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if let Some(e) = entry {
|
||||
Ok(serde_json::from_str(&e)
|
||||
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
conn.set_ex(
|
||||
code.to_string(),
|
||||
hex::encode(
|
||||
bincode::serialize(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||
),
|
||||
ENTRY_LIFETIME,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let serialized_entry: Option<Vec<u8>> = conn
|
||||
.get(code)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if serialized_entry.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
let code_entry: CodeEntry = bincode::deserialize(
|
||||
&hex::decode(serialized_entry.unwrap())
|
||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||
Ok(Some(code_entry))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user