Cloudflare Worker version (#6)

Refactor/generalise API/DB interactions out of OIDC.
This commit is contained in:
Simon Bihel
2022-01-11 10:43:06 +00:00
committed by GitHub
parent 9d725552e0
commit bbcacf4232
19 changed files with 3236 additions and 2854 deletions

364
src/axum_lib.rs Normal file
View File

@@ -0,0 +1,364 @@
use anyhow::{anyhow, Result};
use async_redis_session::RedisSessionStore;
use axum::{
extract::{self, Extension, Form, Query, TypedHeader},
http::{
header::{self, HeaderMap},
StatusCode,
},
response::{self, IntoResponse, Redirect},
routing::{get, get_service, post},
AddExtensionLayer, Json, Router,
};
use bb8_redis::{bb8, RedisConnectionManager};
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use headers::{
self,
authorization::{Basic, Bearer},
Authorization,
};
use openidconnect::core::{
CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata,
CoreResponseType, CoreTokenResponse, CoreUserInfoClaims,
};
use rand::rngs::OsRng;
use rsa::{
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
RsaPrivateKey,
};
use std::net::SocketAddr;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tracing::info;
use super::config;
use super::oidc::{self, CustomError};
use super::session::*;
use ::siwe_oidc::db::*;
impl IntoResponse for CustomError {
fn into_response(self) -> response::Response {
match self {
CustomError::BadRequest(_) => {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
CustomError::BadRequestToken(e) => {
(StatusCode::BAD_REQUEST, Json::from(e)).into_response()
}
CustomError::Unauthorized(_) => {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
CustomError::Redirect(uri) => Redirect::to(
uri.parse().unwrap(),
// .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
)
.into_response(),
CustomError::Other(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
}
}
async fn jwk_set(
Extension(private_key): Extension<RsaPrivateKey>,
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
let jwks = oidc::jwks(private_key)?;
Ok(jwks.into())
}
async fn provider_metadata(
Extension(config): Extension<config::Config>,
) -> Result<Json<CoreProviderMetadata>, CustomError> {
Ok(oidc::metadata(config.base_url)?.into())
}
// TODO should check Authorization header
// Actually, client secret can be
// 1. in the POST (currently supported) [x]
// 2. Authorization header [x]
// 3. JWT [ ]
// 4. signed JWT [ ]
// according to Keycloak
async fn token(
Form(form): Form<oidc::TokenForm>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
basic: Option<TypedHeader<Authorization<Basic>>>,
Extension(private_key): Extension<RsaPrivateKey>,
Extension(config): Extension<config::Config>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<Json<CoreTokenResponse>, CustomError> {
let secret = if let Some(b) = bearer {
Some(b.0 .0.token().to_string())
} else {
basic.map(|b| b.0 .0.password().to_string())
};
let token_response = oidc::token(
form,
secret,
private_key,
config.base_url,
config.require_secret,
&redis_client,
)
.await?;
Ok(token_response.into())
}
// TODO handle `registration` parameter
async fn authorize(
session: UserSessionFromSession,
Query(params): Query<oidc::AuthorizeParams>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(cookie) => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, cookie);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
&params.client_id,
&params.redirect_uri.to_string(),
&params.scope.to_string(),
&params.response_type.unwrap_or(CoreResponseType::Code).as_ref(),
&params.state.unwrap_or_default(),
&params.client_id,
&params.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { header, nonce } => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header);
(nonce, headers)
}
};
let url = oidc::authorize(params, nonce, &redis_client).await?;
Ok((
headers,
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
async fn sign_in(
session: UserSessionFromSession,
Query(params): Query<oidc::SignInParams>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(header) => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.client_id.clone(),
&params.redirect_uri.to_string(),
&params.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { .. } => {
return Ok((
HeaderMap::new(),
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.client_id.clone(),
&params.redirect_uri.to_string(),
&params.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
};
let url = oidc::sign_in(params, Some(nonce), cookies, &redis_client).await?;
Ok((
headers,
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
// TODO clear session
}
async fn register(
extract::Json(payload): extract::Json<CoreClientMetadata>,
Extension(redis_client): Extension<RedisClient>,
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
let registration = oidc::register(payload, &redis_client).await?;
Ok((StatusCode::CREATED, registration.into()))
}
// TODO CORS
// TODO need validation of the token
async fn userinfo(
payload: Option<Form<oidc::UserInfoPayload>>,
bearer: Option<TypedHeader<Authorization<Bearer>>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
Extension(redis_client): Extension<RedisClient>,
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
let payload = if let Some(Form(p)) = payload {
p
} else {
oidc::UserInfoPayload { access_token: None }
};
let claims = oidc::userinfo(bearer.map(|b| b.0 .0), payload, &redis_client).await?;
Ok(claims.into())
}
async fn healthcheck() {}
pub async fn main() {
let config = Figment::from(Serialized::defaults(config::Config::default()))
.merge(Toml::file("siwe-oidc.toml").nested())
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
let config = config.extract::<config::Config>().unwrap();
tracing_subscriber::fmt::init();
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
// let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
let redis_client = RedisClient { pool };
for (id, secret) in &config.default_clients.clone() {
let client_entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: vec![],
};
redis_client
.set_client(id.to_string(), client_entry)
.await
.unwrap(); // TODO
}
let private_key = if let Some(key) = &config.rsa_pem {
RsaPrivateKey::from_pkcs1_pem(key)
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap()
} else {
info!("Generating key...");
let mut rng = OsRng;
let bits = 2048;
let private = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
.unwrap();
info!("Generated key.");
info!("{:?}", private.to_pkcs1_pem().unwrap());
private
};
let app = Router::new()
.nest(
"/build",
get_service(ServeDir::new("./static/build")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.nest(
"/img",
get_service(ServeDir::new("./static/img")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/",
get_service(ServeFile::new("./static/index.html")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/error",
get_service(ServeFile::new("./static/error.html")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/favicon.png",
get_service(ServeFile::new("./static/favicon.png")).handle_error(
|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(oidc::METADATA_PATH, get(provider_metadata))
.route(oidc::JWK_PATH, get(jwk_set))
.route(oidc::TOKEN_PATH, post(token))
.route(oidc::AUTHORIZE_PATH, get(authorize))
.route(oidc::REGISTER_PATH, post(register))
.route(oidc::USERINFO_PATH, get(userinfo).post(userinfo))
.route(oidc::SIGNIN_PATH, get(sign_in))
.route("/health", get(healthcheck))
.layer(AddExtensionLayer::new(private_key))
.layer(AddExtensionLayer::new(config.clone()))
.layer(AddExtensionLayer::new(redis_client))
.layer(AddExtensionLayer::new(
RedisSessionStore::new(config.redis_url.clone())
.unwrap()
.with_prefix("async-sessions/"),
))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from((config.address, config.port));
tracing::info!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}

View File

@@ -1,43 +0,0 @@
use anyhow::{anyhow, Result};
use bb8_redis::{bb8::PooledConnection, redis::AsyncCommands, RedisConnectionManager};
use openidconnect::RedirectUrl;
use serde::{Deserialize, Serialize};
const KV_CLIENT_PREFIX: &str = "clients";
#[derive(Serialize, Deserialize)]
pub struct ClientEntry {
pub secret: String,
pub redirect_uris: Vec<RedirectUrl>,
}
pub async fn set_client(
mut conn: PooledConnection<'_, RedisConnectionManager>,
client_id: String,
client_entry: ClientEntry,
) -> Result<()> {
conn.set(
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
pub async fn get_client(
mut conn: PooledConnection<'_, RedisConnectionManager>,
client_id: String,
) -> Result<Option<ClientEntry>> {
let entry: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}

199
src/db/cf.rs Normal file
View File

@@ -0,0 +1,199 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
// use cached::{stores::TimedCache, Cached};
use chrono::{DateTime, Duration, Utc};
use matchit::Node;
use std::collections::HashMap;
use worker::*;
use super::*;
const KV_NAMESPACE: &str = "SIWE-OIDC";
const DO_NAMESPACE: &str = "SIWE-OIDC-CODES";
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
// Heavily relying on:
// A Durable Object is given 30 seconds of additional CPU time for every
// request it processes, including WebSocket messages. In the absence of
// failures, in-memory state should not be reset after less than 30 seconds of
// inactivity.
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
// Wanted to use TimedCache but it (probably) crashes because it's using std::time::Instant which isn't available on wasm32.
#[durable_object]
pub struct DOCodes {
// codes: TimedCache<String, CodeEntry>,
codes: HashMap<String, (DateTime<Utc>, CodeEntry)>,
// state: State,
// env: Env,
}
#[durable_object]
impl DurableObject for DOCodes {
fn new(state: State, _env: Env) -> Self {
Self {
// codes: TimedCache::with_lifespan(ENTRY_LIFETIME.try_into().unwrap()),
codes: HashMap::new(),
// state,
// env,
}
}
async fn fetch(&mut self, mut req: Request) -> worker::Result<Response> {
// Can't use the Router because we need to reference self (thus move the var to the closure)
if matches!(req.method(), Method::Get) {
let mut matcher = Node::new();
matcher.insert("/:code", ())?;
let path = req.path();
let matched = match matcher.at(&path) {
Ok(m) => m,
Err(_) => return Response::error("Bad request", 400),
};
let code = if let Some(c) = matched.params.get("code") {
c
} else {
return Response::error("Bad request", 400);
};
if let Some(c) = self.codes.get(code) {
if c.0 + Duration::seconds(ENTRY_LIFETIME.try_into().unwrap()) < Utc::now() {
self.codes.remove(code);
Response::error("Not found", 404)
} else {
Response::from_json(&c.1)
}
} else {
Response::error("Not found", 404)
}
} else if matches!(req.method(), Method::Post) {
let mut matcher = Node::new();
matcher.insert("/:code", ())?;
let path = req.path();
let matched = match matcher.at(&path) {
Ok(m) => m,
Err(_) => return Response::error("Bad request", 400),
};
let code = if let Some(c) = matched.params.get("code") {
c
} else {
return Response::error("Bad request", 400);
};
let code_entry = match req.json().await {
Ok(p) => p,
Err(e) => return Response::error(format!("Bad request: {}", e), 400),
};
self.codes
.insert(code.to_string(), (Utc::now(), code_entry));
Response::empty()
} else {
Response::error("Method Not Allowed", 405)
}
}
}
pub struct CFClient {
pub ctx: RouteContext<()>,
pub url: Url,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl DBClient for CFClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
self.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.put(
&format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to build KV put: {}", e))?
// TODO put some sort of expiration for dynamic registration
.execute()
.await
.map_err(|e| anyhow!("Failed to put KV: {}", e))?;
Ok(())
}
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
let entry = self
.ctx
.kv(KV_NAMESPACE)
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
.get(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get KV: {}", e))?
.map(|e| e.as_string());
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let namespace = self
.ctx
.durable_object(DO_NAMESPACE)
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
let stub = namespace
.id_from_name(&code)
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
.get_stub()
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
let mut headers = Headers::new();
headers.set("Content-Type", "application/json").unwrap();
let mut url = self.url.clone();
url.set_path(&code);
url.set_query(None);
let req = Request::new_with_init(
url.as_str(),
&RequestInit {
body: Some(wasm_bindgen::JsValue::from_str(
&serde_json::to_string(&code_entry)
.map_err(|e| anyhow!("Failed to serialize: {}", e))?,
)),
method: Method::Post,
headers,
..Default::default()
},
)
.map_err(|e| anyhow!("Failed to construct request for Durable Object: {}", e))?;
let res = stub
.fetch_with_request(req)
.await
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
match res.status_code() {
200 => Ok(()),
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
}
}
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
let namespace = self
.ctx
.durable_object(DO_NAMESPACE)
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
let stub = namespace
.id_from_name(&code)
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
.get_stub()
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
let mut url = self.url.clone();
url.set_path(&code);
url.set_query(None);
let mut res = stub
.fetch_with_str(url.as_str())
.await
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
match res.status_code() {
200 => Ok(Some(res.json().await.map_err(|e| {
anyhow!(
"Response to Durable Object failed to be deserialized: {}",
e
)
})?)),
404 => Ok(None),
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
}
}
}

40
src/db/mod.rs Normal file
View File

@@ -0,0 +1,40 @@
use anyhow::Result;
use async_trait::async_trait;
use openidconnect::{Nonce, RedirectUrl};
use serde::{Deserialize, Serialize};
#[cfg(not(target_arch = "wasm32"))]
mod redis;
#[cfg(not(target_arch = "wasm32"))]
pub use redis::RedisClient;
#[cfg(target_arch = "wasm32")]
mod cf;
#[cfg(target_arch = "wasm32")]
pub use cf::CFClient;
const KV_CLIENT_PREFIX: &str = "clients";
const ENTRY_LIFETIME: usize = 30;
#[derive(Clone, Serialize, Deserialize)]
pub struct CodeEntry {
pub exchange_count: usize,
pub address: String,
pub nonce: Option<Nonce>,
pub client_id: String,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientEntry {
pub secret: String,
pub redirect_uris: Vec<RedirectUrl>,
}
// Using a trait to easily pass async functions with async_trait
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait DBClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
}

89
src/db/redis.rs Normal file
View File

@@ -0,0 +1,89 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use super::*;
#[derive(Clone)]
pub struct RedisClient {
pub pool: Pool<RedisConnectionManager>,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl DBClient for RedisClient {
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set(
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
serde_json::to_string(&client_entry)
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let entry: Option<String> = conn
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if let Some(e) = entry {
Ok(serde_json::from_str(&e)
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
} else {
Ok(None)
}
}
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set_ex(
code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
Ok(())
}
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
let mut conn = self
.pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(code)
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Ok(None);
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
Ok(Some(code_entry))
}
}

18
src/lib.rs Normal file
View File

@@ -0,0 +1,18 @@
#[cfg(target_arch = "wasm32")]
use worker::*;
pub mod db;
#[cfg(target_arch = "wasm32")]
pub mod oidc;
#[cfg(target_arch = "wasm32")]
mod worker_lib;
#[cfg(target_arch = "wasm32")]
use worker_lib::main as worker_main;
// pub use worker_lib::main;
#[cfg(target_arch = "wasm32")]
#[event(fetch)]
pub async fn main(req: Request, env: Env) -> Result<Response> {
worker_main(req, env).await
}

View File

@@ -1,882 +1,19 @@
use anyhow::{anyhow, Result};
use async_redis_session::RedisSessionStore;
use axum::{
body::{Bytes, Full},
error_handling::HandleErrorExt,
extract::{self, Extension, Form, Query, TypedHeader},
http::{
header::{self, HeaderMap},
Response, StatusCode,
},
response::{IntoResponse, Redirect},
routing::{get, post, service_method_routing},
AddExtensionLayer, Json, Router,
};
use bb8_redis::{bb8, bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use chrono::{Duration, Utc};
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use headers::{self, authorization::Bearer, Authorization};
use hex::FromHex;
use iri_string::types::{UriAbsoluteString, UriString};
use openidconnect::{
core::{
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
url::Url,
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
};
use rand::rngs::OsRng;
use rsa::{
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
RsaPrivateKey,
};
use serde::{Deserialize, Serialize};
use siwe::eip4361::{Message, Version};
use std::{convert::Infallible, net::SocketAddr, str::FromStr};
use thiserror::Error;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tracing::info;
use urlencoding::decode;
use uuid::Uuid;
#[cfg(not(target_arch = "wasm32"))]
mod axum_lib;
#[cfg(not(target_arch = "wasm32"))]
mod config;
mod db;
#[cfg(not(target_arch = "wasm32"))]
mod oidc;
#[cfg(not(target_arch = "wasm32"))]
mod session;
#[cfg(not(target_arch = "wasm32"))]
use axum_lib::main as axum_main;
use db::*;
use session::*;
const KID: &str = "key1";
const ENTRY_LIFETIME: usize = 30;
type ConnectionPool = Pool<RedisConnectionManager>;
#[derive(Serialize, Debug)]
pub struct TokenError {
pub error: CoreErrorResponseType,
}
#[derive(Debug, Error)]
pub enum CustomError {
#[error("{0}")]
BadRequest(String),
#[error("{0:?}")]
BadRequestToken(Json<TokenError>),
#[error("{0}")]
Unauthorized(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl IntoResponse for CustomError {
type Body = Full<Bytes>;
type BodyError = Infallible;
fn into_response(self) -> Response<Self::Body> {
match self {
CustomError::BadRequest(_) => {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
CustomError::BadRequestToken(e) => (StatusCode::BAD_REQUEST, e).into_response(),
CustomError::Unauthorized(_) => {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
CustomError::Other(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
}
}
async fn jwk_set(
Extension(private_key): Extension<RsaPrivateKey>,
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
&pem,
Some(JsonWebKeyId::new(KID.to_string())),
)
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
.as_verification_key()]);
Ok(jwks.into())
}
async fn provider_metadata(
Extension(config): Extension<config::Config>,
) -> Result<Json<CoreProviderMetadata>, CustomError> {
let pm = CoreProviderMetadata::new(
IssuerUrl::from_url(config.base_url.clone()),
AuthUrl::from_url(
config
.base_url
.join("authorize")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
JsonWebKeySetUrl::from_url(
config
.base_url
.join("jwk")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
vec![
ResponseTypes::new(vec![CoreResponseType::Code]),
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
],
vec![CoreSubjectIdentifierType::Pairwise],
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
EmptyAdditionalProviderMetadata {},
)
.set_token_endpoint(Some(TokenUrl::from_url(
config
.base_url
.join("token")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
config
.base_url
.join("userinfo")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_scopes_supported(Some(vec![
Scope::new("openid".to_string()),
// Scope::new("email".to_string()),
// Scope::new("profile".to_string()),
]))
.set_claims_supported(Some(vec![
CoreClaimName::new("sub".to_string()),
CoreClaimName::new("aud".to_string()),
// CoreClaimName::new("email".to_string()),
// CoreClaimName::new("email_verified".to_string()),
CoreClaimName::new("exp".to_string()),
CoreClaimName::new("iat".to_string()),
CoreClaimName::new("iss".to_string()),
// CoreClaimName::new("name".to_string()),
// CoreClaimName::new("given_name".to_string()),
// CoreClaimName::new("family_name".to_string()),
// CoreClaimName::new("picture".to_string()),
// CoreClaimName::new("locale".to_string()),
]))
.set_registration_endpoint(Some(RegistrationUrl::from_url(
config
.base_url
.join("register")
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_token_endpoint_auth_methods_supported(Some(vec![
CoreClientAuthMethod::ClientSecretBasic,
CoreClientAuthMethod::ClientSecretPost,
]));
Ok(pm.into())
}
#[derive(Serialize, Deserialize)]
struct TokenForm {
code: String,
client_id: Option<String>,
client_secret: Option<String>,
grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
}
// TODO should check Authorization header
// Actually, client secret can be
// 1. in the POST (currently supported) [x]
// 2. Authorization header [x]
// 3. JWT [ ]
// 4. signed JWT [ ]
// according to Keycloak
async fn token(
form: Form<TokenForm>,
bearer: Option<TypedHeader<Authorization<Bearer>>>,
Extension(private_key): Extension<RsaPrivateKey>,
Extension(config): Extension<config::Config>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<Json<CoreTokenResponse>, CustomError> {
let mut conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(form.code.to_string())
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Err(CustomError::BadRequestToken(
TokenError {
error: CoreErrorResponseType::InvalidGrant,
}
.into(),
));
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
let client_id = if let Some(c) = form.client_id.clone() {
c
} else {
code_entry.client_id.clone()
};
if let Some(secret) = if let Some(TypedHeader(Authorization(b))) = bearer {
Some(b.token().to_string())
} else {
form.client_secret.clone()
} {
let conn2 = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let client_entry = get_client(conn2, client_id.clone()).await?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
if secret != client_entry.unwrap().secret {
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
}
} else if config.require_secret {
return Err(CustomError::Unauthorized("Secret required.".to_string()));
}
if code_entry.exchange_count > 0 {
// TODO use Oauth error response
return Err(anyhow!("Code was previously exchanged.").into());
}
conn.set_ex(
form.code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
let access_token = AccessToken::new(form.code.to_string());
let core_id_token = CoreIdTokenClaims::new(
IssuerUrl::from_url(config.base_url),
vec![Audience::new(client_id.clone())],
Utc::now() + Duration::seconds(60),
Utc::now(),
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims {},
)
.set_nonce(code_entry.nonce);
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let id_token = CoreIdToken::new(
core_id_token,
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
Some(&access_token),
None,
)
.map_err(|e| anyhow!("{}", e))?;
Ok(CoreTokenResponse::new(
access_token,
CoreTokenType::Bearer,
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
)
.into())
}
#[derive(Deserialize)]
struct AuthorizeParams {
client_id: String,
redirect_uri: RedirectUrl,
scope: Scope,
response_type: Option<CoreResponseType>,
state: Option<String>,
nonce: Option<Nonce>,
prompt: Option<CoreAuthPrompt>,
request_uri: Option<RequestUrl>,
request: Option<String>,
}
// TODO handle `registration` parameter
async fn authorize(
session: UserSessionFromSession,
params: Query<AuthorizeParams>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let client_entry = get_client(conn, params.client_id.clone())
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
let mut r_u = params.0.redirect_uri.clone().url().clone();
r_u.set_query(None);
let mut r_us: Vec<Url> = client_entry
.unwrap()
.redirect_uris
.iter_mut()
.map(|u| u.url().clone())
.collect();
r_us.iter_mut().for_each(|u| u.set_query(None));
if !r_us.contains(&r_u) {
return Ok((
HeaderMap::new(),
Redirect::to(
"/error?message=unregistered_request_uri"
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
let state = if let Some(s) = params.0.state.clone() {
s
} else if params.0.request_uri.is_some() {
let mut url = params.0.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
} else if params.0.request.is_some() {
let mut url = params.0.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
} else {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing state");
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
};
if let Some(CoreAuthPrompt::None) = params.0.prompt {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
);
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
if params.0.response_type.is_none() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing response_type");
return Ok((
HeaderMap::new(),
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
let response_type = params.0.response_type.as_ref().unwrap();
if params.scope != Scope::new("openid".to_string()) {
return Err(anyhow!("Scope not supported").into());
}
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(cookie) => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, cookie);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
&params.0.client_id,
&params.0.redirect_uri.to_string(),
&params.0.scope.to_string(),
&response_type.as_ref(),
&state,
&params.0.client_id,
&params.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { header, nonce } => {
let mut headers = HeaderMap::new();
headers.insert(header::SET_COOKIE, header);
(nonce, headers)
}
};
let domain = params.redirect_uri.url().host().unwrap();
let oidc_nonce_param = if let Some(n) = &params.nonce {
format!("&oidc_nonce={}", n.secret())
} else {
"".to_string()
};
Ok((
headers,
Redirect::to(
format!(
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
nonce,
domain,
params.redirect_uri.to_string(),
state,
params.client_id,
oidc_nonce_param
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
#[derive(Serialize, Deserialize)]
struct SiweCookie {
message: Web3ModalMessage,
signature: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Web3ModalMessage {
pub domain: String,
pub address: String,
pub statement: String,
pub uri: String,
pub version: String,
pub chain_id: String,
pub nonce: String,
pub issued_at: String,
pub expiration_time: Option<String>,
pub not_before: Option<String>,
pub request_id: Option<String>,
pub resources: Option<Vec<String>>,
}
impl Web3ModalMessage {
pub fn to_eip4361_message(&self) -> Result<Message> {
let mut next_resources: Vec<UriString> = Vec::new();
match &self.resources {
Some(resources) => {
for resource in resources {
let x = UriString::from_str(resource)?;
next_resources.push(x)
}
}
None => {}
}
Ok(Message {
domain: self.domain.clone().try_into()?,
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
statement: self.statement.to_string(),
uri: UriAbsoluteString::from_str(&self.uri)?,
version: Version::from_str(&self.version)?,
chain_id: self.chain_id.to_string(),
nonce: self.nonce.to_string(),
issued_at: self.issued_at.to_string(),
expiration_time: self.expiration_time.clone(),
not_before: self.not_before.clone(),
request_id: self.request_id.clone(),
resources: next_resources,
})
}
}
#[derive(Serialize, Deserialize)]
struct CodeEntry {
exchange_count: usize,
address: String,
nonce: Option<Nonce>,
client_id: String,
}
#[derive(Deserialize)]
struct SignInParams {
redirect_uri: RedirectUrl,
state: String,
oidc_nonce: Option<Nonce>,
client_id: String,
}
async fn sign_in(
session: UserSessionFromSession,
params: Query<SignInParams>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<(HeaderMap, Redirect), CustomError> {
let mut headers = HeaderMap::new();
let siwe_cookie: SiweCookie = match cookies.get("siwe") {
Some(c) => serde_json::from_str(
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
)
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
None => {
return Err(anyhow!("No `siwe` cookie").into());
}
};
let (nonce, headers) = match session {
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
UserSessionFromSession::Invalid(header) => {
headers.insert(header::SET_COOKIE, header);
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.0.client_id.clone(),
&params.0.redirect_uri.to_string(),
&params.0.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
));
}
UserSessionFromSession::Created { .. } => {
return Ok((
headers,
Redirect::to(
format!(
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
&params.0.client_id.clone(),
&params.0.redirect_uri.to_string(),
&params.0.state,
)
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
}
};
let signature = match <[u8; 65]>::from_hex(
siwe_cookie
.signature
.chars()
.skip(2)
.take(130)
.collect::<String>(),
) {
Ok(s) => s,
Err(e) => {
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
}
};
let message = siwe_cookie
.message
.to_eip4361_message()
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
info!("{}", message);
message
.verify_eip191(signature)
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
let domain = params.redirect_uri.url().host().unwrap();
if domain.to_string() != siwe_cookie.message.domain {
return Err(anyhow!("Conflicting domains in message and redirect").into());
}
if nonce != siwe_cookie.message.nonce {
return Err(anyhow!("Conflicting nonces in message and session").into());
}
let code_entry = CodeEntry {
address: siwe_cookie.message.address,
nonce: params.oidc_nonce.clone(),
exchange_count: 0,
client_id: params.0.client_id.clone(),
};
let code = Uuid::new_v4();
let mut conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
conn.set_ex(
code.to_string(),
hex::encode(
bincode::serialize(&code_entry)
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
),
ENTRY_LIFETIME,
)
.await
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("code", &code.to_string());
url.query_pairs_mut().append_pair("state", &params.state);
Ok((
headers,
Redirect::to(
url.as_str()
.parse()
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
),
))
// TODO clear session
}
async fn register(
extract::Json(payload): extract::Json<CoreClientMetadata>,
Extension(pool): Extension<ConnectionPool>,
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
let id = Uuid::new_v4();
let secret = Uuid::new_v4();
let conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: payload.redirect_uris().to_vec(),
};
set_client(conn, id.to_string(), entry).await?;
Ok((
StatusCode::CREATED,
CoreClientRegistrationResponse::new(
ClientId::new(id.to_string()),
payload.redirect_uris().to_vec(),
EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(),
)
.set_client_secret(Some(ClientSecret::new(secret.to_string())))
.into(),
))
}
// TODO CORS
// TODO need validation of the token
// TODO restrict access token use to only once?
async fn userinfo(
// access_token: AccessTokenUserInfo, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
Extension(pool): Extension<ConnectionPool>,
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
let code = bearer.token().to_string();
let mut conn = pool
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
let serialized_entry: Option<Vec<u8>> = conn
.get(code)
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if serialized_entry.is_none() {
return Err(CustomError::BadRequest("Unknown code.".to_string()));
}
let code_entry: CodeEntry = bincode::deserialize(
&hex::decode(serialized_entry.unwrap())
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
)
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
Ok(CoreUserInfoClaims::new(
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims::default(),
)
.into())
}
async fn healthcheck() {}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::main]
async fn main() {
let config = Figment::from(Serialized::defaults(config::Config::default()))
.merge(Toml::file("siwe-oidc.toml").nested())
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
let config = config.extract::<config::Config>().unwrap();
tracing_subscriber::fmt::init();
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
for (id, secret) in &config.default_clients.clone() {
let conn = pool2
.get()
.await
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))
.unwrap();
let client_entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: vec![],
};
set_client(conn, id.to_string(), client_entry)
.await
.unwrap(); // TODO
}
let private_key = if let Some(key) = &config.rsa_pem {
RsaPrivateKey::from_pkcs1_pem(key)
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap()
} else {
info!("Generating key...");
let mut rng = OsRng;
let bits = 2048;
let private = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
.unwrap();
info!("Generated key.");
info!("{:?}", private.to_pkcs1_pem().unwrap());
private
};
let app = Router::new()
.nest(
"/build",
service_method_routing::get(ServeDir::new("./static/build")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.nest(
"/img",
service_method_routing::get(ServeDir::new("./static/img")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/",
service_method_routing::get(ServeFile::new("./static/index.html")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/error",
service_method_routing::get(ServeFile::new("./static/error.html")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route(
"/favicon.png",
service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error(
|error: std::io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
},
),
)
.route("/.well-known/openid-configuration", get(provider_metadata))
.route("/jwk", get(jwk_set))
.route("/token", post(token))
.route("/authorize", get(authorize))
.route("/register", post(register))
.route("/userinfo", get(userinfo).post(userinfo))
.route("/sign_in", get(sign_in))
.route("/health", get(healthcheck))
.layer(AddExtensionLayer::new(private_key))
.layer(AddExtensionLayer::new(config.clone()))
.layer(AddExtensionLayer::new(pool))
.layer(AddExtensionLayer::new(
RedisSessionStore::new(config.redis_url.clone())
.unwrap()
.with_prefix("async-sessions/"),
))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from((config.address, config.port));
tracing::info!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
axum_main().await
}
#[cfg(target_arch = "wasm32")]
fn main() {}

523
src/oidc.rs Normal file
View File

@@ -0,0 +1,523 @@
use anyhow::{anyhow, Result};
use chrono::{Duration, Utc};
use headers::{self, authorization::Bearer};
use hex::FromHex;
use iri_string::types::UriString;
use openidconnect::{
core::{
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
url::Url,
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
};
use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey};
use serde::{Deserialize, Serialize};
use siwe::eip4361::{Message, Version};
use std::str::FromStr;
use thiserror::Error;
use tracing::info;
use urlencoding::decode;
use uuid::Uuid;
#[cfg(target_arch = "wasm32")]
use super::db::*;
#[cfg(not(target_arch = "wasm32"))]
use siwe_oidc::db::*;
const KID: &str = "key1";
pub const METADATA_PATH: &str = "/.well-known/openid-configuration";
pub const JWK_PATH: &str = "/jwk";
pub const TOKEN_PATH: &str = "/token";
pub const AUTHORIZE_PATH: &str = "/authorize";
pub const REGISTER_PATH: &str = "/register";
pub const USERINFO_PATH: &str = "/userinfo";
pub const SIGNIN_PATH: &str = "/sign_in";
pub const SIWE_COOKIE_KEY: &str = "siwe";
#[cfg(not(target_arch = "wasm32"))]
type DBClientType = (dyn DBClient + Sync);
#[cfg(target_arch = "wasm32")]
type DBClientType = dyn DBClient;
#[derive(Serialize, Debug)]
pub struct TokenError {
pub error: CoreErrorResponseType,
pub error_description: String,
}
#[derive(Debug, Error)]
pub enum CustomError {
#[error("{0}")]
BadRequest(String),
#[error("{0:?}")]
BadRequestToken(TokenError),
#[error("{0}")]
Unauthorized(String),
#[error("{0:?}")]
Redirect(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub fn jwks(private_key: RsaPrivateKey) -> Result<CoreJsonWebKeySet, CustomError> {
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
&pem,
Some(JsonWebKeyId::new(KID.to_string())),
)
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
.as_verification_key()]);
Ok(jwks)
}
pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
let pm = CoreProviderMetadata::new(
IssuerUrl::from_url(base_url.clone()),
AuthUrl::from_url(
base_url
.join(AUTHORIZE_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
JsonWebKeySetUrl::from_url(
base_url
.join(JWK_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
),
vec![
ResponseTypes::new(vec![CoreResponseType::Code]),
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
],
vec![CoreSubjectIdentifierType::Pairwise],
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
EmptyAdditionalProviderMetadata {},
)
.set_token_endpoint(Some(TokenUrl::from_url(
base_url
.join(TOKEN_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
base_url
.join(USERINFO_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_scopes_supported(Some(vec![
Scope::new("openid".to_string()),
// Scope::new("email".to_string()),
// Scope::new("profile".to_string()),
]))
.set_claims_supported(Some(vec![
CoreClaimName::new("sub".to_string()),
CoreClaimName::new("aud".to_string()),
// CoreClaimName::new("email".to_string()),
// CoreClaimName::new("email_verified".to_string()),
CoreClaimName::new("exp".to_string()),
CoreClaimName::new("iat".to_string()),
CoreClaimName::new("iss".to_string()),
// CoreClaimName::new("name".to_string()),
// CoreClaimName::new("given_name".to_string()),
// CoreClaimName::new("family_name".to_string()),
// CoreClaimName::new("picture".to_string()),
// CoreClaimName::new("locale".to_string()),
]))
.set_registration_endpoint(Some(RegistrationUrl::from_url(
base_url
.join(REGISTER_PATH)
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
)))
.set_token_endpoint_auth_methods_supported(Some(vec![
CoreClientAuthMethod::ClientSecretBasic,
CoreClientAuthMethod::ClientSecretPost,
]));
Ok(pm)
}
#[derive(Serialize, Deserialize)]
pub struct TokenForm {
pub code: String,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
}
pub async fn token(
form: TokenForm,
// From the request's Authorization header
secret: Option<String>,
private_key: RsaPrivateKey,
base_url: Url,
require_secret: bool,
db_client: &DBClientType,
) -> Result<CoreTokenResponse, CustomError> {
let code_entry = if let Some(c) = db_client.get_code(form.code.to_string()).await? {
c
} else {
return Err(CustomError::BadRequestToken(TokenError {
error: CoreErrorResponseType::InvalidGrant,
error_description: "Unknown code.".to_string(),
}));
};
let client_id = if let Some(c) = form.client_id.clone() {
c
} else {
code_entry.client_id.clone()
};
if let Some(secret) = if let Some(b) = secret {
Some(b)
} else {
form.client_secret.clone()
} {
let client_entry = db_client.get_client(client_id.clone()).await?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
if secret != client_entry.unwrap().secret {
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
}
} else if require_secret {
return Err(CustomError::Unauthorized("Secret required.".to_string()));
}
if code_entry.exchange_count > 0 {
// TODO use Oauth error response
return Err(CustomError::BadRequestToken(TokenError {
error: CoreErrorResponseType::InvalidGrant,
error_description: "Code was previously exchanged.".to_string(),
}));
}
let mut code_entry2 = code_entry.clone();
code_entry2.exchange_count += 1;
db_client
.set_code(form.code.to_string(), code_entry2)
.await?;
let access_token = AccessToken::new(form.code);
let core_id_token = CoreIdTokenClaims::new(
IssuerUrl::from_url(base_url),
vec![Audience::new(client_id.clone())],
Utc::now() + Duration::seconds(60),
Utc::now(),
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims {},
)
.set_nonce(code_entry.nonce);
let pem = private_key
.to_pkcs1_pem()
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
let id_token = CoreIdToken::new(
core_id_token,
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
Some(&access_token),
None,
)
.map_err(|e| anyhow!("{}", e))?;
Ok(CoreTokenResponse::new(
access_token,
CoreTokenType::Bearer,
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
))
}
#[derive(Deserialize)]
pub struct AuthorizeParams {
pub client_id: String,
pub redirect_uri: RedirectUrl,
pub scope: Scope,
pub response_type: Option<CoreResponseType>,
pub state: Option<String>,
pub nonce: Option<Nonce>,
pub prompt: Option<CoreAuthPrompt>,
pub request_uri: Option<RequestUrl>,
pub request: Option<String>,
}
pub async fn authorize(
params: AuthorizeParams,
nonce: String,
db_client: &DBClientType,
) -> Result<String, CustomError> {
let client_entry = db_client
.get_client(params.client_id.clone())
.await
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
if client_entry.is_none() {
return Err(CustomError::Unauthorized(
"Unrecognised client id.".to_string(),
));
}
let mut r_u = params.redirect_uri.clone().url().clone();
r_u.set_query(None);
let mut r_us: Vec<Url> = client_entry
.unwrap()
.redirect_uris
.iter_mut()
.map(|u| u.url().clone())
.collect();
r_us.iter_mut().for_each(|u| u.set_query(None));
if !r_us.contains(&r_u) {
return Err(CustomError::Redirect(
"/error?message=unregistered_request_uri".to_string(),
));
}
let state = if let Some(s) = params.state.clone() {
s
} else if params.request_uri.is_some() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
);
return Err(CustomError::Redirect(url.to_string()));
} else if params.request.is_some() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
);
return Err(CustomError::Redirect(url.to_string()));
} else {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing state");
return Err(CustomError::Redirect(url.to_string()));
};
if let Some(CoreAuthPrompt::None) = params.prompt {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut().append_pair(
"error",
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
);
return Err(CustomError::Redirect(url.to_string()));
}
if params.response_type.is_none() {
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("state", &state);
url.query_pairs_mut()
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
url.query_pairs_mut()
.append_pair("error_description", "Missing response_type");
return Err(CustomError::Redirect(url.to_string()));
}
let _response_type = params.response_type.as_ref().unwrap();
if params.scope != Scope::new("openid".to_string()) {
return Err(anyhow!("Scope not supported").into());
}
let domain = params.redirect_uri.url().host().unwrap();
let oidc_nonce_param = if let Some(n) = &params.nonce {
format!("&oidc_nonce={}", n.secret())
} else {
"".to_string()
};
Ok(format!(
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
nonce,
domain,
params.redirect_uri.to_string(),
state,
params.client_id,
oidc_nonce_param
))
}
#[derive(Serialize, Deserialize)]
pub struct SiweCookie {
message: Web3ModalMessage,
signature: String,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Web3ModalMessage {
pub domain: String,
pub address: String,
pub statement: String,
pub uri: String,
pub version: String,
pub chain_id: String,
pub nonce: String,
pub issued_at: String,
pub expiration_time: Option<String>,
pub not_before: Option<String>,
pub request_id: Option<String>,
pub resources: Option<Vec<String>>,
}
impl Web3ModalMessage {
fn to_eip4361_message(&self) -> Result<Message> {
let mut next_resources: Vec<UriString> = Vec::new();
match &self.resources {
Some(resources) => {
for resource in resources {
let x = UriString::from_str(resource)?;
next_resources.push(x)
}
}
None => {}
}
Ok(Message {
domain: self.domain.clone().try_into()?,
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
statement: self.statement.to_string(),
uri: UriString::from_str(&self.uri)?,
version: Version::from_str(&self.version)?,
chain_id: self.chain_id.to_string(),
nonce: self.nonce.to_string(),
issued_at: self.issued_at.to_string(),
expiration_time: self.expiration_time.clone(),
not_before: self.not_before.clone(),
request_id: self.request_id.clone(),
resources: next_resources,
})
}
}
#[derive(Deserialize)]
pub struct SignInParams {
pub redirect_uri: RedirectUrl,
pub state: String,
pub oidc_nonce: Option<Nonce>,
pub client_id: String,
}
pub async fn sign_in(
params: SignInParams,
expected_nonce: Option<String>,
cookies: headers::Cookie,
db_client: &DBClientType,
) -> Result<Url, CustomError> {
let siwe_cookie: SiweCookie = match cookies.get(SIWE_COOKIE_KEY) {
Some(c) => serde_json::from_str(
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
)
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
None => {
return Err(anyhow!("No `siwe` cookie").into());
}
};
let signature = match <[u8; 65]>::from_hex(
siwe_cookie
.signature
.chars()
.skip(2)
.take(130)
.collect::<String>(),
) {
Ok(s) => s,
Err(e) => {
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
}
};
let message = siwe_cookie
.message
.to_eip4361_message()
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
info!("{}", message);
message
.verify(signature)
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
let domain = params.redirect_uri.url().host().unwrap();
if domain.to_string() != siwe_cookie.message.domain {
return Err(anyhow!("Conflicting domains in message and redirect").into());
}
if expected_nonce.is_some() && expected_nonce.unwrap() != siwe_cookie.message.nonce {
return Err(anyhow!("Conflicting nonces in message and session").into());
}
let code_entry = CodeEntry {
address: siwe_cookie.message.address,
nonce: params.oidc_nonce.clone(),
exchange_count: 0,
client_id: params.client_id.clone(),
};
let code = Uuid::new_v4();
db_client.set_code(code.to_string(), code_entry).await?;
let mut url = params.redirect_uri.url().clone();
url.query_pairs_mut().append_pair("code", &code.to_string());
url.query_pairs_mut().append_pair("state", &params.state);
Ok(url)
}
pub async fn register(
payload: CoreClientMetadata,
db_client: &DBClientType,
) -> Result<CoreClientRegistrationResponse, CustomError> {
let id = Uuid::new_v4();
let secret = Uuid::new_v4();
let entry = ClientEntry {
secret: secret.to_string(),
redirect_uris: payload.redirect_uris().to_vec(),
};
db_client.set_client(id.to_string(), entry).await?;
Ok(CoreClientRegistrationResponse::new(
ClientId::new(id.to_string()),
payload.redirect_uris().to_vec(),
EmptyAdditionalClientMetadata::default(),
EmptyAdditionalClientRegistrationResponse::default(),
)
.set_client_secret(Some(ClientSecret::new(secret.to_string()))))
}
#[derive(Deserialize)]
pub struct UserInfoPayload {
pub access_token: Option<String>,
}
pub async fn userinfo(
bearer: Option<Bearer>,
payload: UserInfoPayload,
db_client: &DBClientType,
) -> Result<CoreUserInfoClaims, CustomError> {
let code = if let Some(b) = bearer {
b.token().to_string()
} else if let Some(c) = payload.access_token {
c
} else {
return Err(CustomError::BadRequest("Missing access token.".to_string()));
};
let code_entry = if let Some(c) = db_client.get_code(code).await? {
c
} else {
return Err(CustomError::BadRequest("Unknown code.".to_string()));
};
Ok(CoreUserInfoClaims::new(
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
EmptyAdditionalClaims::default(),
))
}

210
src/worker_lib.rs Normal file
View File

@@ -0,0 +1,210 @@
use anyhow::anyhow;
use headers::{
self,
authorization::{Basic, Bearer, Credentials},
Authorization, Header, HeaderValue,
};
use rand::{distributions::Alphanumeric, Rng};
use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey};
use worker::*;
use super::db::CFClient;
use super::oidc::{self, CustomError, TokenForm, UserInfoPayload};
const BASE_URL_KEY: &str = "BASE_URL";
const RSA_PEM_KEY: &str = "RSA_PEM";
// https://github.com/cloudflare/workers-rs/issues/64
// #[global_allocator]
// static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
impl From<CustomError> for Result<Response> {
fn from(error: CustomError) -> Self {
match error {
CustomError::BadRequest(_) => Response::error(&error.to_string(), 400),
CustomError::BadRequestToken(e) => Response::from_json(&e).map(|r| r.with_status(400)),
CustomError::Unauthorized(_) => Response::error(&error.to_string(), 401),
CustomError::Redirect(uri) => Response::redirect(uri.parse().unwrap()),
CustomError::Other(_) => Response::error(&error.to_string(), 500),
}
}
}
pub async fn main(req: Request, env: Env) -> Result<Response> {
console_error_panic_hook::set_once();
// tracing_subscriber::fmt::init();
// console_log::init_with_level(log::Level::Info).expect("error initializing log");
let userinfo = |mut req: Request, ctx: RouteContext<()>| async move {
let bearer = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(Bearer::decode);
let payload = if bearer.is_none() {
match req.form_data().await {
Ok(f) => {
let access_token = if let Some(FormEntry::Field(a)) = f.get("access_token") {
Some(a)
} else {
return Response::error("Missing code", 400);
};
UserInfoPayload { access_token }
}
Err(_) => return Response::error("Bad request", 400),
}
} else {
UserInfoPayload { access_token: None }
};
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::userinfo(bearer, payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?),
Err(e) => e.into(),
}
};
let router = Router::new();
router
.get_async(oidc::METADATA_PATH, |_req, ctx| async move {
match oidc::metadata(ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap()) {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
.get_async(oidc::JWK_PATH, |_req, ctx| async move {
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
match oidc::jwks(private_key) {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
.post_async(oidc::TOKEN_PATH, |mut req, ctx| async move {
let form_data = req.form_data().await?;
let code = if let Some(FormEntry::Field(c)) = form_data.get("code") {
c
} else {
return Response::error("Missing code", 400);
};
let client_id = match form_data.get("client_id") {
Some(FormEntry::Field(c)) => Some(c),
None => None,
_ => return Response::error("Client ID not a field", 400),
};
let client_secret = match form_data.get("client_secret") {
Some(FormEntry::Field(c)) => Some(c),
None => None,
_ => return Response::error("Client secret not a field", 400),
};
let grant_type = if let Some(FormEntry::Field(c)) = form_data.get("code") {
if let Ok(cc) = serde_json::from_str(&format!("\"{}\"", c)) {
cc
} else {
return Response::error("Invalid grant type", 400);
}
} else {
return Response::error("Missing grant type", 400);
};
let secret = req
.headers()
.get(Authorization::<Bearer>::name().as_str())?
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
.as_ref()
.and_then(|b| {
if b.to_str().unwrap().starts_with("Bearer") {
Bearer::decode(b).map(|bb| bb.token().to_string())
} else {
Basic::decode(b).map(|bb| bb.password().to_string())
}
});
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
.map_err(|e| anyhow!("Failed to load private key: {}", e))
.unwrap();
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?;
let db_client = CFClient { ctx, url };
let token_response = oidc::token(
TokenForm {
code,
client_id,
client_secret,
grant_type,
},
secret,
private_key,
base_url,
false,
&db_client,
)
.await;
match token_response {
Ok(m) => Response::from_json(&m),
Err(e) => e.into(),
}
})
// TODO add browser session
.get_async(oidc::AUTHORIZE_PATH, |req, ctx| async move {
let base_url: Url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
let url = req.url()?;
let query = url.query().unwrap_or_default();
let params = match serde_urlencoded::from_str(query) {
Ok(p) => p,
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
};
let nonce = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::authorize(params, nonce, &db_client).await {
Ok(url) => Response::redirect(base_url.join(&url).unwrap()),
Err(e) => match e {
CustomError::Redirect(url) => {
CustomError::Redirect(base_url.join(&url).unwrap().to_string())
}
c => c,
}
.into(),
}
})
.post_async(oidc::REGISTER_PATH, |mut req, ctx| async move {
let payload = req.json().await?;
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::register(payload, &db_client).await {
Ok(r) => Ok(Response::from_json(&r)?.with_status(201)),
Err(e) => e.into(),
}
})
.post_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::USERINFO_PATH, userinfo)
.get_async(oidc::SIGNIN_PATH, |req, ctx| async move {
let url = req.url()?;
let query = url.query().unwrap_or_default();
let params = match serde_urlencoded::from_str(query) {
Ok(p) => p,
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
};
let cookies = req
.headers()
.get(headers::Cookie::name().as_str())?
.and_then(|c| HeaderValue::from_str(&c).ok())
.and_then(|c| headers::Cookie::decode(&mut [c].iter()).ok());
if cookies.is_none() {
return Response::error("Missing cookies", 400);
}
let url = req.url()?;
let db_client = CFClient { ctx, url };
match oidc::sign_in(params, None, cookies.unwrap(), &db_client).await {
Ok(url) => Response::redirect(url),
Err(e) => e.into(),
}
})
.run(req, env)
.await
}