use anyhow::{anyhow, Result}; use axum::{ extract::{self, Extension, Form, Path, Query, TypedHeader}, http::{ header::{self, HeaderMap}, StatusCode, }, response::{self, IntoResponse, Redirect}, routing::{delete, get, get_service, post}, Json, Router, }; use figment::{ providers::{Env, Format, Serialized, Toml}, Figment, }; use headers::{ self, authorization::{Basic, Bearer}, Authorization, ContentType, Header, }; use openidconnect::core::{ CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata, CoreTokenResponse, CoreUserInfoClaims, CoreUserInfoJsonWebToken, }; use rand::rngs::OsRng; use rsa::{ pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey, LineEnding}, RsaPrivateKey, }; use std::net::SocketAddr; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, }; use tracing::info; use super::config; use super::oidc::{self, CustomError}; use ::siwe_oidc::db::*; impl IntoResponse for CustomError { fn into_response(self) -> response::Response { match self { CustomError::BadRequest(_) => { (StatusCode::BAD_REQUEST, self.to_string()).into_response() } CustomError::BadRequestRegister(e) => { (StatusCode::BAD_REQUEST, Json::from(e)).into_response() } CustomError::BadRequestToken(e) => { (StatusCode::BAD_REQUEST, Json::from(e)).into_response() } CustomError::Unauthorized(_) => { (StatusCode::UNAUTHORIZED, self.to_string()).into_response() } CustomError::NotFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(), CustomError::Redirect(uri) => Redirect::to( uri.parse().unwrap(), // .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ) .into_response(), CustomError::Other(_) => { (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() } } } } async fn jwk_set( Extension(private_key): Extension, ) -> Result, CustomError> { let jwks = oidc::jwks(private_key)?; Ok(jwks.into()) } async fn provider_metadata( Extension(config): Extension, ) -> Result, CustomError> { Ok(oidc::metadata(config.base_url)?.into()) } async fn token( Form(form): Form, bearer: Option>>, basic: Option>>, Extension(private_key): Extension, Extension(config): Extension, Extension(redis_client): Extension, ) -> Result, CustomError> { let secret = if let Some(b) = bearer { Some(b.0 .0.token().to_string()) } else { basic.map(|b| b.0 .0.password().to_string()) }; let token_response = oidc::token( form, secret, private_key, config.base_url, config.require_secret, config.eth_provider, &redis_client, ) .await?; Ok(token_response.into()) } async fn authorize( Query(params): Query, Extension(redis_client): Extension, ) -> Result<(HeaderMap, Redirect), CustomError> { let (url, session_cookie) = oidc::authorize(params, &redis_client).await?; let mut headers = HeaderMap::new(); headers.insert( header::SET_COOKIE, session_cookie.to_string().parse().unwrap(), ); Ok(( headers, Redirect::to( url.as_str() .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, ), )) } async fn sign_in( Query(params): Query, TypedHeader(cookies): TypedHeader, Extension(redis_client): Extension, Extension(config): Extension, ) -> Result { let url = oidc::sign_in(&config.base_url, params, cookies, &redis_client).await?; Ok(Redirect::to( url.as_str() .parse() .map_err(|e| anyhow!("Could not parse URI: {}", e))?, )) } async fn register( extract::Json(payload): extract::Json, Extension(config): Extension, Extension(redis_client): Extension, ) -> Result<(StatusCode, Json), CustomError> { let registration = oidc::register(payload, config.base_url, &redis_client).await?; Ok((StatusCode::CREATED, registration.into())) } struct UserInfoResponseJWT(Json); impl IntoResponse for UserInfoResponseJWT { fn into_response(self) -> response::Response { response::Response::builder() .status(StatusCode::OK) .header(ContentType::name(), "application/jwt") .body( serde_json::to_string(&self.0 .0) .unwrap() .replace('"', "") .into_response() .into_body(), ) .unwrap() } } enum UserInfoResponse { Json(Json), Jwt(UserInfoResponseJWT), } impl IntoResponse for UserInfoResponse { fn into_response(self) -> response::Response { match self { UserInfoResponse::Json(j) => j.into_response(), UserInfoResponse::Jwt(j) => j.into_response(), } } } // TODO CORS // TODO need validation of the token async fn userinfo( Extension(private_key): Extension, Extension(config): Extension, payload: Option>, bearer: Option>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs Extension(redis_client): Extension, ) -> Result { let payload = if let Some(Form(p)) = payload { p } else { oidc::UserInfoPayload { access_token: None } }; let claims = oidc::userinfo( config.base_url, config.eth_provider, private_key, bearer.map(|b| b.0 .0), payload, &redis_client, ) .await?; Ok(match claims { oidc::UserInfoResponse::Json(c) => UserInfoResponse::Json(c.into()), oidc::UserInfoResponse::Jwt(c) => UserInfoResponse::Jwt(UserInfoResponseJWT(c.into())), }) } async fn clientinfo( Path(client_id): Path, Extension(redis_client): Extension, ) -> Result, CustomError> { Ok(oidc::clientinfo(client_id, &redis_client).await?.into()) } async fn client_update( Path(client_id): Path, extract::Json(payload): extract::Json, bearer: Option>>, Extension(redis_client): Extension, ) -> Result<(), CustomError> { oidc::client_update(client_id, payload, bearer.map(|b| b.0 .0), &redis_client).await } async fn client_delete( Path(client_id): Path, bearer: Option>>, Extension(redis_client): Extension, ) -> Result<(StatusCode, ()), CustomError> { Ok(( StatusCode::NO_CONTENT, oidc::client_delete(client_id, bearer.map(|b| b.0 .0), &redis_client).await?, )) } async fn healthcheck() {} pub async fn main() { let config = Figment::from(Serialized::defaults(config::Config::default())) .merge(Toml::file("siwe-oidc.toml").nested()) .merge(Env::prefixed("SIWEOIDC_").split("__").global()); let config = config.extract::().unwrap(); tracing_subscriber::fmt::init(); let redis_client = RedisClient::new(&config.redis_url) .await .expect("Could not build Redis client"); for (id, entry) in &config.default_clients.clone() { let entry: ClientEntry = serde_json::from_str(entry).expect("Deserialisation of ClientEntry failed"); redis_client .set_client(id.to_string(), entry.clone()) .await .unwrap(); // TODO } let private_key = if let Some(key) = &config.rsa_pem { RsaPrivateKey::from_pkcs1_pem(key) .map_err(|e| anyhow!("Failed to load private key: {}", e)) .unwrap() } else { info!("Generating key..."); let mut rng = OsRng; let bits = 2048; let private = RsaPrivateKey::new(&mut rng, bits) .map_err(|e| anyhow!("Failed to generate a key: {}", e)) .unwrap(); info!("Generated key."); info!("{:?}", private.to_pkcs1_pem(LineEnding::LF).unwrap()); private }; let app = Router::new() .nest( "/build", get_service(ServeDir::new("./static/build")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .nest( "/legal", get_service(ServeDir::new("./static/legal")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .nest( "/img", get_service(ServeDir::new("./static/img")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route( "/", get_service(ServeFile::new("./static/index.html")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route( "/error", get_service(ServeFile::new("./static/error.html")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route( "/favicon.png", get_service(ServeFile::new("./static/favicon.png")).handle_error( |error: std::io::Error| async move { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), ) }, ), ) .route(oidc::METADATA_PATH, get(provider_metadata)) .route(oidc::JWK_PATH, get(jwk_set)) .route(oidc::TOKEN_PATH, post(token)) .route(oidc::AUTHORIZE_PATH, get(authorize)) .route(oidc::REGISTER_PATH, post(register)) .route(oidc::USERINFO_PATH, get(userinfo).post(userinfo)) .route(&format!("{}/:id", oidc::CLIENT_PATH), get(clientinfo)) .route(&format!("{}/:id", oidc::CLIENT_PATH), delete(client_delete)) .route(&format!("{}/:id", oidc::CLIENT_PATH), post(client_update)) .route(oidc::SIGNIN_PATH, get(sign_in)) .route("/health", get(healthcheck)) .layer(Extension(private_key)) .layer(Extension(config.clone())) .layer(Extension(redis_client)) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from((config.address, config.port)); tracing::info!("Listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); }