use anyhow::{anyhow, Result}; use async_redis_session::RedisSessionStore; use axum::{ extract::{self, Extension, Form, Path, Query, TypedHeader}, http::{ header::{self, HeaderMap}, StatusCode, }, response::{self, IntoResponse, Redirect}, routing::{get, get_service, post}, AddExtensionLayer, Json, Router, }; use bb8_redis::{bb8, RedisConnectionManager}; use figment::{ providers::{Env, Format, Serialized, Toml}, Figment, }; use headers::{ self, authorization::{Basic, Bearer}, Authorization, ContentType, Header, }; use openidconnect::core::{ CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata, CoreResponseType, CoreTokenResponse, CoreUserInfoClaims, CoreUserInfoJsonWebToken, }; use rand::rngs::OsRng; use rsa::{ pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey}, RsaPrivateKey, }; use std::net::SocketAddr; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, }; use tracing::info; use super::config; use super::oidc::{self, CustomError}; use super::session::*; use ::siwe_oidc::db::*; impl IntoResponse for CustomError { fn into_response(self) -> response::Response { match self { CustomError::BadRequest(_) => { (StatusCode::BAD_REQUEST, self.to_string()).into_response() } CustomError::BadRequestRegister(e) => { (StatusCode::BAD_REQUEST, Json::from(e)).into_response() } CustomError::BadRequestToken(e) => { (StatusCode::BAD_REQUEST, Json::from(e)).into_response() } CustomError::Unauthorized(_) => { (StatusCode::UNAUTHORIZED, self.to_string()).into_response() } CustomError::NotFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(), CustomError::Redirect(uri) => Redirect::to( uri.parse().unwrap(), // .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ) .into_response(), CustomError::Other(_) => { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() } } } } async fn jwk_set( Extension(private_key): Extension, ) -> Result, CustomError> { let jwks = oidc::jwks(private_key)?; Ok(jwks.into()) } async fn provider_metadata( Extension(config): Extension, ) -> Result, CustomError> { Ok(oidc::metadata(config.base_url)?.into()) } // TODO should check Authorization header // Actually, client secret can be // 1. in the POST (currently supported) [x] // 2. Authorization header [x] // 3. JWT [ ] // 4. signed JWT [ ] // according to Keycloak async fn token( Form(form): Form, bearer: Option>>, basic: Option>>, Extension(private_key): Extension, Extension(config): Extension, Extension(redis_client): Extension, ) -> Result, CustomError> { let secret = if let Some(b) = bearer { Some(b.0 .0.token().to_string()) } else { basic.map(|b| b.0 .0.password().to_string()) }; let token_response = oidc::token( form, secret, private_key, config.base_url, config.require_secret, config.eth_provider, &redis_client, ) .await?; Ok(token_response.into()) } // TODO handle `registration` parameter async fn authorize( session: UserSessionFromSession, Query(params): Query, Extension(redis_client): Extension, ) -> Result<(HeaderMap, Redirect), CustomError> { let (nonce, headers) = match session { UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()), UserSessionFromSession::Invalid(cookie) => { let mut headers = HeaderMap::new(); headers.insert(header::SET_COOKIE, cookie); return Ok(( headers, Redirect::to( format!( "/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}", ¶ms.client_id, ¶ms.redirect_uri.to_string(), ¶ms.scope.to_string(), ¶ms.response_type.unwrap_or(CoreResponseType::Code).as_ref(), ¶ms.state.unwrap_or_default(), ¶ms.client_id, ¶ms.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default() ) .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ), )); } UserSessionFromSession::Created { header, nonce } => { let mut headers = HeaderMap::new(); headers.insert(header::SET_COOKIE, header); (nonce, headers) } }; let url = oidc::authorize(params, nonce, &redis_client).await?; Ok(( headers, Redirect::to( url.as_str() .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ), )) } async fn sign_in( session: UserSessionFromSession, Query(params): Query, TypedHeader(cookies): TypedHeader, Extension(redis_client): Extension, ) -> Result<(HeaderMap, Redirect), CustomError> { let (nonce, headers) = match session { UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()), UserSessionFromSession::Invalid(header) => { let mut headers = HeaderMap::new(); headers.insert(header::SET_COOKIE, header); return Ok(( headers, Redirect::to( format!( "/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}", ¶ms.client_id.clone(), ¶ms.redirect_uri.to_string(), ¶ms.state, ) .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ), )); } UserSessionFromSession::Created { .. } => { return Ok(( HeaderMap::new(), Redirect::to( format!( "/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}", ¶ms.client_id.clone(), ¶ms.redirect_uri.to_string(), ¶ms.state, ) .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ), )) } }; let url = oidc::sign_in(params, Some(nonce), cookies, &redis_client).await?; Ok(( headers, Redirect::to( url.as_str() .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ), )) // TODO clear session } async fn register( extract::Json(payload): extract::Json, Extension(redis_client): Extension, ) -> Result<(StatusCode, Json), CustomError> { let registration = oidc::register(payload, &redis_client).await?; Ok((StatusCode::CREATED, registration.into())) } struct UserInfoResponseJWT(Json); impl IntoResponse for UserInfoResponseJWT { fn into_response(self) -> response::Response { response::Response::builder() .status(StatusCode::OK) .header(ContentType::name(), "application/jwt") .body( serde_json::to_string(&self.0 .0) .unwrap() .replace('"', "") .into_response() .into_body(), ) .unwrap() } } enum UserInfoResponse { Json(Json), Jwt(UserInfoResponseJWT), } impl IntoResponse for UserInfoResponse { fn into_response(self) -> response::Response { match self { UserInfoResponse::Json(j) => j.into_response(), UserInfoResponse::Jwt(j) => j.into_response(), } } } // TODO CORS // TODO need validation of the token async fn userinfo( Extension(private_key): Extension, Extension(config): Extension, payload: Option>, bearer: Option>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs Extension(redis_client): Extension, ) -> Result { let payload = if let Some(Form(p)) = payload { p } else { oidc::UserInfoPayload { access_token: None } }; let claims = oidc::userinfo( config.base_url, config.eth_provider, private_key, bearer.map(|b| b.0 .0), payload, &redis_client, ) .await?; Ok(match claims { oidc::UserInfoResponse::Json(c) => UserInfoResponse::Json(c.into()), oidc::UserInfoResponse::Jwt(c) => UserInfoResponse::Jwt(UserInfoResponseJWT(c.into())), }) } async fn clientinfo( Path(client_id): Path, Extension(redis_client): Extension, ) -> Result, CustomError> { Ok(oidc::clientinfo(client_id, &redis_client).await?.into()) } async fn healthcheck() {} pub async fn main() { let config = Figment::from(Serialized::defaults(config::Config::default())) .merge(Toml::file("siwe-oidc.toml").nested()) .merge(Env::prefixed("SIWEOIDC_").split("__").global()); let config = config.extract::().unwrap(); tracing_subscriber::fmt::init(); let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap(); let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap(); // let pool2 = bb8::Pool::builder().build(manager).await.unwrap(); let redis_client = RedisClient { pool }; // for (id, secret) in &config.default_clients.clone() { // let client_entry = ClientEntry { // secret: secret.to_string(), // redirect_uris: vec![], // }; // redis_client // .set_client(id.to_string(), client_entry) // .await // .unwrap(); // TODO // } let private_key = if let Some(key) = &config.rsa_pem { RsaPrivateKey::from_pkcs1_pem(key) .map_err(|e| anyhow!("Failed to load private key: {}", e)) .unwrap() } else { info!("Generating key..."); let mut rng = OsRng; let bits = 2048; let private = RsaPrivateKey::new(&mut rng, bits) .map_err(|e| anyhow!("Failed to generate a key: {}", e)) .unwrap(); info!("Generated key."); info!("{:?}", private.to_pkcs1_pem().unwrap()); private }; let app = Router::new() .nest( "/build", get_service(ServeDir::new("./static/build")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .nest( "/img", get_service(ServeDir::new("./static/img")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route( "/", get_service(ServeFile::new("./static/index.html")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route( "/error", get_service(ServeFile::new("./static/error.html")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route( "/favicon.png", get_service(ServeFile::new("./static/favicon.png")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route(oidc::METADATA_PATH, get(provider_metadata)) .route(oidc::JWK_PATH, get(jwk_set)) .route(oidc::TOKEN_PATH, post(token)) .route(oidc::AUTHORIZE_PATH, get(authorize)) .route(oidc::REGISTER_PATH, post(register)) .route(oidc::USERINFO_PATH, get(userinfo).post(userinfo)) .route(&format!("{}/:id", oidc::CLIENTINFO_PATH), get(clientinfo)) .route(oidc::SIGNIN_PATH, get(sign_in)) .route("/health", get(healthcheck)) .layer(AddExtensionLayer::new(private_key)) .layer(AddExtensionLayer::new(config.clone())) .layer(AddExtensionLayer::new(redis_client)) .layer(AddExtensionLayer::new( RedisSessionStore::new(config.redis_url.clone()) .unwrap() .with_prefix("async-sessions/"), )) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from((config.address, config.port)); tracing::info!("Listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); }