32
src/config.rs
Normal file
32
src/config.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Config {
|
||||
pub address: IpAddr,
|
||||
pub port: u16,
|
||||
pub base_url: Url,
|
||||
pub rsa_pem: Option<String>,
|
||||
pub redis_url: Url,
|
||||
pub default_clients: HashMap<String, String>,
|
||||
// TODO secret is more complicated than that, and needs to be in the well-known config
|
||||
pub require_secret: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
address: Ipv4Addr::new(127, 0, 0, 1).into(),
|
||||
port: 8000,
|
||||
base_url: Url::parse("http://127.0.0.1:8000").unwrap(),
|
||||
rsa_pem: None,
|
||||
redis_url: Url::parse("redis://localhost").unwrap(),
|
||||
default_clients: HashMap::default(),
|
||||
require_secret: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
707
src/main.rs
Normal file
707
src/main.rs
Normal file
@@ -0,0 +1,707 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use axum::{
|
||||
body::{Bytes, Full},
|
||||
error_handling::HandleErrorExt,
|
||||
extract::{self, Extension, Form, Query, TypedHeader},
|
||||
http::{
|
||||
header::{self, HeaderMap},
|
||||
Response, StatusCode,
|
||||
},
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::{get, post, service_method_routing},
|
||||
AddExtensionLayer, Json, Router,
|
||||
};
|
||||
use bb8_redis::{bb8, bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||
use chrono::{Duration, Utc};
|
||||
use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
Figment,
|
||||
};
|
||||
use headers::{self, authorization::Bearer, Authorization};
|
||||
use hex::FromHex;
|
||||
use iri_string::types::{UriAbsoluteString, UriString};
|
||||
use openidconnect::{
|
||||
core::{
|
||||
CoreClaimName, CoreClientAuthMethod, CoreClientMetadata, CoreClientRegistrationResponse,
|
||||
CoreGrantType, CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
|
||||
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
|
||||
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
|
||||
},
|
||||
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
|
||||
AccessToken, Audience, AuthUrl, ClientId, EmptyAdditionalClaims,
|
||||
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
|
||||
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, ResponseTypes, Scope,
|
||||
StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
|
||||
};
|
||||
use rand::rngs::OsRng;
|
||||
use rsa::{
|
||||
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
|
||||
RsaPrivateKey,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use siwe::eip4361::{Message, Version};
|
||||
use std::{convert::Infallible, net::SocketAddr, str::FromStr};
|
||||
use thiserror::Error;
|
||||
use tower_http::{
|
||||
services::{ServeDir, ServeFile},
|
||||
trace::TraceLayer,
|
||||
};
|
||||
use tracing::info;
|
||||
use urlencoding::decode;
|
||||
use uuid::Uuid;
|
||||
|
||||
mod config;
|
||||
mod session;
|
||||
|
||||
use session::*;
|
||||
|
||||
const KID: &str = "key1";
|
||||
const KV_CLIENT_PREFIX: &str = "clients";
|
||||
const ENTRY_LIFETIME: usize = 60 * 60 * 24 * 2;
|
||||
|
||||
type ConnectionPool = Pool<RedisConnectionManager>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CustomError {
|
||||
#[error("{0}")]
|
||||
BadRequest(String),
|
||||
#[error("{0}")]
|
||||
Unauthorized(String),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl IntoResponse for CustomError {
|
||||
type Body = Full<Bytes>;
|
||||
type BodyError = Infallible;
|
||||
|
||||
fn into_response(self) -> Response<Self::Body> {
|
||||
match self {
|
||||
CustomError::BadRequest(_) => (StatusCode::BAD_REQUEST, self.to_string()),
|
||||
CustomError::Unauthorized(_) => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
CustomError::Other(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
||||
}
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn jwk_set(
|
||||
Extension(private_key): Extension<RsaPrivateKey>,
|
||||
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
|
||||
let pem = private_key
|
||||
.to_pkcs1_pem()
|
||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
|
||||
&pem,
|
||||
Some(JsonWebKeyId::new(KID.to_string())),
|
||||
)
|
||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
|
||||
.as_verification_key()]);
|
||||
Ok(jwks.into())
|
||||
}
|
||||
|
||||
async fn provider_metadata(
|
||||
Extension(config): Extension<config::Config>,
|
||||
) -> Result<Json<CoreProviderMetadata>, CustomError> {
|
||||
let pm = CoreProviderMetadata::new(
|
||||
IssuerUrl::from_url(config.base_url.clone()),
|
||||
AuthUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("authorize")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
),
|
||||
JsonWebKeySetUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("jwk")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
),
|
||||
vec![
|
||||
ResponseTypes::new(vec![CoreResponseType::Code]),
|
||||
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
|
||||
],
|
||||
vec![CoreSubjectIdentifierType::Pairwise],
|
||||
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
|
||||
EmptyAdditionalProviderMetadata {},
|
||||
)
|
||||
.set_token_endpoint(Some(TokenUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("token")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("userinfo")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_scopes_supported(Some(vec![
|
||||
Scope::new("openid".to_string()),
|
||||
// Scope::new("email".to_string()),
|
||||
// Scope::new("profile".to_string()),
|
||||
]))
|
||||
.set_claims_supported(Some(vec![
|
||||
CoreClaimName::new("sub".to_string()),
|
||||
CoreClaimName::new("aud".to_string()),
|
||||
// CoreClaimName::new("email".to_string()),
|
||||
// CoreClaimName::new("email_verified".to_string()),
|
||||
CoreClaimName::new("exp".to_string()),
|
||||
CoreClaimName::new("iat".to_string()),
|
||||
CoreClaimName::new("iss".to_string()),
|
||||
// CoreClaimName::new("name".to_string()),
|
||||
// CoreClaimName::new("given_name".to_string()),
|
||||
// CoreClaimName::new("family_name".to_string()),
|
||||
// CoreClaimName::new("picture".to_string()),
|
||||
// CoreClaimName::new("locale".to_string()),
|
||||
]))
|
||||
.set_registration_endpoint(Some(RegistrationUrl::from_url(
|
||||
config
|
||||
.base_url
|
||||
.join("register")
|
||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||
)))
|
||||
.set_token_endpoint_auth_methods_supported(Some(vec![CoreClientAuthMethod::ClientSecretPost]));
|
||||
|
||||
Ok(pm.into())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenForm {
|
||||
code: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
|
||||
}
|
||||
|
||||
// TODO should check Authorization header
|
||||
// Actually, client secret can be
|
||||
// 1. in the POST (currently supported)
|
||||
// 2. Authorization header
|
||||
// 3. JWT
|
||||
// 4. signed JWT
|
||||
// according to Keycloak
|
||||
|
||||
async fn token(
|
||||
form: Form<TokenForm>,
|
||||
Extension(private_key): Extension<RsaPrivateKey>,
|
||||
Extension(config): Extension<config::Config>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<Json<CoreTokenResponse>, CustomError> {
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
|
||||
if let Some(secret) = form.client_secret.clone() {
|
||||
let stored_secret: Option<String> = conn
|
||||
.get(format!("{}/{}", KV_CLIENT_PREFIX, form.client_id))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if stored_secret.is_none() {
|
||||
Err(CustomError::Unauthorized(
|
||||
"Unrecognised client id.".to_string(),
|
||||
))?;
|
||||
}
|
||||
if secret != stored_secret.unwrap() {
|
||||
Err(CustomError::Unauthorized("Bad secret.".to_string()))?;
|
||||
}
|
||||
} else if config.require_secret {
|
||||
Err(CustomError::Unauthorized("Secret required.".to_string()))?;
|
||||
}
|
||||
|
||||
let serialized_entry: Option<Vec<u8>> = conn
|
||||
.get(form.code.to_string())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if serialized_entry.is_none() {
|
||||
Err(CustomError::BadRequest("Unknown code.".to_string()))?;
|
||||
}
|
||||
let code_entry: CodeEntry = bincode::deserialize(
|
||||
&hex::decode(serialized_entry.unwrap())
|
||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||
|
||||
if code_entry.exchange_count > 0 {
|
||||
// TODO use Oauth error response
|
||||
Err(anyhow!("Code was previously exchanged."))?;
|
||||
}
|
||||
conn.set_ex(
|
||||
form.code.to_string(),
|
||||
hex::encode(
|
||||
bincode::serialize(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||
),
|
||||
ENTRY_LIFETIME,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
|
||||
let access_token = AccessToken::new(form.code.clone());
|
||||
let core_id_token = CoreIdTokenClaims::new(
|
||||
IssuerUrl::from_url(config.base_url),
|
||||
vec![Audience::new(form.client_id.clone())],
|
||||
Utc::now() + Duration::seconds(60),
|
||||
Utc::now(),
|
||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||
EmptyAdditionalClaims {},
|
||||
)
|
||||
.set_nonce(code_entry.nonce);
|
||||
|
||||
let pem = private_key
|
||||
.to_pkcs1_pem()
|
||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||
|
||||
let id_token = CoreIdToken::new(
|
||||
core_id_token,
|
||||
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
|
||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
|
||||
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
|
||||
Some(&access_token),
|
||||
None,
|
||||
)
|
||||
.map_err(|e| anyhow!("{}", e))?;
|
||||
|
||||
Ok(CoreTokenResponse::new(
|
||||
access_token,
|
||||
CoreTokenType::Bearer,
|
||||
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeParams {
|
||||
client_id: String,
|
||||
redirect_uri: RedirectUrl,
|
||||
scope: Scope,
|
||||
response_type: CoreResponseType,
|
||||
state: String,
|
||||
nonce: Option<Nonce>,
|
||||
}
|
||||
|
||||
// TODO handle `registration` parameter
|
||||
async fn authorize(
|
||||
session: UserSessionFromSession,
|
||||
params: Query<AuthorizeParams>,
|
||||
// Extension(private_key): Extension<RsaPrivateKey>,
|
||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||
// TODO: Enforce Client Registration
|
||||
// let d = std::str::from_utf8(
|
||||
// &jwk.decrypt(
|
||||
// PaddingScheme::new_pkcs1v15_encrypt(),
|
||||
// ¶ms.client_id.as_bytes(),
|
||||
// )
|
||||
// .map_err(|e| anyhow!("Failed to decrypt client id: {}", e))?,
|
||||
// )
|
||||
// .map_err(|e| anyhow!("Failed to decrypt client id: {}", e))?
|
||||
// if d != params.redirect_uri.as_str() {
|
||||
// return Err(anyhow!("Client id not composed of redirect url"));
|
||||
// };
|
||||
|
||||
if params.scope != Scope::new("openid".to_string()) {
|
||||
Err(anyhow!("Scope not supported"))?;
|
||||
}
|
||||
|
||||
let (nonce, headers) = match session {
|
||||
UserSessionFromSession::FoundUserSession(nonce) => (nonce, HeaderMap::new()),
|
||||
UserSessionFromSession::InvalidUserSession(cookie) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, cookie);
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}{}",
|
||||
¶ms.0.client_id,
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.scope.to_string(),
|
||||
¶ms.0.response_type.as_ref(),
|
||||
¶ms.0.state,
|
||||
¶ms.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or(String::new())
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
UserSessionFromSession::CreatedFreshUserSession { header, nonce } => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::SET_COOKIE, header);
|
||||
(nonce, headers)
|
||||
}
|
||||
};
|
||||
|
||||
let domain = params.redirect_uri.url().host().unwrap();
|
||||
let oidc_nonce_param = if let Some(n) = ¶ms.nonce {
|
||||
format!("&oidc_nonce={}", n.secret())
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/?nonce={}&domain={}&redirect_uri={}&state={}{}",
|
||||
nonce,
|
||||
domain,
|
||||
params.redirect_uri.to_string(),
|
||||
params.state,
|
||||
oidc_nonce_param
|
||||
)
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SiweCookie {
|
||||
message: Web3ModalMessage,
|
||||
signature: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Web3ModalMessage {
|
||||
pub domain: String,
|
||||
pub address: String,
|
||||
pub statement: String,
|
||||
pub uri: String,
|
||||
pub version: String,
|
||||
pub chain_id: String,
|
||||
pub nonce: String,
|
||||
pub issued_at: String,
|
||||
pub expiration_time: Option<String>,
|
||||
pub not_before: Option<String>,
|
||||
pub request_id: Option<String>,
|
||||
pub resources: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Web3ModalMessage {
|
||||
pub fn to_eip4361_message(&self) -> Result<Message> {
|
||||
let mut next_resources: Vec<UriString> = Vec::new();
|
||||
match &self.resources {
|
||||
Some(resources) => {
|
||||
for resource in resources {
|
||||
let x = UriString::from_str(resource)?;
|
||||
next_resources.push(x)
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
Ok(Message {
|
||||
domain: self.domain.clone().try_into()?,
|
||||
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
|
||||
statement: self.statement.to_string(),
|
||||
uri: UriAbsoluteString::from_str(&self.uri)?,
|
||||
version: Version::from_str(&self.version)?,
|
||||
chain_id: self.chain_id.to_string(),
|
||||
nonce: self.nonce.to_string(),
|
||||
issued_at: self.issued_at.to_string(),
|
||||
expiration_time: self.expiration_time.clone(),
|
||||
not_before: self.not_before.clone(),
|
||||
request_id: self.request_id.clone(),
|
||||
resources: next_resources,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct CodeEntry {
|
||||
exchange_count: usize,
|
||||
address: String,
|
||||
nonce: Option<Nonce>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SignInParams {
|
||||
redirect_uri: RedirectUrl,
|
||||
state: String,
|
||||
oidc_nonce: Option<Nonce>,
|
||||
}
|
||||
|
||||
async fn sign_in(
|
||||
session: UserSessionFromSession,
|
||||
params: Query<SignInParams>,
|
||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||
let mut headers = HeaderMap::new();
|
||||
let siwe_cookie: SiweCookie = match cookies.get("siwe") {
|
||||
Some(c) => serde_json::from_str(
|
||||
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
|
||||
None => Err(anyhow!("No `siwe` cookie"))?,
|
||||
};
|
||||
|
||||
let (nonce, headers) = match session {
|
||||
UserSessionFromSession::FoundUserSession(nonce) => (nonce, HeaderMap::new()),
|
||||
UserSessionFromSession::InvalidUserSession(header) => {
|
||||
headers.insert(header::SET_COOKIE, header);
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.state,
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
));
|
||||
}
|
||||
UserSessionFromSession::CreatedFreshUserSession { .. } => {
|
||||
return Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
format!(
|
||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.redirect_uri.to_string(),
|
||||
¶ms.0.state,
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let signature = match <[u8; 65]>::from_hex(
|
||||
siwe_cookie
|
||||
.signature
|
||||
.chars()
|
||||
.skip(2)
|
||||
.take(130)
|
||||
.collect::<String>()
|
||||
.clone(),
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(e) => Err(CustomError::BadRequest(format!("Bad signature: {}", e)))?,
|
||||
};
|
||||
|
||||
let message = siwe_cookie
|
||||
.message
|
||||
.to_eip4361_message()
|
||||
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
|
||||
info!("{}", message);
|
||||
message
|
||||
.verify_eip191(signature)
|
||||
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
|
||||
|
||||
let domain = params.redirect_uri.url().host().unwrap();
|
||||
if domain.to_string() != siwe_cookie.message.domain {
|
||||
Err(anyhow!("Conflicting domains in message and redirect"))?
|
||||
}
|
||||
if nonce != siwe_cookie.message.nonce {
|
||||
Err(anyhow!("Conflicting nonces in message and session"))?
|
||||
}
|
||||
|
||||
let code_entry = CodeEntry {
|
||||
address: siwe_cookie.message.address,
|
||||
nonce: params.oidc_nonce.clone(),
|
||||
exchange_count: 0,
|
||||
};
|
||||
|
||||
let code = Uuid::new_v4();
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
conn.set_ex(
|
||||
code.to_string(),
|
||||
hex::encode(
|
||||
bincode::serialize(&code_entry)
|
||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||
),
|
||||
ENTRY_LIFETIME,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
|
||||
let mut url = params.redirect_uri.url().clone();
|
||||
url.query_pairs_mut().append_pair("code", &code.to_string());
|
||||
url.query_pairs_mut().append_pair("state", ¶ms.state);
|
||||
Ok((
|
||||
headers,
|
||||
Redirect::to(
|
||||
url.as_str()
|
||||
.parse()
|
||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||
),
|
||||
))
|
||||
// TODO clear session
|
||||
}
|
||||
|
||||
async fn register(
|
||||
extract::Json(payload): extract::Json<CoreClientMetadata>,
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<Json<CoreClientRegistrationResponse>, CustomError> {
|
||||
let id = Uuid::new_v4();
|
||||
let secret = Uuid::new_v4();
|
||||
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
conn.set(format!("{}/{}", KV_CLIENT_PREFIX, id), secret.to_string())
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||
|
||||
Ok(CoreClientRegistrationResponse::new(
|
||||
ClientId::new(id.to_string()),
|
||||
payload.redirect_uris().to_vec(),
|
||||
EmptyAdditionalClientMetadata::default(),
|
||||
EmptyAdditionalClientRegistrationResponse::default(),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
// TODO CORS
|
||||
// TODO need validation of the token
|
||||
// TODO restrict access token use to only once?
|
||||
async fn userinfo(
|
||||
// access_token: AccessTokenUserInfo, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
||||
TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
||||
Extension(pool): Extension<ConnectionPool>,
|
||||
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
|
||||
let code = bearer.token().to_string();
|
||||
let mut conn = pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||
let serialized_entry: Option<Vec<u8>> = conn
|
||||
.get(code)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||
if serialized_entry.is_none() {
|
||||
Err(CustomError::BadRequest("Unknown code.".to_string()))?;
|
||||
}
|
||||
let code_entry: CodeEntry = bincode::deserialize(
|
||||
&hex::decode(serialized_entry.unwrap())
|
||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||
|
||||
Ok(CoreUserInfoClaims::new(
|
||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||
EmptyAdditionalClaims::default(),
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
async fn healthcheck() {}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let config = Figment::from(Serialized::defaults(config::Config::default()))
|
||||
.merge(Toml::file("siwe-oidc.toml").nested())
|
||||
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
|
||||
let config = config.extract::<config::Config>().unwrap();
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
|
||||
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
|
||||
let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
|
||||
|
||||
let mut conn = pool2
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))
|
||||
.unwrap();
|
||||
for (id, secret) in &config.default_clients.clone() {
|
||||
let _: () = conn
|
||||
.set(format!("{}/{}", KV_CLIENT_PREFIX, id), secret)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let private_key = if let Some(key) = &config.rsa_pem {
|
||||
RsaPrivateKey::from_pkcs1_pem(&key)
|
||||
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||
.unwrap()
|
||||
} else {
|
||||
info!("Generating key...");
|
||||
let mut rng = OsRng;
|
||||
let bits = 2048;
|
||||
let private = RsaPrivateKey::new(&mut rng, bits)
|
||||
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
|
||||
.unwrap();
|
||||
|
||||
info!("Generated key.");
|
||||
info!("{:?}", private.to_pkcs1_pem().unwrap());
|
||||
private
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/build",
|
||||
service_method_routing::get(ServeDir::new("./static/build")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/",
|
||||
service_method_routing::get(ServeFile::new("./static/index.html")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/favicon.png",
|
||||
service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error(
|
||||
|error: std::io::Error| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route("/.well-known/openid-configuration", get(provider_metadata))
|
||||
.route("/jwk", get(jwk_set))
|
||||
.route("/token", post(token))
|
||||
.route("/authorize", get(authorize))
|
||||
.route("/register", post(register))
|
||||
.route("/userinfo", get(userinfo).post(userinfo))
|
||||
.route("/sign_in", get(sign_in))
|
||||
.route("/health", get(healthcheck))
|
||||
.layer(AddExtensionLayer::new(private_key))
|
||||
.layer(AddExtensionLayer::new(config.clone()))
|
||||
.layer(AddExtensionLayer::new(pool))
|
||||
.layer(AddExtensionLayer::new(
|
||||
RedisSessionStore::new(config.redis_url.clone())
|
||||
.unwrap()
|
||||
.with_prefix("async-sessions/"),
|
||||
))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let addr = SocketAddr::from((config.address, config.port));
|
||||
tracing::info!("Listening on {}", addr);
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
124
src/session.rs
Normal file
124
src/session.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use async_redis_session::RedisSessionStore;
|
||||
use async_session::{Session, SessionStore as _};
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{Extension, FromRequest, RequestParts},
|
||||
http::{self, header::HeaderValue, StatusCode},
|
||||
};
|
||||
use cookie::Cookie;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
|
||||
const SESSION_COOKIE_NAME: &str = "session";
|
||||
const SESSION_KEY: &str = "user_session";
|
||||
|
||||
pub enum UserSessionFromSession {
|
||||
FoundUserSession(String),
|
||||
CreatedFreshUserSession { header: HeaderValue, nonce: String },
|
||||
InvalidUserSession(HeaderValue),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for UserSessionFromSession
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = (StatusCode, String);
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let Extension(store) = match Extension::<RedisSessionStore>::from_request(req).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("`MemoryStore` extension missing: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let headers = if let Some(h) = req.headers() {
|
||||
h
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"other extractor taken headers".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let session_cookie: Cookie = if let Some(session_cookie) = headers
|
||||
.get(http::header::COOKIE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(|header| {
|
||||
header
|
||||
.split(";")
|
||||
.map(|cookie| Cookie::parse(cookie).ok())
|
||||
.filter(|cookie| {
|
||||
cookie.is_some() && cookie.as_ref().unwrap().name() == SESSION_COOKIE_NAME
|
||||
})
|
||||
.next()
|
||||
})
|
||||
.flatten()
|
||||
.flatten()
|
||||
{
|
||||
session_cookie
|
||||
} else {
|
||||
let user_session = UserSession::new();
|
||||
let mut session = Session::new();
|
||||
session.insert(SESSION_KEY, user_session.clone()).unwrap();
|
||||
let cookie = store.store_session(session).await.unwrap().unwrap();
|
||||
|
||||
return Ok(Self::CreatedFreshUserSession {
|
||||
header: Cookie::new(SESSION_COOKIE_NAME, cookie)
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
nonce: user_session.nonce,
|
||||
});
|
||||
};
|
||||
|
||||
let session = match store.load_session(session_cookie.value().to_string()).await {
|
||||
Ok(Some(s)) => s,
|
||||
_ => {
|
||||
debug!("Could not load session");
|
||||
let mut cookie = session_cookie.clone();
|
||||
cookie.make_removal();
|
||||
return Ok(Self::InvalidUserSession(
|
||||
cookie.to_string().parse().unwrap(),
|
||||
));
|
||||
}
|
||||
};
|
||||
let user_session = if let Some(user_session) = session.get::<UserSession>(SESSION_KEY) {
|
||||
user_session
|
||||
} else {
|
||||
debug!("No `user_session` found in session");
|
||||
let mut cookie = session_cookie.clone();
|
||||
cookie.make_removal();
|
||||
return Ok(Self::InvalidUserSession(
|
||||
cookie.to_string().parse().unwrap(),
|
||||
));
|
||||
};
|
||||
|
||||
Ok(Self::FoundUserSession(user_session.nonce))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct UserSession {
|
||||
id: Uuid,
|
||||
nonce: String,
|
||||
}
|
||||
|
||||
impl UserSession {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4(),
|
||||
nonce: rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(16)
|
||||
.map(char::from)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user