Cloudflare Worker version (#6)

Refactor/generalise API/DB interactions out of OIDC.
This commit is contained in:
Simon Bihel
2022-01-11 10:43:06 +00:00
committed by GitHub
parent 9d725552e0
commit bbcacf4232
19 changed files with 3236 additions and 2854 deletions

199
src/db/cf.rs Normal file
View File

@@ -0,0 +1,199 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
// use cached::{stores::TimedCache, Cached};
use chrono::{DateTime, Duration, Utc};
use matchit::Node;
use std::collections::HashMap;
use worker::*;
use super::*;
const KV_NAMESPACE: &str = "SIWE-OIDC";
const DO_NAMESPACE: &str = "SIWE-OIDC-CODES";
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
// Heavily relying on:
// A Durable Object is given 30 seconds of additional CPU time for every
// request it processes, including WebSocket messages. In the absence of
// failures, in-memory state should not be reset after less than 30 seconds of
// inactivity.
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
// Wanted to use TimedCache but it (probably) crashes because it's using std::time::Instant which isn't available on wasm32.
#[durable_object]
pub struct DOCodes {
// codes: TimedCache<String, CodeEntry>,
codes: HashMap<String, (DateTime<Utc>, CodeEntry)>,
// state: State,
// env: Env,
}
#[durable_object]
impl DurableObject for DOCodes {
fn new(state: State, _env: Env) -> Self {
Self {
// codes: TimedCache::with_lifespan(ENTRY_LIFETIME.try_into().unwrap()),
codes: HashMap::new(),
// state,
// env,
}
}
async fn fetch(&mut self, mut req: Request) -> worker::Result<Response> {
// Can't use the Router because we need to reference self (thus move the var to the closure)
if matches!(req.method(), Method::Get) {
let mut matcher = Node::new();
matcher.insert("/:code", ())?;
let path = req.path();
let matched = match matcher.at(&path) {
Ok(m) => m,
Err(_) => return Response::error("Bad request", 400),
};
let code = if let Some(c) = matched.params.get("code") {
c
} else {
return Response::error("Bad request", 400);
};
if let Some(c) = self.codes.get(code) {
if c.0 + Duration::seconds(ENTRY_LIFETIME.try_into().unwrap()) < Utc::now() {
self.codes.remove(code);
Response::error("Not found", 404)
} else {
Response::from_json(&c.1)
}
} else {
Response::error("Not found", 404)
}
} else if matches!(req.method(), Method::Post) {
let mut matcher = Node::new();
matcher.insert("/:code", ())?;
let path = req.path();
let matched = match matcher.at(&path) {
Ok(m) => m,
Err(_) => return Response::error("Bad request", 400),
};
let code = if let Some(c) = matched.params.get("code") {
c
} else {
return Response::error("Bad request", 400);
};
let code_entry = match req.json().await {
Ok(p) => p,
Err(e) => return Response::error(format!("Bad request: {}", e), 400),
};
self.codes
.insert(code.to_string(), (Utc::now(), code_entry));
Response::empty()
} else {
Response::error("Method Not Allowed", 405)
}
}
}
pub struct CFClient {
pub ctx: RouteContext<()>,
pub url: Url,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl DBClient for CFClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
self.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.put(
&format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to build KV put: {}", e))?
// TODO put some sort of expiration for dynamic registration
.execute()
.await
.map_err(|e| anyhow!("Failed to put KV: {}", e))?;
Ok(())
}
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
let entry = self
.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.get(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get KV: {}", e))?
.map(|e| e.as_string());
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let namespace = self
.ctx
.durable_object(DO_NAMESPACE)
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
let stub = namespace
.id_from_name(&code)
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
.get_stub()
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
let mut headers = Headers::new();
headers.set("Content-Type", "application/json").unwrap();
let mut url = self.url.clone();
url.set_path(&code);
url.set_query(None);
let req = Request::new_with_init(
url.as_str(),
&RequestInit {
body: Some(wasm_bindgen::JsValue::from_str(
&serde_json::to_string(&code_entry)
.map_err(|e| anyhow!("Failed to serialize: {}", e))?,
)),
method: Method::Post,
headers,
..Default::default()
},
)
.map_err(|e| anyhow!("Failed to construct request for Durable Object: {}", e))?;
let res = stub
.fetch_with_request(req)
.await
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
match res.status_code() {
200 => Ok(()),
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
}
}
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
let namespace = self
.ctx
.durable_object(DO_NAMESPACE)
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
let stub = namespace
.id_from_name(&code)
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
.get_stub()
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
let mut url = self.url.clone();
url.set_path(&code);
url.set_query(None);
let mut res = stub
.fetch_with_str(url.as_str())
.await
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
match res.status_code() {
200 => Ok(Some(res.json().await.map_err(|e| {
anyhow!(
"Response to Durable Object failed to be deserialized: {}",
e
)
})?)),
404 => Ok(None),
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
}
}
}

40
src/db/mod.rs Normal file
View File

@@ -0,0 +1,40 @@
use anyhow::Result;
use async_trait::async_trait;
use openidconnect::{Nonce, RedirectUrl};
use serde::{Deserialize, Serialize};
#[cfg(not(target_arch = "wasm32"))]
mod redis;
#[cfg(not(target_arch = "wasm32"))]
pub use redis::RedisClient;
#[cfg(target_arch = "wasm32")]
mod cf;
#[cfg(target_arch = "wasm32")]
pub use cf::CFClient;
const KV_CLIENT_PREFIX: &str = "clients";
const ENTRY_LIFETIME: usize = 30;
#[derive(Clone, Serialize, Deserialize)]
pub struct CodeEntry {
pub exchange_count: usize,
pub address: String,
pub nonce: Option<Nonce>,
pub client_id: String,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientEntry {
pub secret: String,
pub redirect_uris: Vec<RedirectUrl>,
}
// Using a trait to easily pass async functions with async_trait
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait DBClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
}

89
src/db/redis.rs Normal file
View File

@@ -0,0 +1,89 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use super::*;
#[derive(Clone)]
pub struct RedisClient {
pub pool: Pool<RedisConnectionManager>,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl DBClient for RedisClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set(
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let entry: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set_ex(
code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(code)
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Ok(None);
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
Ok(Some(code_entry))
}
}