@@ -9,7 +9,6 @@ use axum::{
|
||||
routing::{delete, get, get_service, post},
|
||||
Json, Router,
|
||||
};
|
||||
use bb8_redis::{bb8, RedisConnectionManager};
|
||||
use figment::{
|
||||
providers::{Env, Format, Serialized, Toml},
|
||||
Figment,
|
||||
@@ -247,10 +246,9 @@ pub async fn main() {
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
|
||||
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
|
||||
|
||||
let redis_client = RedisClient { pool };
|
||||
let redis_client = RedisClient::new(&config.redis_url)
|
||||
.await
|
||||
.expect("Could not build Redis client");
|
||||
|
||||
for (id, entry) in &config.default_clients.clone() {
|
||||
let entry: ClientEntry =
|
||||
|
||||
@@ -1,12 +1,29 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||
use bb8_redis::{
|
||||
bb8::{self, Pool},
|
||||
redis::AsyncCommands,
|
||||
RedisConnectionManager,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisClient {
|
||||
pub pool: Pool<RedisConnectionManager>,
|
||||
pool: Pool<RedisConnectionManager>,
|
||||
}
|
||||
|
||||
impl RedisClient {
|
||||
pub async fn new(url: &Url) -> Result<Self> {
|
||||
let manager = RedisConnectionManager::new(url.clone())
|
||||
.context("Could not build Redis connection manager")?;
|
||||
let pool = bb8::Pool::builder()
|
||||
.build(manager.clone())
|
||||
.await
|
||||
.context("Coud not build Redis pool")?;
|
||||
Ok(Self { pool })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
|
||||
111
src/oidc.rs
111
src/oidc.rs
@@ -630,8 +630,12 @@ pub async fn sign_in(
|
||||
.map_err(|e| anyhow!("Failed message verification: {}", e))?;
|
||||
|
||||
let domain = params.redirect_uri.url();
|
||||
if *domain != Url::from_str(siwe_cookie.message.resources.get(0).unwrap().as_ref()).unwrap() {
|
||||
return Err(anyhow!("Conflicting domains in message and redirect").into());
|
||||
if let Some(r) = siwe_cookie.message.resources.get(0) {
|
||||
if *domain != Url::from_str(r.as_ref()).unwrap() {
|
||||
return Err(anyhow!("Conflicting domains in message and redirect").into());
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Missing resource in SIWE message").into());
|
||||
}
|
||||
|
||||
let code_entry = CodeEntry {
|
||||
@@ -828,9 +832,34 @@ pub async fn userinfo(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::config::Config;
|
||||
|
||||
use super::*;
|
||||
use ethers_signers::{LocalWallet, Signer};
|
||||
use headers::{HeaderMap, HeaderMapExt, HeaderValue};
|
||||
use rand::rngs::OsRng;
|
||||
use test_log::test;
|
||||
|
||||
async fn default_config() -> (Config, RedisClient) {
|
||||
let config = Config::default();
|
||||
let db_client = RedisClient::new(&config.redis_url).await.unwrap();
|
||||
db_client
|
||||
.set_client(
|
||||
"client".into(),
|
||||
ClientEntry {
|
||||
secret: "secret".into(),
|
||||
metadata: CoreClientMetadata::new(
|
||||
vec![RedirectUrl::new("https://example.com".into()).unwrap()],
|
||||
EmptyAdditionalClientMetadata {},
|
||||
),
|
||||
access_token: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
(config, db_client)
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn test_claims() {
|
||||
let res = resolve_claims(
|
||||
@@ -850,4 +879,82 @@ mod tests {
|
||||
Some("https://ipfs.io/ipfs/QmSP4nq9fnN9dAiCj42ug9Wa79rqmQerZXZch82VqpiH7U/image.gif")
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQueryParams {
|
||||
nonce: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SignInQueryParams {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn e2e_flow() {
|
||||
let (_config, db_client) = default_config().await;
|
||||
let wallet = "dcf2cbdd171a21c480aa7f53d77f31bb102282b3ff099c78e3118b37348c72f7"
|
||||
.parse::<LocalWallet>()
|
||||
.unwrap();
|
||||
|
||||
let base_url = Url::parse("https://example.com").unwrap();
|
||||
let params = AuthorizeParams {
|
||||
client_id: "client".into(),
|
||||
redirect_uri: RedirectUrl::from_url(base_url.clone()),
|
||||
scope: Scope::new("openid".to_string()),
|
||||
response_type: Some(CoreResponseType::IdToken),
|
||||
state: Some("state".into()),
|
||||
nonce: None,
|
||||
prompt: None,
|
||||
request_uri: None,
|
||||
request: None,
|
||||
};
|
||||
let (redirect_url, cookie) = authorize(params, &db_client).await.unwrap();
|
||||
let authorize_params: AuthorizeQueryParams =
|
||||
serde_urlencoded::from_str(redirect_url.split("/?").collect::<Vec<&str>>()[1]).unwrap();
|
||||
let params: SignInParams = serde_urlencoded::from_str(&redirect_url).unwrap();
|
||||
let message = Web3ModalMessage {
|
||||
domain: "example.com".into(),
|
||||
address: wallet.address(),
|
||||
statement: "statement".to_string(),
|
||||
uri: base_url.to_string(),
|
||||
version: "1".into(),
|
||||
chain_id: 1,
|
||||
nonce: authorize_params.nonce,
|
||||
issued_at: "2023-04-17T11:01:24.862Z".into(),
|
||||
expiration_time: None,
|
||||
not_before: None,
|
||||
request_id: None,
|
||||
resources: vec!["https://example.com".try_into().unwrap()],
|
||||
};
|
||||
let signature = wallet
|
||||
.sign_message(message.to_eip4361_message().unwrap().to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
let signature = format!("0x{signature}");
|
||||
let siwe_cookie = serde_json::to_string(&SiweCookie { message, signature }).unwrap();
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
"cookie",
|
||||
HeaderValue::from_str(&format!("{cookie}; {SIWE_COOKIE_KEY}={siwe_cookie}")).unwrap(),
|
||||
);
|
||||
let cookie = headers.typed_get::<headers::Cookie>().unwrap();
|
||||
let redirect_url = sign_in(&base_url, params, cookie, &db_client)
|
||||
.await
|
||||
.unwrap();
|
||||
let signin_params: SignInQueryParams =
|
||||
serde_urlencoded::from_str(redirect_url.query().unwrap()).unwrap();
|
||||
let _ = userinfo(
|
||||
base_url,
|
||||
None,
|
||||
RsaPrivateKey::new(&mut OsRng, 1024).unwrap(),
|
||||
None,
|
||||
UserInfoPayload {
|
||||
access_token: Some(signin_params.code),
|
||||
},
|
||||
&db_client,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,15 +31,15 @@ macro_rules! json_bad_request {
|
||||
impl From<CustomError> for Result<Response> {
|
||||
fn from(error: CustomError) -> Self {
|
||||
match error {
|
||||
CustomError::BadRequest(_) => Response::error(&error.to_string(), 400),
|
||||
CustomError::BadRequest(_) => Response::error(error.to_string(), 400),
|
||||
CustomError::BadRequestRegister(e) => {
|
||||
Response::from_json(&e).map(|r| r.with_status(400))
|
||||
}
|
||||
CustomError::BadRequestToken(e) => Response::from_json(&e).map(|r| r.with_status(400)),
|
||||
CustomError::Unauthorized(_) => Response::error(&error.to_string(), 401),
|
||||
CustomError::NotFound => Response::error(&error.to_string(), 404),
|
||||
CustomError::Unauthorized(_) => Response::error(error.to_string(), 401),
|
||||
CustomError::NotFound => Response::error(error.to_string(), 404),
|
||||
CustomError::Redirect(uri) => Response::redirect(uri.parse().unwrap()),
|
||||
CustomError::Other(_) => Response::error(&error.to_string(), 500),
|
||||
CustomError::Other(_) => Response::error(error.to_string(), 500),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user