Cloudflare Worker version (#6)
Refactor/generalise API/DB interactions out of OIDC.
This commit is contained in:
parent
9d725552e0
commit
bbcacf4232
13
.github/workflows/ci.yml
vendored
13
.github/workflows/ci.yml
vendored
@ -8,12 +8,25 @@ env:
|
|||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cargo_target: "x86_64-unknown-linux-gnu"
|
||||||
|
- cargo_target: "wasm32-unknown-unknown"
|
||||||
steps:
|
steps:
|
||||||
- name: Clone repo
|
- name: Clone repo
|
||||||
uses: actions/checkout@master
|
uses: actions/checkout@master
|
||||||
|
- name: Add targets
|
||||||
|
run: rustup target add wasm32-unknown-unknown
|
||||||
- name: Build
|
- name: Build
|
||||||
|
env:
|
||||||
|
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
|
||||||
run: cargo build --verbose
|
run: cargo build --verbose
|
||||||
- name: Clippy
|
- name: Clippy
|
||||||
|
env:
|
||||||
|
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
|
||||||
run: RUSTFLAGS="-Dwarnings" cargo clippy
|
run: RUSTFLAGS="-Dwarnings" cargo clippy
|
||||||
- name: Fmt
|
- name: Fmt
|
||||||
|
env:
|
||||||
|
CARGO_BUILD_TARGET: ${{ matrix.cargo_target }}
|
||||||
run: cargo fmt -- --check
|
run: cargo fmt -- --check
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
/target
|
/target
|
||||||
/static/build
|
/static/build
|
||||||
|
wrangler.toml
|
||||||
|
618
Cargo.lock
generated
618
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
57
Cargo.toml
57
Cargo.toml
@ -5,35 +5,68 @@ edition = "2021"
|
|||||||
authors = ["Spruce Systems, Inc."]
|
authors = ["Spruce Systems, Inc."]
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
repository = "https://github.com/spruceid/siwe-oidc/"
|
repository = "https://github.com/spruceid/siwe-oidc/"
|
||||||
|
description = "OpenID Connect Identity Provider for Sign-In with Ethereum."
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
[lib]
|
||||||
|
crate-type = ["cdylib", "rlib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.51"
|
anyhow = "1.0.51"
|
||||||
axum = { version = "0.3.4", features = ["headers"] }
|
|
||||||
chrono = "0.4.19"
|
|
||||||
headers = "0.3.5"
|
headers = "0.3.5"
|
||||||
hex = "0.4.3"
|
hex = "0.4.3"
|
||||||
iri-string = { version = "0.4", features = ["serde-std"] }
|
iri-string = { version = "0.4", features = ["serde-std"] }
|
||||||
openidconnect = "2.1.2"
|
# openidconnect = "2.1.2"
|
||||||
|
openidconnect = { git = "https://github.com/sbihel/openidconnect-rs", branch = "main", default-features = false, features = ["reqwest", "rustls-tls", "rustcrypto"] }
|
||||||
rand = "0.8.4"
|
rand = "0.8.4"
|
||||||
rsa = { version = "0.5.0", features = ["alloc"] }
|
rsa = { version = "0.5.0", features = ["alloc"] }
|
||||||
rust-argon2 = "0.8"
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0.72"
|
serde_json = "1.0.72"
|
||||||
siwe = "0.1"
|
siwe = "0.1.2"
|
||||||
async-session = "3.0.0"
|
|
||||||
thiserror = "1.0.30"
|
thiserror = "1.0.30"
|
||||||
tokio = { version = "1.14.0", features = ["full"] }
|
|
||||||
tower-http = { version = "0.2.0", features = ["fs", "trace", "cors"] }
|
|
||||||
tracing = "0.1.29"
|
tracing = "0.1.29"
|
||||||
tracing-subscriber = { version = "0.3.2", features = ["env-filter"] }
|
|
||||||
url = { version = "2.2", features = ["serde"] }
|
url = { version = "2.2", features = ["serde"] }
|
||||||
urlencoding = "2.1.0"
|
urlencoding = "2.1.0"
|
||||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
|
||||||
figment = { version = "0.10.6", features = ["toml", "env"] }
|
|
||||||
sha2 = "0.9.0"
|
sha2 = "0.9.0"
|
||||||
cookie = "0.15.1"
|
cookie = "0.15.1"
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
|
async-trait = "0.1.52"
|
||||||
|
|
||||||
|
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||||
|
async-session = "3.0.0"
|
||||||
|
axum = { version = "0.4.3", features = ["headers"] }
|
||||||
|
# axum-debug = "0.3.2"
|
||||||
|
chrono = "0.4.19"
|
||||||
|
figment = { version = "0.10.6", features = ["toml", "env"] }
|
||||||
|
tokio = { version = "1.14.0", features = ["full"] }
|
||||||
|
tower-http = { version = "0.2.0", features = ["fs", "trace", "cors"] }
|
||||||
|
tracing-subscriber = { version = "0.3.2", features = ["env-filter"] }
|
||||||
bb8-redis = "0.10.1"
|
bb8-redis = "0.10.1"
|
||||||
async-redis-session = "0.2.2"
|
async-redis-session = "0.2.2"
|
||||||
|
uuid = { version = "0.8", features = ["serde", "v4"] }
|
||||||
|
|
||||||
|
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||||
|
# cached = { version = "0.26", default-features = false }
|
||||||
|
chrono = { version = "0.4.19", features = ["wasmbind"] }
|
||||||
|
console_error_panic_hook = { version = "0.1" }
|
||||||
|
# console_log = "0.2"
|
||||||
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
|
# log = "0.4"
|
||||||
|
matchit = "0.4.2"
|
||||||
|
serde_urlencoded = "0.7.0"
|
||||||
|
uuid = { version = "0.8", features = ["serde", "v4", "wasm-bindgen"] }
|
||||||
|
wee_alloc = { version = "0.4" }
|
||||||
|
worker = "0.0.7"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
|
||||||
|
# [target.'cfg(target_arch = "wasm32")'.profile.release]
|
||||||
|
# opt-level = "z"
|
||||||
|
|
||||||
|
# [target.'cfg(target_arch = "wasm32")'.profile.debug]
|
||||||
|
# opt-level = "z"
|
||||||
|
# lto = false
|
||||||
|
|
||||||
|
[package.metadata.wasm-pack.profile.profiling]
|
||||||
|
wasm-opt = ['-g', '-O']
|
||||||
|
61
README.md
61
README.md
@ -2,11 +2,54 @@
|
|||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### Dependencies
|
Two versions are available, a stand-alone binary (using Axum and Redis) and a
|
||||||
|
Cloudflare Worker. They use the same code base and are selected at compile time
|
||||||
|
(compiling for `wasm32` will make the Worker version).
|
||||||
|
|
||||||
|
### Cloudflare Worker
|
||||||
|
|
||||||
|
You will need [`wrangler`](https://github.com/cloudflare/wrangler).
|
||||||
|
|
||||||
|
Then copy the configuration file template:
|
||||||
|
```bash
|
||||||
|
cp wrangler_example.toml wrangler.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
Replacing the following fields:
|
||||||
|
- `account_id`: your Cloudflare account ID;
|
||||||
|
- `zone_id`: (Optional) DNS zone ID; and
|
||||||
|
- `kv_namespaces`: a KV namespace ID (created with `wrangler kv:namespace create SIWE-OIDC`).
|
||||||
|
|
||||||
|
At this point, you should be able to create/publish the worker:
|
||||||
|
```
|
||||||
|
wrangler publish
|
||||||
|
```
|
||||||
|
|
||||||
|
The IdP currently only supports having the **frontend under the same subdomain as
|
||||||
|
the API**. Here is the configuration for Cloudflare Pages:
|
||||||
|
- `Build command`: `cd js/ui && npm install && npm run build`;
|
||||||
|
- `Build output directory`: `/static`; and
|
||||||
|
- `Root directory`: `/`.
|
||||||
|
And you will need to add some rules to do the routing between the Page and the
|
||||||
|
Worker. Here are the rules for the Worker (the Page being used as the fallback
|
||||||
|
on the subdomain):
|
||||||
|
```
|
||||||
|
siweoidc.example.com/s*
|
||||||
|
siweoidc.example.com/u*
|
||||||
|
siweoidc.example.com/r*
|
||||||
|
siweoidc.example.com/a*
|
||||||
|
siweoidc.example.com/t*
|
||||||
|
siweoidc.example.com/j*
|
||||||
|
siweoidc.example.com/.w*
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stand-Alone Binary
|
||||||
|
|
||||||
|
#### Dependencies
|
||||||
|
|
||||||
Redis, or a Redis compatible database (e.g. MemoryDB in AWS), is required.
|
Redis, or a Redis compatible database (e.g. MemoryDB in AWS), is required.
|
||||||
|
|
||||||
### Starting the IdP
|
#### Starting the IdP
|
||||||
|
|
||||||
The Docker image is available at `ghcr.io/spruceid/siwe_oidc:0.1.0`. Here is an
|
The Docker image is available at `ghcr.io/spruceid/siwe_oidc:0.1.0`. Here is an
|
||||||
example usage:
|
example usage:
|
||||||
@ -35,9 +78,23 @@ For the core OIDC information, it is available under
|
|||||||
|
|
||||||
* Additional information, from native projects (e.g. ENS domains), to more
|
* Additional information, from native projects (e.g. ENS domains), to more
|
||||||
traditional ones (e.g. email).
|
traditional ones (e.g. email).
|
||||||
|
* PKCE support (code challenge).
|
||||||
|
* Browser session support for the Worker version.
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
|
### Cloudflare Worker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wrangler dev
|
||||||
|
```
|
||||||
|
You can now use http://127.0.0.1:8787/.well-known/openid-configuration.
|
||||||
|
|
||||||
|
> At the moment it's not possible to use it end-to-end with the frontend as they
|
||||||
|
> need to share the same host (i.e. port), unless using a local load-balancer.
|
||||||
|
|
||||||
|
### Stand Alone Binary
|
||||||
|
|
||||||
A Docker Compose is available to test the IdP locally with Keycloak.
|
A Docker Compose is available to test the IdP locally with Keycloak.
|
||||||
|
|
||||||
1. You will first need to run:
|
1. You will first need to run:
|
||||||
|
2895
js/ui/package-lock.json
generated
2895
js/ui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -2,28 +2,28 @@
|
|||||||
"name": "svelte-app",
|
"name": "svelte-app",
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@tsconfig/svelte": "^1.0.10",
|
"@tsconfig/svelte": "^3.0.0",
|
||||||
"@types/node": "^14.11.1",
|
"@types/node": "^17.0.7",
|
||||||
"@typescript-eslint/eslint-plugin": "^4.21.0",
|
"@typescript-eslint/eslint-plugin": "^5.9.0",
|
||||||
"@typescript-eslint/parser": "^4.21.0",
|
"@typescript-eslint/parser": "^5.9.0",
|
||||||
"assert": "^2.0.0",
|
"assert": "^2.0.0",
|
||||||
"autoprefixer": "^10.2.5",
|
"autoprefixer": "^10.2.5",
|
||||||
"base64-loader": "^1.0.0",
|
"base64-loader": "^1.0.0",
|
||||||
"buffer": "^6.0.3",
|
"buffer": "^6.0.3",
|
||||||
"cross-env": "^7.0.3",
|
"cross-env": "^7.0.3",
|
||||||
"crypto-browserify": "^3.12.0",
|
"crypto-browserify": "^3.12.0",
|
||||||
"css-loader": "^5.0.1",
|
"css-loader": "^6.5.1",
|
||||||
"cssnano": "^5.0.8",
|
"cssnano": "^5.0.8",
|
||||||
"dotenv-webpack": "^7.0.3",
|
"dotenv-webpack": "^7.0.3",
|
||||||
"eslint": "^7.23.0",
|
"eslint": "^8.6.0",
|
||||||
"eslint-config-prettier": "^8.1.0",
|
"eslint-config-prettier": "^8.1.0",
|
||||||
"eslint-plugin-svelte3": "^3.1.2",
|
"eslint-plugin-svelte3": "^3.1.2",
|
||||||
"https-browserify": "^1.0.0",
|
"https-browserify": "^1.0.0",
|
||||||
"mini-css-extract-plugin": "^1.3.4",
|
"mini-css-extract-plugin": "^2.4.5",
|
||||||
"os-browserify": "^0.3.0",
|
"os-browserify": "^0.3.0",
|
||||||
"postcss": "^8.2.8",
|
"postcss": "^8.2.8",
|
||||||
"postcss-load-config": "^3.0.1",
|
"postcss-load-config": "^3.0.1",
|
||||||
"postcss-loader": "^5.2.0",
|
"postcss-loader": "^6.2.1",
|
||||||
"precss": "^4.0.0",
|
"precss": "^4.0.0",
|
||||||
"prettier": "^2.2.1",
|
"prettier": "^2.2.1",
|
||||||
"prettier-plugin-svelte": "^2.2.0",
|
"prettier-plugin-svelte": "^2.2.0",
|
||||||
@ -31,12 +31,12 @@
|
|||||||
"stream-browserify": "^3.0.0",
|
"stream-browserify": "^3.0.0",
|
||||||
"stream-http": "^3.2.0",
|
"stream-http": "^3.2.0",
|
||||||
"svelte": "^3.31.2",
|
"svelte": "^3.31.2",
|
||||||
"svelte-check": "^1.0.46",
|
"svelte-check": "^2.2.11",
|
||||||
"svelte-loader": "^3.0.0",
|
"svelte-loader": "^3.0.0",
|
||||||
"svelte-preprocess": "^4.3.0",
|
"svelte-preprocess": "^4.3.0",
|
||||||
"svg-url-loader": "^7.1.1",
|
"svg-url-loader": "^7.1.1",
|
||||||
"tailwindcss": "^2.0.4",
|
"tailwindcss": "^3.0.9",
|
||||||
"ts-loader": "^8.0.4",
|
"ts-loader": "^9.2.6",
|
||||||
"tslib": "^2.0.1",
|
"tslib": "^2.0.1",
|
||||||
"typescript": "^4.0.3",
|
"typescript": "^4.0.3",
|
||||||
"webpack": "^5.16.0",
|
"webpack": "^5.16.0",
|
||||||
@ -54,6 +54,7 @@
|
|||||||
"@toruslabs/torus-embed": "^1.18.3",
|
"@toruslabs/torus-embed": "^1.18.3",
|
||||||
"@walletconnect/web3-provider": "^1.6.6",
|
"@walletconnect/web3-provider": "^1.6.6",
|
||||||
"fortmatic": "^2.2.1",
|
"fortmatic": "^2.2.1",
|
||||||
|
"url": "^0.11.0",
|
||||||
"walletlink": "^2.2.8"
|
"walletlink": "^2.2.8"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
{
|
{
|
||||||
"extends": "@tsconfig/svelte/tsconfig.json",
|
"extends": "@tsconfig/svelte/tsconfig.json",
|
||||||
"include": ["src/**/*", "src/node_modules/**/*"],
|
"include": ["src/**/*", "src/node_modules/**/*"],
|
||||||
"exclude": ["node_modules/*", "__sapper__/*", "static/*"]
|
"exclude": ["node_modules/*", "__sapper__/*", "static/*"],
|
||||||
|
"compilerOptions": {
|
||||||
|
"types": ["node", "svelte"]
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
@ -27,6 +27,7 @@ module.exports = {
|
|||||||
path: false,
|
path: false,
|
||||||
process: require.resolve('process/browser'),
|
process: require.resolve('process/browser'),
|
||||||
stream: require.resolve('stream-browserify'),
|
stream: require.resolve('stream-browserify'),
|
||||||
|
url: require.resolve("url")
|
||||||
// util: false,
|
// util: false,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
364
src/axum_lib.rs
Normal file
364
src/axum_lib.rs
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_redis_session::RedisSessionStore;
|
||||||
|
use axum::{
|
||||||
|
extract::{self, Extension, Form, Query, TypedHeader},
|
||||||
|
http::{
|
||||||
|
header::{self, HeaderMap},
|
||||||
|
StatusCode,
|
||||||
|
},
|
||||||
|
response::{self, IntoResponse, Redirect},
|
||||||
|
routing::{get, get_service, post},
|
||||||
|
AddExtensionLayer, Json, Router,
|
||||||
|
};
|
||||||
|
use bb8_redis::{bb8, RedisConnectionManager};
|
||||||
|
use figment::{
|
||||||
|
providers::{Env, Format, Serialized, Toml},
|
||||||
|
Figment,
|
||||||
|
};
|
||||||
|
use headers::{
|
||||||
|
self,
|
||||||
|
authorization::{Basic, Bearer},
|
||||||
|
Authorization,
|
||||||
|
};
|
||||||
|
use openidconnect::core::{
|
||||||
|
CoreClientMetadata, CoreClientRegistrationResponse, CoreJsonWebKeySet, CoreProviderMetadata,
|
||||||
|
CoreResponseType, CoreTokenResponse, CoreUserInfoClaims,
|
||||||
|
};
|
||||||
|
use rand::rngs::OsRng;
|
||||||
|
use rsa::{
|
||||||
|
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
|
||||||
|
RsaPrivateKey,
|
||||||
|
};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tower_http::{
|
||||||
|
services::{ServeDir, ServeFile},
|
||||||
|
trace::TraceLayer,
|
||||||
|
};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use super::config;
|
||||||
|
use super::oidc::{self, CustomError};
|
||||||
|
use super::session::*;
|
||||||
|
use ::siwe_oidc::db::*;
|
||||||
|
|
||||||
|
impl IntoResponse for CustomError {
|
||||||
|
fn into_response(self) -> response::Response {
|
||||||
|
match self {
|
||||||
|
CustomError::BadRequest(_) => {
|
||||||
|
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
|
||||||
|
}
|
||||||
|
CustomError::BadRequestToken(e) => {
|
||||||
|
(StatusCode::BAD_REQUEST, Json::from(e)).into_response()
|
||||||
|
}
|
||||||
|
CustomError::Unauthorized(_) => {
|
||||||
|
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
|
||||||
|
}
|
||||||
|
CustomError::Redirect(uri) => Redirect::to(
|
||||||
|
uri.parse().unwrap(),
|
||||||
|
// .map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
CustomError::Other(_) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn jwk_set(
|
||||||
|
Extension(private_key): Extension<RsaPrivateKey>,
|
||||||
|
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
|
||||||
|
let jwks = oidc::jwks(private_key)?;
|
||||||
|
Ok(jwks.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn provider_metadata(
|
||||||
|
Extension(config): Extension<config::Config>,
|
||||||
|
) -> Result<Json<CoreProviderMetadata>, CustomError> {
|
||||||
|
Ok(oidc::metadata(config.base_url)?.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO should check Authorization header
|
||||||
|
// Actually, client secret can be
|
||||||
|
// 1. in the POST (currently supported) [x]
|
||||||
|
// 2. Authorization header [x]
|
||||||
|
// 3. JWT [ ]
|
||||||
|
// 4. signed JWT [ ]
|
||||||
|
// according to Keycloak
|
||||||
|
|
||||||
|
async fn token(
|
||||||
|
Form(form): Form<oidc::TokenForm>,
|
||||||
|
bearer: Option<TypedHeader<Authorization<Bearer>>>,
|
||||||
|
basic: Option<TypedHeader<Authorization<Basic>>>,
|
||||||
|
Extension(private_key): Extension<RsaPrivateKey>,
|
||||||
|
Extension(config): Extension<config::Config>,
|
||||||
|
Extension(redis_client): Extension<RedisClient>,
|
||||||
|
) -> Result<Json<CoreTokenResponse>, CustomError> {
|
||||||
|
let secret = if let Some(b) = bearer {
|
||||||
|
Some(b.0 .0.token().to_string())
|
||||||
|
} else {
|
||||||
|
basic.map(|b| b.0 .0.password().to_string())
|
||||||
|
};
|
||||||
|
let token_response = oidc::token(
|
||||||
|
form,
|
||||||
|
secret,
|
||||||
|
private_key,
|
||||||
|
config.base_url,
|
||||||
|
config.require_secret,
|
||||||
|
&redis_client,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(token_response.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO handle `registration` parameter
|
||||||
|
async fn authorize(
|
||||||
|
session: UserSessionFromSession,
|
||||||
|
Query(params): Query<oidc::AuthorizeParams>,
|
||||||
|
Extension(redis_client): Extension<RedisClient>,
|
||||||
|
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||||
|
let (nonce, headers) = match session {
|
||||||
|
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
||||||
|
UserSessionFromSession::Invalid(cookie) => {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(header::SET_COOKIE, cookie);
|
||||||
|
return Ok((
|
||||||
|
headers,
|
||||||
|
Redirect::to(
|
||||||
|
format!(
|
||||||
|
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
|
||||||
|
¶ms.client_id,
|
||||||
|
¶ms.redirect_uri.to_string(),
|
||||||
|
¶ms.scope.to_string(),
|
||||||
|
¶ms.response_type.unwrap_or(CoreResponseType::Code).as_ref(),
|
||||||
|
¶ms.state.unwrap_or_default(),
|
||||||
|
¶ms.client_id,
|
||||||
|
¶ms.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
|
||||||
|
)
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
UserSessionFromSession::Created { header, nonce } => {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(header::SET_COOKIE, header);
|
||||||
|
(nonce, headers)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = oidc::authorize(params, nonce, &redis_client).await?;
|
||||||
|
Ok((
|
||||||
|
headers,
|
||||||
|
Redirect::to(
|
||||||
|
url.as_str()
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sign_in(
|
||||||
|
session: UserSessionFromSession,
|
||||||
|
Query(params): Query<oidc::SignInParams>,
|
||||||
|
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
||||||
|
Extension(redis_client): Extension<RedisClient>,
|
||||||
|
) -> Result<(HeaderMap, Redirect), CustomError> {
|
||||||
|
let (nonce, headers) = match session {
|
||||||
|
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
||||||
|
UserSessionFromSession::Invalid(header) => {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(header::SET_COOKIE, header);
|
||||||
|
return Ok((
|
||||||
|
headers,
|
||||||
|
Redirect::to(
|
||||||
|
format!(
|
||||||
|
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||||
|
¶ms.client_id.clone(),
|
||||||
|
¶ms.redirect_uri.to_string(),
|
||||||
|
¶ms.state,
|
||||||
|
)
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
UserSessionFromSession::Created { .. } => {
|
||||||
|
return Ok((
|
||||||
|
HeaderMap::new(),
|
||||||
|
Redirect::to(
|
||||||
|
format!(
|
||||||
|
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
||||||
|
¶ms.client_id.clone(),
|
||||||
|
¶ms.redirect_uri.to_string(),
|
||||||
|
¶ms.state,
|
||||||
|
)
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = oidc::sign_in(params, Some(nonce), cookies, &redis_client).await?;
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
headers,
|
||||||
|
Redirect::to(
|
||||||
|
url.as_str()
|
||||||
|
.parse()
|
||||||
|
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
// TODO clear session
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register(
|
||||||
|
extract::Json(payload): extract::Json<CoreClientMetadata>,
|
||||||
|
Extension(redis_client): Extension<RedisClient>,
|
||||||
|
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
|
||||||
|
let registration = oidc::register(payload, &redis_client).await?;
|
||||||
|
Ok((StatusCode::CREATED, registration.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO CORS
|
||||||
|
// TODO need validation of the token
|
||||||
|
async fn userinfo(
|
||||||
|
payload: Option<Form<oidc::UserInfoPayload>>,
|
||||||
|
bearer: Option<TypedHeader<Authorization<Bearer>>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
||||||
|
Extension(redis_client): Extension<RedisClient>,
|
||||||
|
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
|
||||||
|
let payload = if let Some(Form(p)) = payload {
|
||||||
|
p
|
||||||
|
} else {
|
||||||
|
oidc::UserInfoPayload { access_token: None }
|
||||||
|
};
|
||||||
|
let claims = oidc::userinfo(bearer.map(|b| b.0 .0), payload, &redis_client).await?;
|
||||||
|
Ok(claims.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn healthcheck() {}
|
||||||
|
|
||||||
|
pub async fn main() {
|
||||||
|
let config = Figment::from(Serialized::defaults(config::Config::default()))
|
||||||
|
.merge(Toml::file("siwe-oidc.toml").nested())
|
||||||
|
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
|
||||||
|
let config = config.extract::<config::Config>().unwrap();
|
||||||
|
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
|
||||||
|
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
|
||||||
|
// let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
|
||||||
|
|
||||||
|
let redis_client = RedisClient { pool };
|
||||||
|
|
||||||
|
for (id, secret) in &config.default_clients.clone() {
|
||||||
|
let client_entry = ClientEntry {
|
||||||
|
secret: secret.to_string(),
|
||||||
|
redirect_uris: vec![],
|
||||||
|
};
|
||||||
|
redis_client
|
||||||
|
.set_client(id.to_string(), client_entry)
|
||||||
|
.await
|
||||||
|
.unwrap(); // TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
let private_key = if let Some(key) = &config.rsa_pem {
|
||||||
|
RsaPrivateKey::from_pkcs1_pem(key)
|
||||||
|
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
info!("Generating key...");
|
||||||
|
let mut rng = OsRng;
|
||||||
|
let bits = 2048;
|
||||||
|
let private = RsaPrivateKey::new(&mut rng, bits)
|
||||||
|
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
info!("Generated key.");
|
||||||
|
info!("{:?}", private.to_pkcs1_pem().unwrap());
|
||||||
|
private
|
||||||
|
};
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.nest(
|
||||||
|
"/build",
|
||||||
|
get_service(ServeDir::new("./static/build")).handle_error(
|
||||||
|
|error: std::io::Error| async move {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Unhandled internal error: {}", error),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.nest(
|
||||||
|
"/img",
|
||||||
|
get_service(ServeDir::new("./static/img")).handle_error(
|
||||||
|
|error: std::io::Error| async move {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Unhandled internal error: {}", error),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/",
|
||||||
|
get_service(ServeFile::new("./static/index.html")).handle_error(
|
||||||
|
|error: std::io::Error| async move {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Unhandled internal error: {}", error),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/error",
|
||||||
|
get_service(ServeFile::new("./static/error.html")).handle_error(
|
||||||
|
|error: std::io::Error| async move {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Unhandled internal error: {}", error),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/favicon.png",
|
||||||
|
get_service(ServeFile::new("./static/favicon.png")).handle_error(
|
||||||
|
|error: std::io::Error| async move {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Unhandled internal error: {}", error),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.route(oidc::METADATA_PATH, get(provider_metadata))
|
||||||
|
.route(oidc::JWK_PATH, get(jwk_set))
|
||||||
|
.route(oidc::TOKEN_PATH, post(token))
|
||||||
|
.route(oidc::AUTHORIZE_PATH, get(authorize))
|
||||||
|
.route(oidc::REGISTER_PATH, post(register))
|
||||||
|
.route(oidc::USERINFO_PATH, get(userinfo).post(userinfo))
|
||||||
|
.route(oidc::SIGNIN_PATH, get(sign_in))
|
||||||
|
.route("/health", get(healthcheck))
|
||||||
|
.layer(AddExtensionLayer::new(private_key))
|
||||||
|
.layer(AddExtensionLayer::new(config.clone()))
|
||||||
|
.layer(AddExtensionLayer::new(redis_client))
|
||||||
|
.layer(AddExtensionLayer::new(
|
||||||
|
RedisSessionStore::new(config.redis_url.clone())
|
||||||
|
.unwrap()
|
||||||
|
.with_prefix("async-sessions/"),
|
||||||
|
))
|
||||||
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
||||||
|
let addr = SocketAddr::from((config.address, config.port));
|
||||||
|
tracing::info!("Listening on {}", addr);
|
||||||
|
axum::Server::bind(&addr)
|
||||||
|
.serve(app.into_make_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
43
src/db.rs
43
src/db.rs
@ -1,43 +0,0 @@
|
|||||||
use anyhow::{anyhow, Result};
|
|
||||||
use bb8_redis::{bb8::PooledConnection, redis::AsyncCommands, RedisConnectionManager};
|
|
||||||
use openidconnect::RedirectUrl;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
const KV_CLIENT_PREFIX: &str = "clients";
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct ClientEntry {
|
|
||||||
pub secret: String,
|
|
||||||
pub redirect_uris: Vec<RedirectUrl>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_client(
|
|
||||||
mut conn: PooledConnection<'_, RedisConnectionManager>,
|
|
||||||
client_id: String,
|
|
||||||
client_entry: ClientEntry,
|
|
||||||
) -> Result<()> {
|
|
||||||
conn.set(
|
|
||||||
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
|
||||||
serde_json::to_string(&client_entry)
|
|
||||||
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_client(
|
|
||||||
mut conn: PooledConnection<'_, RedisConnectionManager>,
|
|
||||||
client_id: String,
|
|
||||||
) -> Result<Option<ClientEntry>> {
|
|
||||||
let entry: Option<String> = conn
|
|
||||||
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
|
||||||
if let Some(e) = entry {
|
|
||||||
Ok(serde_json::from_str(&e)
|
|
||||||
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
199
src/db/cf.rs
Normal file
199
src/db/cf.rs
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
// use cached::{stores::TimedCache, Cached};
|
||||||
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
use matchit::Node;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use worker::*;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
const KV_NAMESPACE: &str = "SIWE-OIDC";
|
||||||
|
const DO_NAMESPACE: &str = "SIWE-OIDC-CODES";
|
||||||
|
|
||||||
|
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
|
||||||
|
// Heavily relying on:
|
||||||
|
// A Durable Object is given 30 seconds of additional CPU time for every
|
||||||
|
// request it processes, including WebSocket messages. In the absence of
|
||||||
|
// failures, in-memory state should not be reset after less than 30 seconds of
|
||||||
|
// inactivity.
|
||||||
|
// /!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\/!\
|
||||||
|
|
||||||
|
// Wanted to use TimedCache but it (probably) crashes because it's using std::time::Instant which isn't available on wasm32.
|
||||||
|
|
||||||
|
#[durable_object]
|
||||||
|
pub struct DOCodes {
|
||||||
|
// codes: TimedCache<String, CodeEntry>,
|
||||||
|
codes: HashMap<String, (DateTime<Utc>, CodeEntry)>,
|
||||||
|
// state: State,
|
||||||
|
// env: Env,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[durable_object]
|
||||||
|
impl DurableObject for DOCodes {
|
||||||
|
fn new(state: State, _env: Env) -> Self {
|
||||||
|
Self {
|
||||||
|
// codes: TimedCache::with_lifespan(ENTRY_LIFETIME.try_into().unwrap()),
|
||||||
|
codes: HashMap::new(),
|
||||||
|
// state,
|
||||||
|
// env,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch(&mut self, mut req: Request) -> worker::Result<Response> {
|
||||||
|
// Can't use the Router because we need to reference self (thus move the var to the closure)
|
||||||
|
if matches!(req.method(), Method::Get) {
|
||||||
|
let mut matcher = Node::new();
|
||||||
|
matcher.insert("/:code", ())?;
|
||||||
|
let path = req.path();
|
||||||
|
let matched = match matcher.at(&path) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(_) => return Response::error("Bad request", 400),
|
||||||
|
};
|
||||||
|
let code = if let Some(c) = matched.params.get("code") {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
return Response::error("Bad request", 400);
|
||||||
|
};
|
||||||
|
if let Some(c) = self.codes.get(code) {
|
||||||
|
if c.0 + Duration::seconds(ENTRY_LIFETIME.try_into().unwrap()) < Utc::now() {
|
||||||
|
self.codes.remove(code);
|
||||||
|
Response::error("Not found", 404)
|
||||||
|
} else {
|
||||||
|
Response::from_json(&c.1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Response::error("Not found", 404)
|
||||||
|
}
|
||||||
|
} else if matches!(req.method(), Method::Post) {
|
||||||
|
let mut matcher = Node::new();
|
||||||
|
matcher.insert("/:code", ())?;
|
||||||
|
let path = req.path();
|
||||||
|
let matched = match matcher.at(&path) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(_) => return Response::error("Bad request", 400),
|
||||||
|
};
|
||||||
|
let code = if let Some(c) = matched.params.get("code") {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
return Response::error("Bad request", 400);
|
||||||
|
};
|
||||||
|
let code_entry = match req.json().await {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => return Response::error(format!("Bad request: {}", e), 400),
|
||||||
|
};
|
||||||
|
self.codes
|
||||||
|
.insert(code.to_string(), (Utc::now(), code_entry));
|
||||||
|
Response::empty()
|
||||||
|
} else {
|
||||||
|
Response::error("Method Not Allowed", 405)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CFClient {
|
||||||
|
pub ctx: RouteContext<()>,
|
||||||
|
pub url: Url,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||||
|
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||||
|
impl DBClient for CFClient {
|
||||||
|
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
|
||||||
|
self.ctx
|
||||||
|
.kv(KV_NAMESPACE)
|
||||||
|
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
|
||||||
|
.put(
|
||||||
|
&format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||||
|
serde_json::to_string(&client_entry)
|
||||||
|
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||||
|
)
|
||||||
|
.map_err(|e| anyhow!("Failed to build KV put: {}", e))?
|
||||||
|
// TODO put some sort of expiration for dynamic registration
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to put KV: {}", e))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
|
||||||
|
let entry = self
|
||||||
|
.ctx
|
||||||
|
.kv(KV_NAMESPACE)
|
||||||
|
.map_err(|e| anyhow!("Failed to get KV store: {}", e))?
|
||||||
|
.get(&format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get KV: {}", e))?
|
||||||
|
.map(|e| e.as_string());
|
||||||
|
if let Some(e) = entry {
|
||||||
|
Ok(serde_json::from_str(&e)
|
||||||
|
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
|
||||||
|
let namespace = self
|
||||||
|
.ctx
|
||||||
|
.durable_object(DO_NAMESPACE)
|
||||||
|
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
|
||||||
|
let stub = namespace
|
||||||
|
.id_from_name(&code)
|
||||||
|
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
|
||||||
|
.get_stub()
|
||||||
|
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
|
||||||
|
let mut headers = Headers::new();
|
||||||
|
headers.set("Content-Type", "application/json").unwrap();
|
||||||
|
let mut url = self.url.clone();
|
||||||
|
url.set_path(&code);
|
||||||
|
url.set_query(None);
|
||||||
|
let req = Request::new_with_init(
|
||||||
|
url.as_str(),
|
||||||
|
&RequestInit {
|
||||||
|
body: Some(wasm_bindgen::JsValue::from_str(
|
||||||
|
&serde_json::to_string(&code_entry)
|
||||||
|
.map_err(|e| anyhow!("Failed to serialize: {}", e))?,
|
||||||
|
)),
|
||||||
|
method: Method::Post,
|
||||||
|
headers,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.map_err(|e| anyhow!("Failed to construct request for Durable Object: {}", e))?;
|
||||||
|
let res = stub
|
||||||
|
.fetch_with_request(req)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
|
||||||
|
match res.status_code() {
|
||||||
|
200 => Ok(()),
|
||||||
|
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
|
||||||
|
let namespace = self
|
||||||
|
.ctx
|
||||||
|
.durable_object(DO_NAMESPACE)
|
||||||
|
.map_err(|e| anyhow!("Failed to retrieve Durable Object: {}", e))?;
|
||||||
|
let stub = namespace
|
||||||
|
.id_from_name(&code)
|
||||||
|
.map_err(|e| anyhow!("Failed to retrieve Durable Object from ID: {}", e))?
|
||||||
|
.get_stub()
|
||||||
|
.map_err(|e| anyhow!("Failed to retrieve Durable Object stub: {}", e))?;
|
||||||
|
let mut url = self.url.clone();
|
||||||
|
url.set_path(&code);
|
||||||
|
url.set_query(None);
|
||||||
|
let mut res = stub
|
||||||
|
.fetch_with_str(url.as_str())
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Request to Durable Object failed: {}", e))?;
|
||||||
|
match res.status_code() {
|
||||||
|
200 => Ok(Some(res.json().await.map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Response to Durable Object failed to be deserialized: {}",
|
||||||
|
e
|
||||||
|
)
|
||||||
|
})?)),
|
||||||
|
404 => Ok(None),
|
||||||
|
code => Err(anyhow!("Error fetching from Durable Object: {}", code)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
40
src/db/mod.rs
Normal file
40
src/db/mod.rs
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use openidconnect::{Nonce, RedirectUrl};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
mod redis;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
pub use redis::RedisClient;
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
mod cf;
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
pub use cf::CFClient;
|
||||||
|
|
||||||
|
const KV_CLIENT_PREFIX: &str = "clients";
|
||||||
|
const ENTRY_LIFETIME: usize = 30;
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CodeEntry {
|
||||||
|
pub exchange_count: usize,
|
||||||
|
pub address: String,
|
||||||
|
pub nonce: Option<Nonce>,
|
||||||
|
pub client_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ClientEntry {
|
||||||
|
pub secret: String,
|
||||||
|
pub redirect_uris: Vec<RedirectUrl>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Using a trait to easily pass async functions with async_trait
|
||||||
|
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||||
|
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||||
|
pub trait DBClient {
|
||||||
|
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()>;
|
||||||
|
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>>;
|
||||||
|
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()>;
|
||||||
|
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>>;
|
||||||
|
}
|
89
src/db/redis.rs
Normal file
89
src/db/redis.rs
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct RedisClient {
|
||||||
|
pub pool: Pool<RedisConnectionManager>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||||
|
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||||
|
impl DBClient for RedisClient {
|
||||||
|
async fn set_client(&self, client_id: String, client_entry: ClientEntry) -> Result<()> {
|
||||||
|
let mut conn = self
|
||||||
|
.pool
|
||||||
|
.get()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||||
|
|
||||||
|
conn.set(
|
||||||
|
format!("{}/{}", KV_CLIENT_PREFIX, client_id),
|
||||||
|
serde_json::to_string(&client_entry)
|
||||||
|
.map_err(|e| anyhow!("Failed to serialize client entry: {}", e))?,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_client(&self, client_id: String) -> Result<Option<ClientEntry>> {
|
||||||
|
let mut conn = self
|
||||||
|
.pool
|
||||||
|
.get()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||||
|
let entry: Option<String> = conn
|
||||||
|
.get(format!("{}/{}", KV_CLIENT_PREFIX, client_id))
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||||
|
if let Some(e) = entry {
|
||||||
|
Ok(serde_json::from_str(&e)
|
||||||
|
.map_err(|e| anyhow!("Failed to deserialize client entry: {}", e))?)
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_code(&self, code: String, code_entry: CodeEntry) -> Result<()> {
|
||||||
|
let mut conn = self
|
||||||
|
.pool
|
||||||
|
.get()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||||
|
conn.set_ex(
|
||||||
|
code.to_string(),
|
||||||
|
hex::encode(
|
||||||
|
bincode::serialize(&code_entry)
|
||||||
|
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
||||||
|
),
|
||||||
|
ENTRY_LIFETIME,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_code(&self, code: String) -> Result<Option<CodeEntry>> {
|
||||||
|
let mut conn = self
|
||||||
|
.pool
|
||||||
|
.get()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
||||||
|
let serialized_entry: Option<Vec<u8>> = conn
|
||||||
|
.get(code)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||||
|
if serialized_entry.is_none() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let code_entry: CodeEntry = bincode::deserialize(
|
||||||
|
&hex::decode(serialized_entry.unwrap())
|
||||||
|
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
||||||
|
)
|
||||||
|
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
||||||
|
Ok(Some(code_entry))
|
||||||
|
}
|
||||||
|
}
|
18
src/lib.rs
Normal file
18
src/lib.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
use worker::*;
|
||||||
|
|
||||||
|
pub mod db;
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
pub mod oidc;
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
mod worker_lib;
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
use worker_lib::main as worker_main;
|
||||||
|
// pub use worker_lib::main;
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
#[event(fetch)]
|
||||||
|
pub async fn main(req: Request, env: Env) -> Result<Response> {
|
||||||
|
worker_main(req, env).await
|
||||||
|
}
|
889
src/main.rs
889
src/main.rs
@ -1,882 +1,19 @@
|
|||||||
use anyhow::{anyhow, Result};
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use async_redis_session::RedisSessionStore;
|
mod axum_lib;
|
||||||
use axum::{
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
body::{Bytes, Full},
|
|
||||||
error_handling::HandleErrorExt,
|
|
||||||
extract::{self, Extension, Form, Query, TypedHeader},
|
|
||||||
http::{
|
|
||||||
header::{self, HeaderMap},
|
|
||||||
Response, StatusCode,
|
|
||||||
},
|
|
||||||
response::{IntoResponse, Redirect},
|
|
||||||
routing::{get, post, service_method_routing},
|
|
||||||
AddExtensionLayer, Json, Router,
|
|
||||||
};
|
|
||||||
use bb8_redis::{bb8, bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
|
||||||
use chrono::{Duration, Utc};
|
|
||||||
use figment::{
|
|
||||||
providers::{Env, Format, Serialized, Toml},
|
|
||||||
Figment,
|
|
||||||
};
|
|
||||||
use headers::{self, authorization::Bearer, Authorization};
|
|
||||||
use hex::FromHex;
|
|
||||||
use iri_string::types::{UriAbsoluteString, UriString};
|
|
||||||
use openidconnect::{
|
|
||||||
core::{
|
|
||||||
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
|
|
||||||
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
|
|
||||||
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
|
|
||||||
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
|
|
||||||
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
|
|
||||||
},
|
|
||||||
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
|
|
||||||
url::Url,
|
|
||||||
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
|
|
||||||
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
|
|
||||||
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
|
|
||||||
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
|
|
||||||
};
|
|
||||||
use rand::rngs::OsRng;
|
|
||||||
use rsa::{
|
|
||||||
pkcs1::{FromRsaPrivateKey, ToRsaPrivateKey},
|
|
||||||
RsaPrivateKey,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use siwe::eip4361::{Message, Version};
|
|
||||||
use std::{convert::Infallible, net::SocketAddr, str::FromStr};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tower_http::{
|
|
||||||
services::{ServeDir, ServeFile},
|
|
||||||
trace::TraceLayer,
|
|
||||||
};
|
|
||||||
use tracing::info;
|
|
||||||
use urlencoding::decode;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
mod db;
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
mod oidc;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
mod session;
|
mod session;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use axum_lib::main as axum_main;
|
||||||
|
|
||||||
use db::*;
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use session::*;
|
|
||||||
|
|
||||||
const KID: &str = "key1";
|
|
||||||
const ENTRY_LIFETIME: usize = 30;
|
|
||||||
|
|
||||||
type ConnectionPool = Pool<RedisConnectionManager>;
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
pub struct TokenError {
|
|
||||||
pub error: CoreErrorResponseType,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum CustomError {
|
|
||||||
#[error("{0}")]
|
|
||||||
BadRequest(String),
|
|
||||||
#[error("{0:?}")]
|
|
||||||
BadRequestToken(Json<TokenError>),
|
|
||||||
#[error("{0}")]
|
|
||||||
Unauthorized(String),
|
|
||||||
#[error(transparent)]
|
|
||||||
Other(#[from] anyhow::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for CustomError {
|
|
||||||
type Body = Full<Bytes>;
|
|
||||||
type BodyError = Infallible;
|
|
||||||
|
|
||||||
fn into_response(self) -> Response<Self::Body> {
|
|
||||||
match self {
|
|
||||||
CustomError::BadRequest(_) => {
|
|
||||||
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
|
|
||||||
}
|
|
||||||
CustomError::BadRequestToken(e) => (StatusCode::BAD_REQUEST, e).into_response(),
|
|
||||||
CustomError::Unauthorized(_) => {
|
|
||||||
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
|
|
||||||
}
|
|
||||||
CustomError::Other(_) => {
|
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn jwk_set(
|
|
||||||
Extension(private_key): Extension<RsaPrivateKey>,
|
|
||||||
) -> Result<Json<CoreJsonWebKeySet>, CustomError> {
|
|
||||||
let pem = private_key
|
|
||||||
.to_pkcs1_pem()
|
|
||||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
|
||||||
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
|
|
||||||
&pem,
|
|
||||||
Some(JsonWebKeyId::new(KID.to_string())),
|
|
||||||
)
|
|
||||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
|
|
||||||
.as_verification_key()]);
|
|
||||||
Ok(jwks.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn provider_metadata(
|
|
||||||
Extension(config): Extension<config::Config>,
|
|
||||||
) -> Result<Json<CoreProviderMetadata>, CustomError> {
|
|
||||||
let pm = CoreProviderMetadata::new(
|
|
||||||
IssuerUrl::from_url(config.base_url.clone()),
|
|
||||||
AuthUrl::from_url(
|
|
||||||
config
|
|
||||||
.base_url
|
|
||||||
.join("authorize")
|
|
||||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
|
||||||
),
|
|
||||||
JsonWebKeySetUrl::from_url(
|
|
||||||
config
|
|
||||||
.base_url
|
|
||||||
.join("jwk")
|
|
||||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
|
||||||
),
|
|
||||||
vec![
|
|
||||||
ResponseTypes::new(vec![CoreResponseType::Code]),
|
|
||||||
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
|
|
||||||
],
|
|
||||||
vec![CoreSubjectIdentifierType::Pairwise],
|
|
||||||
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
|
|
||||||
EmptyAdditionalProviderMetadata {},
|
|
||||||
)
|
|
||||||
.set_token_endpoint(Some(TokenUrl::from_url(
|
|
||||||
config
|
|
||||||
.base_url
|
|
||||||
.join("token")
|
|
||||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
|
||||||
)))
|
|
||||||
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
|
|
||||||
config
|
|
||||||
.base_url
|
|
||||||
.join("userinfo")
|
|
||||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
|
||||||
)))
|
|
||||||
.set_scopes_supported(Some(vec![
|
|
||||||
Scope::new("openid".to_string()),
|
|
||||||
// Scope::new("email".to_string()),
|
|
||||||
// Scope::new("profile".to_string()),
|
|
||||||
]))
|
|
||||||
.set_claims_supported(Some(vec![
|
|
||||||
CoreClaimName::new("sub".to_string()),
|
|
||||||
CoreClaimName::new("aud".to_string()),
|
|
||||||
// CoreClaimName::new("email".to_string()),
|
|
||||||
// CoreClaimName::new("email_verified".to_string()),
|
|
||||||
CoreClaimName::new("exp".to_string()),
|
|
||||||
CoreClaimName::new("iat".to_string()),
|
|
||||||
CoreClaimName::new("iss".to_string()),
|
|
||||||
// CoreClaimName::new("name".to_string()),
|
|
||||||
// CoreClaimName::new("given_name".to_string()),
|
|
||||||
// CoreClaimName::new("family_name".to_string()),
|
|
||||||
// CoreClaimName::new("picture".to_string()),
|
|
||||||
// CoreClaimName::new("locale".to_string()),
|
|
||||||
]))
|
|
||||||
.set_registration_endpoint(Some(RegistrationUrl::from_url(
|
|
||||||
config
|
|
||||||
.base_url
|
|
||||||
.join("register")
|
|
||||||
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
|
||||||
)))
|
|
||||||
.set_token_endpoint_auth_methods_supported(Some(vec![
|
|
||||||
CoreClientAuthMethod::ClientSecretBasic,
|
|
||||||
CoreClientAuthMethod::ClientSecretPost,
|
|
||||||
]));
|
|
||||||
|
|
||||||
Ok(pm.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct TokenForm {
|
|
||||||
code: String,
|
|
||||||
client_id: Option<String>,
|
|
||||||
client_secret: Option<String>,
|
|
||||||
grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO should check Authorization header
|
|
||||||
// Actually, client secret can be
|
|
||||||
// 1. in the POST (currently supported) [x]
|
|
||||||
// 2. Authorization header [x]
|
|
||||||
// 3. JWT [ ]
|
|
||||||
// 4. signed JWT [ ]
|
|
||||||
// according to Keycloak
|
|
||||||
|
|
||||||
async fn token(
|
|
||||||
form: Form<TokenForm>,
|
|
||||||
bearer: Option<TypedHeader<Authorization<Bearer>>>,
|
|
||||||
Extension(private_key): Extension<RsaPrivateKey>,
|
|
||||||
Extension(config): Extension<config::Config>,
|
|
||||||
Extension(pool): Extension<ConnectionPool>,
|
|
||||||
) -> Result<Json<CoreTokenResponse>, CustomError> {
|
|
||||||
let mut conn = pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
|
||||||
|
|
||||||
let serialized_entry: Option<Vec<u8>> = conn
|
|
||||||
.get(form.code.to_string())
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
|
||||||
if serialized_entry.is_none() {
|
|
||||||
return Err(CustomError::BadRequestToken(
|
|
||||||
TokenError {
|
|
||||||
error: CoreErrorResponseType::InvalidGrant,
|
|
||||||
}
|
|
||||||
.into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let code_entry: CodeEntry = bincode::deserialize(
|
|
||||||
&hex::decode(serialized_entry.unwrap())
|
|
||||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
|
||||||
)
|
|
||||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
|
||||||
|
|
||||||
let client_id = if let Some(c) = form.client_id.clone() {
|
|
||||||
c
|
|
||||||
} else {
|
|
||||||
code_entry.client_id.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(secret) = if let Some(TypedHeader(Authorization(b))) = bearer {
|
|
||||||
Some(b.token().to_string())
|
|
||||||
} else {
|
|
||||||
form.client_secret.clone()
|
|
||||||
} {
|
|
||||||
let conn2 = pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
|
||||||
let client_entry = get_client(conn2, client_id.clone()).await?;
|
|
||||||
if client_entry.is_none() {
|
|
||||||
return Err(CustomError::Unauthorized(
|
|
||||||
"Unrecognised client id.".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if secret != client_entry.unwrap().secret {
|
|
||||||
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
|
|
||||||
}
|
|
||||||
} else if config.require_secret {
|
|
||||||
return Err(CustomError::Unauthorized("Secret required.".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if code_entry.exchange_count > 0 {
|
|
||||||
// TODO use Oauth error response
|
|
||||||
return Err(anyhow!("Code was previously exchanged.").into());
|
|
||||||
}
|
|
||||||
conn.set_ex(
|
|
||||||
form.code.to_string(),
|
|
||||||
hex::encode(
|
|
||||||
bincode::serialize(&code_entry)
|
|
||||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
|
||||||
),
|
|
||||||
ENTRY_LIFETIME,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
|
||||||
|
|
||||||
let access_token = AccessToken::new(form.code.to_string());
|
|
||||||
let core_id_token = CoreIdTokenClaims::new(
|
|
||||||
IssuerUrl::from_url(config.base_url),
|
|
||||||
vec![Audience::new(client_id.clone())],
|
|
||||||
Utc::now() + Duration::seconds(60),
|
|
||||||
Utc::now(),
|
|
||||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
|
||||||
EmptyAdditionalClaims {},
|
|
||||||
)
|
|
||||||
.set_nonce(code_entry.nonce);
|
|
||||||
|
|
||||||
let pem = private_key
|
|
||||||
.to_pkcs1_pem()
|
|
||||||
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
|
||||||
|
|
||||||
let id_token = CoreIdToken::new(
|
|
||||||
core_id_token,
|
|
||||||
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
|
|
||||||
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
|
|
||||||
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
|
|
||||||
Some(&access_token),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.map_err(|e| anyhow!("{}", e))?;
|
|
||||||
|
|
||||||
Ok(CoreTokenResponse::new(
|
|
||||||
access_token,
|
|
||||||
CoreTokenType::Bearer,
|
|
||||||
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
|
|
||||||
)
|
|
||||||
.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct AuthorizeParams {
|
|
||||||
client_id: String,
|
|
||||||
redirect_uri: RedirectUrl,
|
|
||||||
scope: Scope,
|
|
||||||
response_type: Option<CoreResponseType>,
|
|
||||||
state: Option<String>,
|
|
||||||
nonce: Option<Nonce>,
|
|
||||||
prompt: Option<CoreAuthPrompt>,
|
|
||||||
request_uri: Option<RequestUrl>,
|
|
||||||
request: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO handle `registration` parameter
|
|
||||||
async fn authorize(
|
|
||||||
session: UserSessionFromSession,
|
|
||||||
params: Query<AuthorizeParams>,
|
|
||||||
Extension(pool): Extension<ConnectionPool>,
|
|
||||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
|
||||||
let conn = pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
|
||||||
let client_entry = get_client(conn, params.client_id.clone())
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
|
||||||
if client_entry.is_none() {
|
|
||||||
return Err(CustomError::Unauthorized(
|
|
||||||
"Unrecognised client id.".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut r_u = params.0.redirect_uri.clone().url().clone();
|
|
||||||
r_u.set_query(None);
|
|
||||||
let mut r_us: Vec<Url> = client_entry
|
|
||||||
.unwrap()
|
|
||||||
.redirect_uris
|
|
||||||
.iter_mut()
|
|
||||||
.map(|u| u.url().clone())
|
|
||||||
.collect();
|
|
||||||
r_us.iter_mut().for_each(|u| u.set_query(None));
|
|
||||||
if !r_us.contains(&r_u) {
|
|
||||||
return Ok((
|
|
||||||
HeaderMap::new(),
|
|
||||||
Redirect::to(
|
|
||||||
"/error?message=unregistered_request_uri"
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let state = if let Some(s) = params.0.state.clone() {
|
|
||||||
s
|
|
||||||
} else if params.0.request_uri.is_some() {
|
|
||||||
let mut url = params.0.redirect_uri.url().clone();
|
|
||||||
url.query_pairs_mut().append_pair(
|
|
||||||
"error",
|
|
||||||
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
|
|
||||||
);
|
|
||||||
return Ok((
|
|
||||||
HeaderMap::new(),
|
|
||||||
Redirect::to(
|
|
||||||
url.as_str()
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
} else if params.0.request.is_some() {
|
|
||||||
let mut url = params.0.redirect_uri.url().clone();
|
|
||||||
url.query_pairs_mut().append_pair(
|
|
||||||
"error",
|
|
||||||
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
|
|
||||||
);
|
|
||||||
return Ok((
|
|
||||||
HeaderMap::new(),
|
|
||||||
Redirect::to(
|
|
||||||
url.as_str()
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
} else {
|
|
||||||
let mut url = params.redirect_uri.url().clone();
|
|
||||||
url.query_pairs_mut()
|
|
||||||
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
|
||||||
url.query_pairs_mut()
|
|
||||||
.append_pair("error_description", "Missing state");
|
|
||||||
return Ok((
|
|
||||||
HeaderMap::new(),
|
|
||||||
Redirect::to(
|
|
||||||
url.as_str()
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(CoreAuthPrompt::None) = params.0.prompt {
|
|
||||||
let mut url = params.redirect_uri.url().clone();
|
|
||||||
url.query_pairs_mut().append_pair("state", &state);
|
|
||||||
url.query_pairs_mut().append_pair(
|
|
||||||
"error",
|
|
||||||
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
|
|
||||||
);
|
|
||||||
return Ok((
|
|
||||||
HeaderMap::new(),
|
|
||||||
Redirect::to(
|
|
||||||
url.as_str()
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.0.response_type.is_none() {
|
|
||||||
let mut url = params.redirect_uri.url().clone();
|
|
||||||
url.query_pairs_mut().append_pair("state", &state);
|
|
||||||
url.query_pairs_mut()
|
|
||||||
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
|
||||||
url.query_pairs_mut()
|
|
||||||
.append_pair("error_description", "Missing response_type");
|
|
||||||
return Ok((
|
|
||||||
HeaderMap::new(),
|
|
||||||
Redirect::to(
|
|
||||||
url.as_str()
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let response_type = params.0.response_type.as_ref().unwrap();
|
|
||||||
|
|
||||||
if params.scope != Scope::new("openid".to_string()) {
|
|
||||||
return Err(anyhow!("Scope not supported").into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (nonce, headers) = match session {
|
|
||||||
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
|
||||||
UserSessionFromSession::Invalid(cookie) => {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
headers.insert(header::SET_COOKIE, cookie);
|
|
||||||
return Ok((
|
|
||||||
headers,
|
|
||||||
Redirect::to(
|
|
||||||
format!(
|
|
||||||
"/authorize?client_id={}&redirect_uri={}&scope={}&response_type={}&state={}&client_id={}{}",
|
|
||||||
¶ms.0.client_id,
|
|
||||||
¶ms.0.redirect_uri.to_string(),
|
|
||||||
¶ms.0.scope.to_string(),
|
|
||||||
&response_type.as_ref(),
|
|
||||||
&state,
|
|
||||||
¶ms.0.client_id,
|
|
||||||
¶ms.0.nonce.map(|n| format!("&nonce={}", n.secret())).unwrap_or_default()
|
|
||||||
)
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
UserSessionFromSession::Created { header, nonce } => {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
headers.insert(header::SET_COOKIE, header);
|
|
||||||
(nonce, headers)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let domain = params.redirect_uri.url().host().unwrap();
|
|
||||||
let oidc_nonce_param = if let Some(n) = ¶ms.nonce {
|
|
||||||
format!("&oidc_nonce={}", n.secret())
|
|
||||||
} else {
|
|
||||||
"".to_string()
|
|
||||||
};
|
|
||||||
Ok((
|
|
||||||
headers,
|
|
||||||
Redirect::to(
|
|
||||||
format!(
|
|
||||||
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
|
|
||||||
nonce,
|
|
||||||
domain,
|
|
||||||
params.redirect_uri.to_string(),
|
|
||||||
state,
|
|
||||||
params.client_id,
|
|
||||||
oidc_nonce_param
|
|
||||||
)
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct SiweCookie {
|
|
||||||
message: Web3ModalMessage,
|
|
||||||
signature: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
struct Web3ModalMessage {
|
|
||||||
pub domain: String,
|
|
||||||
pub address: String,
|
|
||||||
pub statement: String,
|
|
||||||
pub uri: String,
|
|
||||||
pub version: String,
|
|
||||||
pub chain_id: String,
|
|
||||||
pub nonce: String,
|
|
||||||
pub issued_at: String,
|
|
||||||
pub expiration_time: Option<String>,
|
|
||||||
pub not_before: Option<String>,
|
|
||||||
pub request_id: Option<String>,
|
|
||||||
pub resources: Option<Vec<String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Web3ModalMessage {
|
|
||||||
pub fn to_eip4361_message(&self) -> Result<Message> {
|
|
||||||
let mut next_resources: Vec<UriString> = Vec::new();
|
|
||||||
match &self.resources {
|
|
||||||
Some(resources) => {
|
|
||||||
for resource in resources {
|
|
||||||
let x = UriString::from_str(resource)?;
|
|
||||||
next_resources.push(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Message {
|
|
||||||
domain: self.domain.clone().try_into()?,
|
|
||||||
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
|
|
||||||
statement: self.statement.to_string(),
|
|
||||||
uri: UriAbsoluteString::from_str(&self.uri)?,
|
|
||||||
version: Version::from_str(&self.version)?,
|
|
||||||
chain_id: self.chain_id.to_string(),
|
|
||||||
nonce: self.nonce.to_string(),
|
|
||||||
issued_at: self.issued_at.to_string(),
|
|
||||||
expiration_time: self.expiration_time.clone(),
|
|
||||||
not_before: self.not_before.clone(),
|
|
||||||
request_id: self.request_id.clone(),
|
|
||||||
resources: next_resources,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct CodeEntry {
|
|
||||||
exchange_count: usize,
|
|
||||||
address: String,
|
|
||||||
nonce: Option<Nonce>,
|
|
||||||
client_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct SignInParams {
|
|
||||||
redirect_uri: RedirectUrl,
|
|
||||||
state: String,
|
|
||||||
oidc_nonce: Option<Nonce>,
|
|
||||||
client_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn sign_in(
|
|
||||||
session: UserSessionFromSession,
|
|
||||||
params: Query<SignInParams>,
|
|
||||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
|
||||||
Extension(pool): Extension<ConnectionPool>,
|
|
||||||
) -> Result<(HeaderMap, Redirect), CustomError> {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
let siwe_cookie: SiweCookie = match cookies.get("siwe") {
|
|
||||||
Some(c) => serde_json::from_str(
|
|
||||||
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
|
|
||||||
)
|
|
||||||
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
|
|
||||||
None => {
|
|
||||||
return Err(anyhow!("No `siwe` cookie").into());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (nonce, headers) = match session {
|
|
||||||
UserSessionFromSession::Found(nonce) => (nonce, HeaderMap::new()),
|
|
||||||
UserSessionFromSession::Invalid(header) => {
|
|
||||||
headers.insert(header::SET_COOKIE, header);
|
|
||||||
return Ok((
|
|
||||||
headers,
|
|
||||||
Redirect::to(
|
|
||||||
format!(
|
|
||||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
|
||||||
¶ms.0.client_id.clone(),
|
|
||||||
¶ms.0.redirect_uri.to_string(),
|
|
||||||
¶ms.0.state,
|
|
||||||
)
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
UserSessionFromSession::Created { .. } => {
|
|
||||||
return Ok((
|
|
||||||
headers,
|
|
||||||
Redirect::to(
|
|
||||||
format!(
|
|
||||||
"/authorize?client_id={}&redirect_uri={}&scope=openid&response_type=code&state={}",
|
|
||||||
¶ms.0.client_id.clone(),
|
|
||||||
¶ms.0.redirect_uri.to_string(),
|
|
||||||
¶ms.0.state,
|
|
||||||
)
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let signature = match <[u8; 65]>::from_hex(
|
|
||||||
siwe_cookie
|
|
||||||
.signature
|
|
||||||
.chars()
|
|
||||||
.skip(2)
|
|
||||||
.take(130)
|
|
||||||
.collect::<String>(),
|
|
||||||
) {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let message = siwe_cookie
|
|
||||||
.message
|
|
||||||
.to_eip4361_message()
|
|
||||||
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
|
|
||||||
info!("{}", message);
|
|
||||||
message
|
|
||||||
.verify_eip191(signature)
|
|
||||||
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
|
|
||||||
|
|
||||||
let domain = params.redirect_uri.url().host().unwrap();
|
|
||||||
if domain.to_string() != siwe_cookie.message.domain {
|
|
||||||
return Err(anyhow!("Conflicting domains in message and redirect").into());
|
|
||||||
}
|
|
||||||
if nonce != siwe_cookie.message.nonce {
|
|
||||||
return Err(anyhow!("Conflicting nonces in message and session").into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let code_entry = CodeEntry {
|
|
||||||
address: siwe_cookie.message.address,
|
|
||||||
nonce: params.oidc_nonce.clone(),
|
|
||||||
exchange_count: 0,
|
|
||||||
client_id: params.0.client_id.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let code = Uuid::new_v4();
|
|
||||||
let mut conn = pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
|
||||||
conn.set_ex(
|
|
||||||
code.to_string(),
|
|
||||||
hex::encode(
|
|
||||||
bincode::serialize(&code_entry)
|
|
||||||
.map_err(|e| anyhow!("Failed to serialise code: {}", e))?,
|
|
||||||
),
|
|
||||||
ENTRY_LIFETIME,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to set kv: {}", e))?;
|
|
||||||
|
|
||||||
let mut url = params.redirect_uri.url().clone();
|
|
||||||
url.query_pairs_mut().append_pair("code", &code.to_string());
|
|
||||||
url.query_pairs_mut().append_pair("state", ¶ms.state);
|
|
||||||
Ok((
|
|
||||||
headers,
|
|
||||||
Redirect::to(
|
|
||||||
url.as_str()
|
|
||||||
.parse()
|
|
||||||
.map_err(|e| anyhow!("Could not parse URI: {}", e))?,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
// TODO clear session
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn register(
|
|
||||||
extract::Json(payload): extract::Json<CoreClientMetadata>,
|
|
||||||
Extension(pool): Extension<ConnectionPool>,
|
|
||||||
) -> Result<(StatusCode, Json<CoreClientRegistrationResponse>), CustomError> {
|
|
||||||
let id = Uuid::new_v4();
|
|
||||||
let secret = Uuid::new_v4();
|
|
||||||
|
|
||||||
let conn = pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
|
||||||
let entry = ClientEntry {
|
|
||||||
secret: secret.to_string(),
|
|
||||||
redirect_uris: payload.redirect_uris().to_vec(),
|
|
||||||
};
|
|
||||||
set_client(conn, id.to_string(), entry).await?;
|
|
||||||
|
|
||||||
Ok((
|
|
||||||
StatusCode::CREATED,
|
|
||||||
CoreClientRegistrationResponse::new(
|
|
||||||
ClientId::new(id.to_string()),
|
|
||||||
payload.redirect_uris().to_vec(),
|
|
||||||
EmptyAdditionalClientMetadata::default(),
|
|
||||||
EmptyAdditionalClientRegistrationResponse::default(),
|
|
||||||
)
|
|
||||||
.set_client_secret(Some(ClientSecret::new(secret.to_string())))
|
|
||||||
.into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO CORS
|
|
||||||
// TODO need validation of the token
|
|
||||||
// TODO restrict access token use to only once?
|
|
||||||
async fn userinfo(
|
|
||||||
// access_token: AccessTokenUserInfo, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
|
||||||
TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, // TODO maybe go through FromRequest https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs
|
|
||||||
Extension(pool): Extension<ConnectionPool>,
|
|
||||||
) -> Result<Json<CoreUserInfoClaims>, CustomError> {
|
|
||||||
let code = bearer.token().to_string();
|
|
||||||
let mut conn = pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))?;
|
|
||||||
let serialized_entry: Option<Vec<u8>> = conn
|
|
||||||
.get(code)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
|
||||||
if serialized_entry.is_none() {
|
|
||||||
return Err(CustomError::BadRequest("Unknown code.".to_string()));
|
|
||||||
}
|
|
||||||
let code_entry: CodeEntry = bincode::deserialize(
|
|
||||||
&hex::decode(serialized_entry.unwrap())
|
|
||||||
.map_err(|e| anyhow!("Failed to decode code entry: {}", e))?,
|
|
||||||
)
|
|
||||||
.map_err(|e| anyhow!("Failed to deserialize code: {}", e))?;
|
|
||||||
|
|
||||||
Ok(CoreUserInfoClaims::new(
|
|
||||||
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
|
||||||
EmptyAdditionalClaims::default(),
|
|
||||||
)
|
|
||||||
.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn healthcheck() {}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let config = Figment::from(Serialized::defaults(config::Config::default()))
|
axum_main().await
|
||||||
.merge(Toml::file("siwe-oidc.toml").nested())
|
|
||||||
.merge(Env::prefixed("SIWEOIDC_").split("__").global());
|
|
||||||
let config = config.extract::<config::Config>().unwrap();
|
|
||||||
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
let manager = RedisConnectionManager::new(config.redis_url.clone()).unwrap();
|
|
||||||
let pool = bb8::Pool::builder().build(manager.clone()).await.unwrap();
|
|
||||||
let pool2 = bb8::Pool::builder().build(manager).await.unwrap();
|
|
||||||
|
|
||||||
for (id, secret) in &config.default_clients.clone() {
|
|
||||||
let conn = pool2
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Failed to get connection to database: {}", e))
|
|
||||||
.unwrap();
|
|
||||||
let client_entry = ClientEntry {
|
|
||||||
secret: secret.to_string(),
|
|
||||||
redirect_uris: vec![],
|
|
||||||
};
|
|
||||||
set_client(conn, id.to_string(), client_entry)
|
|
||||||
.await
|
|
||||||
.unwrap(); // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
let private_key = if let Some(key) = &config.rsa_pem {
|
|
||||||
RsaPrivateKey::from_pkcs1_pem(key)
|
|
||||||
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
|
||||||
.unwrap()
|
|
||||||
} else {
|
|
||||||
info!("Generating key...");
|
|
||||||
let mut rng = OsRng;
|
|
||||||
let bits = 2048;
|
|
||||||
let private = RsaPrivateKey::new(&mut rng, bits)
|
|
||||||
.map_err(|e| anyhow!("Failed to generate a key: {}", e))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
info!("Generated key.");
|
|
||||||
info!("{:?}", private.to_pkcs1_pem().unwrap());
|
|
||||||
private
|
|
||||||
};
|
|
||||||
|
|
||||||
let app = Router::new()
|
|
||||||
.nest(
|
|
||||||
"/build",
|
|
||||||
service_method_routing::get(ServeDir::new("./static/build")).handle_error(
|
|
||||||
|error: std::io::Error| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.nest(
|
|
||||||
"/img",
|
|
||||||
service_method_routing::get(ServeDir::new("./static/img")).handle_error(
|
|
||||||
|error: std::io::Error| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/",
|
|
||||||
service_method_routing::get(ServeFile::new("./static/index.html")).handle_error(
|
|
||||||
|error: std::io::Error| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/error",
|
|
||||||
service_method_routing::get(ServeFile::new("./static/error.html")).handle_error(
|
|
||||||
|error: std::io::Error| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/favicon.png",
|
|
||||||
service_method_routing::get(ServeFile::new("./static/favicon.png")).handle_error(
|
|
||||||
|error: std::io::Error| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route("/.well-known/openid-configuration", get(provider_metadata))
|
|
||||||
.route("/jwk", get(jwk_set))
|
|
||||||
.route("/token", post(token))
|
|
||||||
.route("/authorize", get(authorize))
|
|
||||||
.route("/register", post(register))
|
|
||||||
.route("/userinfo", get(userinfo).post(userinfo))
|
|
||||||
.route("/sign_in", get(sign_in))
|
|
||||||
.route("/health", get(healthcheck))
|
|
||||||
.layer(AddExtensionLayer::new(private_key))
|
|
||||||
.layer(AddExtensionLayer::new(config.clone()))
|
|
||||||
.layer(AddExtensionLayer::new(pool))
|
|
||||||
.layer(AddExtensionLayer::new(
|
|
||||||
RedisSessionStore::new(config.redis_url.clone())
|
|
||||||
.unwrap()
|
|
||||||
.with_prefix("async-sessions/"),
|
|
||||||
))
|
|
||||||
.layer(TraceLayer::new_for_http());
|
|
||||||
|
|
||||||
let addr = SocketAddr::from((config.address, config.port));
|
|
||||||
tracing::info!("Listening on {}", addr);
|
|
||||||
axum::Server::bind(&addr)
|
|
||||||
.serve(app.into_make_service())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
fn main() {}
|
||||||
|
523
src/oidc.rs
Normal file
523
src/oidc.rs
Normal file
@ -0,0 +1,523 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
|
use headers::{self, authorization::Bearer};
|
||||||
|
use hex::FromHex;
|
||||||
|
use iri_string::types::UriString;
|
||||||
|
use openidconnect::{
|
||||||
|
core::{
|
||||||
|
CoreAuthErrorResponseType, CoreAuthPrompt, CoreClaimName, CoreClientAuthMethod,
|
||||||
|
CoreClientMetadata, CoreClientRegistrationResponse, CoreErrorResponseType, CoreGrantType,
|
||||||
|
CoreIdToken, CoreIdTokenClaims, CoreIdTokenFields, CoreJsonWebKeySet,
|
||||||
|
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRsaPrivateSigningKey,
|
||||||
|
CoreSubjectIdentifierType, CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
|
||||||
|
},
|
||||||
|
registration::{EmptyAdditionalClientMetadata, EmptyAdditionalClientRegistrationResponse},
|
||||||
|
url::Url,
|
||||||
|
AccessToken, Audience, AuthUrl, ClientId, ClientSecret, EmptyAdditionalClaims,
|
||||||
|
EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, IssuerUrl, JsonWebKeyId,
|
||||||
|
JsonWebKeySetUrl, Nonce, PrivateSigningKey, RedirectUrl, RegistrationUrl, RequestUrl,
|
||||||
|
ResponseTypes, Scope, StandardClaims, SubjectIdentifier, TokenUrl, UserInfoUrl,
|
||||||
|
};
|
||||||
|
use rsa::{pkcs1::ToRsaPrivateKey, RsaPrivateKey};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use siwe::eip4361::{Message, Version};
|
||||||
|
use std::str::FromStr;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tracing::info;
|
||||||
|
use urlencoding::decode;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
use super::db::*;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
use siwe_oidc::db::*;
|
||||||
|
|
||||||
|
const KID: &str = "key1";
|
||||||
|
pub const METADATA_PATH: &str = "/.well-known/openid-configuration";
|
||||||
|
pub const JWK_PATH: &str = "/jwk";
|
||||||
|
pub const TOKEN_PATH: &str = "/token";
|
||||||
|
pub const AUTHORIZE_PATH: &str = "/authorize";
|
||||||
|
pub const REGISTER_PATH: &str = "/register";
|
||||||
|
pub const USERINFO_PATH: &str = "/userinfo";
|
||||||
|
pub const SIGNIN_PATH: &str = "/sign_in";
|
||||||
|
pub const SIWE_COOKIE_KEY: &str = "siwe";
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
type DBClientType = (dyn DBClient + Sync);
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
type DBClientType = dyn DBClient;
|
||||||
|
|
||||||
|
#[derive(Serialize, Debug)]
|
||||||
|
pub struct TokenError {
|
||||||
|
pub error: CoreErrorResponseType,
|
||||||
|
pub error_description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum CustomError {
|
||||||
|
#[error("{0}")]
|
||||||
|
BadRequest(String),
|
||||||
|
#[error("{0:?}")]
|
||||||
|
BadRequestToken(TokenError),
|
||||||
|
#[error("{0}")]
|
||||||
|
Unauthorized(String),
|
||||||
|
#[error("{0:?}")]
|
||||||
|
Redirect(String),
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jwks(private_key: RsaPrivateKey) -> Result<CoreJsonWebKeySet, CustomError> {
|
||||||
|
let pem = private_key
|
||||||
|
.to_pkcs1_pem()
|
||||||
|
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||||
|
let jwks = CoreJsonWebKeySet::new(vec![CoreRsaPrivateSigningKey::from_pem(
|
||||||
|
&pem,
|
||||||
|
Some(JsonWebKeyId::new(KID.to_string())),
|
||||||
|
)
|
||||||
|
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?
|
||||||
|
.as_verification_key()]);
|
||||||
|
Ok(jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn metadata(base_url: Url) -> Result<CoreProviderMetadata, CustomError> {
|
||||||
|
let pm = CoreProviderMetadata::new(
|
||||||
|
IssuerUrl::from_url(base_url.clone()),
|
||||||
|
AuthUrl::from_url(
|
||||||
|
base_url
|
||||||
|
.join(AUTHORIZE_PATH)
|
||||||
|
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||||
|
),
|
||||||
|
JsonWebKeySetUrl::from_url(
|
||||||
|
base_url
|
||||||
|
.join(JWK_PATH)
|
||||||
|
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||||
|
),
|
||||||
|
vec![
|
||||||
|
ResponseTypes::new(vec![CoreResponseType::Code]),
|
||||||
|
ResponseTypes::new(vec![CoreResponseType::Token, CoreResponseType::IdToken]),
|
||||||
|
],
|
||||||
|
vec![CoreSubjectIdentifierType::Pairwise],
|
||||||
|
vec![CoreJwsSigningAlgorithm::RsaSsaPssSha256],
|
||||||
|
EmptyAdditionalProviderMetadata {},
|
||||||
|
)
|
||||||
|
.set_token_endpoint(Some(TokenUrl::from_url(
|
||||||
|
base_url
|
||||||
|
.join(TOKEN_PATH)
|
||||||
|
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||||
|
)))
|
||||||
|
.set_userinfo_endpoint(Some(UserInfoUrl::from_url(
|
||||||
|
base_url
|
||||||
|
.join(USERINFO_PATH)
|
||||||
|
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||||
|
)))
|
||||||
|
.set_scopes_supported(Some(vec![
|
||||||
|
Scope::new("openid".to_string()),
|
||||||
|
// Scope::new("email".to_string()),
|
||||||
|
// Scope::new("profile".to_string()),
|
||||||
|
]))
|
||||||
|
.set_claims_supported(Some(vec![
|
||||||
|
CoreClaimName::new("sub".to_string()),
|
||||||
|
CoreClaimName::new("aud".to_string()),
|
||||||
|
// CoreClaimName::new("email".to_string()),
|
||||||
|
// CoreClaimName::new("email_verified".to_string()),
|
||||||
|
CoreClaimName::new("exp".to_string()),
|
||||||
|
CoreClaimName::new("iat".to_string()),
|
||||||
|
CoreClaimName::new("iss".to_string()),
|
||||||
|
// CoreClaimName::new("name".to_string()),
|
||||||
|
// CoreClaimName::new("given_name".to_string()),
|
||||||
|
// CoreClaimName::new("family_name".to_string()),
|
||||||
|
// CoreClaimName::new("picture".to_string()),
|
||||||
|
// CoreClaimName::new("locale".to_string()),
|
||||||
|
]))
|
||||||
|
.set_registration_endpoint(Some(RegistrationUrl::from_url(
|
||||||
|
base_url
|
||||||
|
.join(REGISTER_PATH)
|
||||||
|
.map_err(|e| anyhow!("Unable to join URL: {}", e))?,
|
||||||
|
)))
|
||||||
|
.set_token_endpoint_auth_methods_supported(Some(vec![
|
||||||
|
CoreClientAuthMethod::ClientSecretBasic,
|
||||||
|
CoreClientAuthMethod::ClientSecretPost,
|
||||||
|
]));
|
||||||
|
|
||||||
|
Ok(pm)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct TokenForm {
|
||||||
|
pub code: String,
|
||||||
|
pub client_id: Option<String>,
|
||||||
|
pub client_secret: Option<String>,
|
||||||
|
pub grant_type: CoreGrantType, // TODO should just be authorization_code apparently?
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn token(
|
||||||
|
form: TokenForm,
|
||||||
|
// From the request's Authorization header
|
||||||
|
secret: Option<String>,
|
||||||
|
private_key: RsaPrivateKey,
|
||||||
|
base_url: Url,
|
||||||
|
require_secret: bool,
|
||||||
|
db_client: &DBClientType,
|
||||||
|
) -> Result<CoreTokenResponse, CustomError> {
|
||||||
|
let code_entry = if let Some(c) = db_client.get_code(form.code.to_string()).await? {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
return Err(CustomError::BadRequestToken(TokenError {
|
||||||
|
error: CoreErrorResponseType::InvalidGrant,
|
||||||
|
error_description: "Unknown code.".to_string(),
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
let client_id = if let Some(c) = form.client_id.clone() {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
code_entry.client_id.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(secret) = if let Some(b) = secret {
|
||||||
|
Some(b)
|
||||||
|
} else {
|
||||||
|
form.client_secret.clone()
|
||||||
|
} {
|
||||||
|
let client_entry = db_client.get_client(client_id.clone()).await?;
|
||||||
|
if client_entry.is_none() {
|
||||||
|
return Err(CustomError::Unauthorized(
|
||||||
|
"Unrecognised client id.".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if secret != client_entry.unwrap().secret {
|
||||||
|
return Err(CustomError::Unauthorized("Bad secret.".to_string()));
|
||||||
|
}
|
||||||
|
} else if require_secret {
|
||||||
|
return Err(CustomError::Unauthorized("Secret required.".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if code_entry.exchange_count > 0 {
|
||||||
|
// TODO use Oauth error response
|
||||||
|
return Err(CustomError::BadRequestToken(TokenError {
|
||||||
|
error: CoreErrorResponseType::InvalidGrant,
|
||||||
|
error_description: "Code was previously exchanged.".to_string(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
let mut code_entry2 = code_entry.clone();
|
||||||
|
code_entry2.exchange_count += 1;
|
||||||
|
db_client
|
||||||
|
.set_code(form.code.to_string(), code_entry2)
|
||||||
|
.await?;
|
||||||
|
let access_token = AccessToken::new(form.code);
|
||||||
|
let core_id_token = CoreIdTokenClaims::new(
|
||||||
|
IssuerUrl::from_url(base_url),
|
||||||
|
vec![Audience::new(client_id.clone())],
|
||||||
|
Utc::now() + Duration::seconds(60),
|
||||||
|
Utc::now(),
|
||||||
|
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||||
|
EmptyAdditionalClaims {},
|
||||||
|
)
|
||||||
|
.set_nonce(code_entry.nonce);
|
||||||
|
|
||||||
|
let pem = private_key
|
||||||
|
.to_pkcs1_pem()
|
||||||
|
.map_err(|e| anyhow!("Failed to serialise key as PEM: {}", e))?;
|
||||||
|
|
||||||
|
let id_token = CoreIdToken::new(
|
||||||
|
core_id_token,
|
||||||
|
&CoreRsaPrivateSigningKey::from_pem(&pem, Some(JsonWebKeyId::new(KID.to_string())))
|
||||||
|
.map_err(|e| anyhow!("Invalid RSA private key: {}", e))?,
|
||||||
|
CoreJwsSigningAlgorithm::RsaSsaPkcs1V15Sha256,
|
||||||
|
Some(&access_token),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.map_err(|e| anyhow!("{}", e))?;
|
||||||
|
|
||||||
|
Ok(CoreTokenResponse::new(
|
||||||
|
access_token,
|
||||||
|
CoreTokenType::Bearer,
|
||||||
|
CoreIdTokenFields::new(Some(id_token), EmptyExtraTokenFields {}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct AuthorizeParams {
|
||||||
|
pub client_id: String,
|
||||||
|
pub redirect_uri: RedirectUrl,
|
||||||
|
pub scope: Scope,
|
||||||
|
pub response_type: Option<CoreResponseType>,
|
||||||
|
pub state: Option<String>,
|
||||||
|
pub nonce: Option<Nonce>,
|
||||||
|
pub prompt: Option<CoreAuthPrompt>,
|
||||||
|
pub request_uri: Option<RequestUrl>,
|
||||||
|
pub request: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn authorize(
|
||||||
|
params: AuthorizeParams,
|
||||||
|
nonce: String,
|
||||||
|
db_client: &DBClientType,
|
||||||
|
) -> Result<String, CustomError> {
|
||||||
|
let client_entry = db_client
|
||||||
|
.get_client(params.client_id.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get kv: {}", e))?;
|
||||||
|
if client_entry.is_none() {
|
||||||
|
return Err(CustomError::Unauthorized(
|
||||||
|
"Unrecognised client id.".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut r_u = params.redirect_uri.clone().url().clone();
|
||||||
|
r_u.set_query(None);
|
||||||
|
let mut r_us: Vec<Url> = client_entry
|
||||||
|
.unwrap()
|
||||||
|
.redirect_uris
|
||||||
|
.iter_mut()
|
||||||
|
.map(|u| u.url().clone())
|
||||||
|
.collect();
|
||||||
|
r_us.iter_mut().for_each(|u| u.set_query(None));
|
||||||
|
if !r_us.contains(&r_u) {
|
||||||
|
return Err(CustomError::Redirect(
|
||||||
|
"/error?message=unregistered_request_uri".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let state = if let Some(s) = params.state.clone() {
|
||||||
|
s
|
||||||
|
} else if params.request_uri.is_some() {
|
||||||
|
let mut url = params.redirect_uri.url().clone();
|
||||||
|
url.query_pairs_mut().append_pair(
|
||||||
|
"error",
|
||||||
|
CoreAuthErrorResponseType::RequestUriNotSupported.as_ref(),
|
||||||
|
);
|
||||||
|
return Err(CustomError::Redirect(url.to_string()));
|
||||||
|
} else if params.request.is_some() {
|
||||||
|
let mut url = params.redirect_uri.url().clone();
|
||||||
|
url.query_pairs_mut().append_pair(
|
||||||
|
"error",
|
||||||
|
CoreAuthErrorResponseType::RequestNotSupported.as_ref(),
|
||||||
|
);
|
||||||
|
return Err(CustomError::Redirect(url.to_string()));
|
||||||
|
} else {
|
||||||
|
let mut url = params.redirect_uri.url().clone();
|
||||||
|
url.query_pairs_mut()
|
||||||
|
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
||||||
|
url.query_pairs_mut()
|
||||||
|
.append_pair("error_description", "Missing state");
|
||||||
|
return Err(CustomError::Redirect(url.to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(CoreAuthPrompt::None) = params.prompt {
|
||||||
|
let mut url = params.redirect_uri.url().clone();
|
||||||
|
url.query_pairs_mut().append_pair("state", &state);
|
||||||
|
url.query_pairs_mut().append_pair(
|
||||||
|
"error",
|
||||||
|
CoreAuthErrorResponseType::InteractionRequired.as_ref(),
|
||||||
|
);
|
||||||
|
return Err(CustomError::Redirect(url.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.response_type.is_none() {
|
||||||
|
let mut url = params.redirect_uri.url().clone();
|
||||||
|
url.query_pairs_mut().append_pair("state", &state);
|
||||||
|
url.query_pairs_mut()
|
||||||
|
.append_pair("error", CoreAuthErrorResponseType::InvalidRequest.as_ref());
|
||||||
|
url.query_pairs_mut()
|
||||||
|
.append_pair("error_description", "Missing response_type");
|
||||||
|
return Err(CustomError::Redirect(url.to_string()));
|
||||||
|
}
|
||||||
|
let _response_type = params.response_type.as_ref().unwrap();
|
||||||
|
|
||||||
|
if params.scope != Scope::new("openid".to_string()) {
|
||||||
|
return Err(anyhow!("Scope not supported").into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let domain = params.redirect_uri.url().host().unwrap();
|
||||||
|
let oidc_nonce_param = if let Some(n) = ¶ms.nonce {
|
||||||
|
format!("&oidc_nonce={}", n.secret())
|
||||||
|
} else {
|
||||||
|
"".to_string()
|
||||||
|
};
|
||||||
|
Ok(format!(
|
||||||
|
"/?nonce={}&domain={}&redirect_uri={}&state={}&client_id={}{}",
|
||||||
|
nonce,
|
||||||
|
domain,
|
||||||
|
params.redirect_uri.to_string(),
|
||||||
|
state,
|
||||||
|
params.client_id,
|
||||||
|
oidc_nonce_param
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct SiweCookie {
|
||||||
|
message: Web3ModalMessage,
|
||||||
|
signature: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct Web3ModalMessage {
|
||||||
|
pub domain: String,
|
||||||
|
pub address: String,
|
||||||
|
pub statement: String,
|
||||||
|
pub uri: String,
|
||||||
|
pub version: String,
|
||||||
|
pub chain_id: String,
|
||||||
|
pub nonce: String,
|
||||||
|
pub issued_at: String,
|
||||||
|
pub expiration_time: Option<String>,
|
||||||
|
pub not_before: Option<String>,
|
||||||
|
pub request_id: Option<String>,
|
||||||
|
pub resources: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Web3ModalMessage {
|
||||||
|
fn to_eip4361_message(&self) -> Result<Message> {
|
||||||
|
let mut next_resources: Vec<UriString> = Vec::new();
|
||||||
|
match &self.resources {
|
||||||
|
Some(resources) => {
|
||||||
|
for resource in resources {
|
||||||
|
let x = UriString::from_str(resource)?;
|
||||||
|
next_resources.push(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Message {
|
||||||
|
domain: self.domain.clone().try_into()?,
|
||||||
|
address: <[u8; 20]>::from_hex(self.address.chars().skip(2).collect::<String>())?,
|
||||||
|
statement: self.statement.to_string(),
|
||||||
|
uri: UriString::from_str(&self.uri)?,
|
||||||
|
version: Version::from_str(&self.version)?,
|
||||||
|
chain_id: self.chain_id.to_string(),
|
||||||
|
nonce: self.nonce.to_string(),
|
||||||
|
issued_at: self.issued_at.to_string(),
|
||||||
|
expiration_time: self.expiration_time.clone(),
|
||||||
|
not_before: self.not_before.clone(),
|
||||||
|
request_id: self.request_id.clone(),
|
||||||
|
resources: next_resources,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct SignInParams {
|
||||||
|
pub redirect_uri: RedirectUrl,
|
||||||
|
pub state: String,
|
||||||
|
pub oidc_nonce: Option<Nonce>,
|
||||||
|
pub client_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn sign_in(
|
||||||
|
params: SignInParams,
|
||||||
|
expected_nonce: Option<String>,
|
||||||
|
cookies: headers::Cookie,
|
||||||
|
db_client: &DBClientType,
|
||||||
|
) -> Result<Url, CustomError> {
|
||||||
|
let siwe_cookie: SiweCookie = match cookies.get(SIWE_COOKIE_KEY) {
|
||||||
|
Some(c) => serde_json::from_str(
|
||||||
|
&decode(c).map_err(|e| anyhow!("Could not decode siwe cookie: {}", e))?,
|
||||||
|
)
|
||||||
|
.map_err(|e| anyhow!("Could not deserialize siwe cookie: {}", e))?,
|
||||||
|
None => {
|
||||||
|
return Err(anyhow!("No `siwe` cookie").into());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let signature = match <[u8; 65]>::from_hex(
|
||||||
|
siwe_cookie
|
||||||
|
.signature
|
||||||
|
.chars()
|
||||||
|
.skip(2)
|
||||||
|
.take(130)
|
||||||
|
.collect::<String>(),
|
||||||
|
) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(CustomError::BadRequest(format!("Bad signature: {}", e)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let message = siwe_cookie
|
||||||
|
.message
|
||||||
|
.to_eip4361_message()
|
||||||
|
.map_err(|e| anyhow!("Failed to serialise message: {}", e))?;
|
||||||
|
info!("{}", message);
|
||||||
|
message
|
||||||
|
.verify(signature)
|
||||||
|
.map_err(|e| anyhow!("Failed signature validation: {}", e))?;
|
||||||
|
|
||||||
|
let domain = params.redirect_uri.url().host().unwrap();
|
||||||
|
if domain.to_string() != siwe_cookie.message.domain {
|
||||||
|
return Err(anyhow!("Conflicting domains in message and redirect").into());
|
||||||
|
}
|
||||||
|
if expected_nonce.is_some() && expected_nonce.unwrap() != siwe_cookie.message.nonce {
|
||||||
|
return Err(anyhow!("Conflicting nonces in message and session").into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let code_entry = CodeEntry {
|
||||||
|
address: siwe_cookie.message.address,
|
||||||
|
nonce: params.oidc_nonce.clone(),
|
||||||
|
exchange_count: 0,
|
||||||
|
client_id: params.client_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let code = Uuid::new_v4();
|
||||||
|
db_client.set_code(code.to_string(), code_entry).await?;
|
||||||
|
|
||||||
|
let mut url = params.redirect_uri.url().clone();
|
||||||
|
url.query_pairs_mut().append_pair("code", &code.to_string());
|
||||||
|
url.query_pairs_mut().append_pair("state", ¶ms.state);
|
||||||
|
Ok(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register(
|
||||||
|
payload: CoreClientMetadata,
|
||||||
|
db_client: &DBClientType,
|
||||||
|
) -> Result<CoreClientRegistrationResponse, CustomError> {
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let secret = Uuid::new_v4();
|
||||||
|
|
||||||
|
let entry = ClientEntry {
|
||||||
|
secret: secret.to_string(),
|
||||||
|
redirect_uris: payload.redirect_uris().to_vec(),
|
||||||
|
};
|
||||||
|
db_client.set_client(id.to_string(), entry).await?;
|
||||||
|
|
||||||
|
Ok(CoreClientRegistrationResponse::new(
|
||||||
|
ClientId::new(id.to_string()),
|
||||||
|
payload.redirect_uris().to_vec(),
|
||||||
|
EmptyAdditionalClientMetadata::default(),
|
||||||
|
EmptyAdditionalClientRegistrationResponse::default(),
|
||||||
|
)
|
||||||
|
.set_client_secret(Some(ClientSecret::new(secret.to_string()))))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct UserInfoPayload {
|
||||||
|
pub access_token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn userinfo(
|
||||||
|
bearer: Option<Bearer>,
|
||||||
|
payload: UserInfoPayload,
|
||||||
|
db_client: &DBClientType,
|
||||||
|
) -> Result<CoreUserInfoClaims, CustomError> {
|
||||||
|
let code = if let Some(b) = bearer {
|
||||||
|
b.token().to_string()
|
||||||
|
} else if let Some(c) = payload.access_token {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
return Err(CustomError::BadRequest("Missing access token.".to_string()));
|
||||||
|
};
|
||||||
|
let code_entry = if let Some(c) = db_client.get_code(code).await? {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
return Err(CustomError::BadRequest("Unknown code.".to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(CoreUserInfoClaims::new(
|
||||||
|
StandardClaims::new(SubjectIdentifier::new(code_entry.address)),
|
||||||
|
EmptyAdditionalClaims::default(),
|
||||||
|
))
|
||||||
|
}
|
210
src/worker_lib.rs
Normal file
210
src/worker_lib.rs
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
use anyhow::anyhow;
|
||||||
|
use headers::{
|
||||||
|
self,
|
||||||
|
authorization::{Basic, Bearer, Credentials},
|
||||||
|
Authorization, Header, HeaderValue,
|
||||||
|
};
|
||||||
|
use rand::{distributions::Alphanumeric, Rng};
|
||||||
|
use rsa::{pkcs1::FromRsaPrivateKey, RsaPrivateKey};
|
||||||
|
use worker::*;
|
||||||
|
|
||||||
|
use super::db::CFClient;
|
||||||
|
use super::oidc::{self, CustomError, TokenForm, UserInfoPayload};
|
||||||
|
|
||||||
|
const BASE_URL_KEY: &str = "BASE_URL";
|
||||||
|
const RSA_PEM_KEY: &str = "RSA_PEM";
|
||||||
|
|
||||||
|
// https://github.com/cloudflare/workers-rs/issues/64
|
||||||
|
// #[global_allocator]
|
||||||
|
// static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
|
||||||
|
|
||||||
|
impl From<CustomError> for Result<Response> {
|
||||||
|
fn from(error: CustomError) -> Self {
|
||||||
|
match error {
|
||||||
|
CustomError::BadRequest(_) => Response::error(&error.to_string(), 400),
|
||||||
|
CustomError::BadRequestToken(e) => Response::from_json(&e).map(|r| r.with_status(400)),
|
||||||
|
CustomError::Unauthorized(_) => Response::error(&error.to_string(), 401),
|
||||||
|
CustomError::Redirect(uri) => Response::redirect(uri.parse().unwrap()),
|
||||||
|
CustomError::Other(_) => Response::error(&error.to_string(), 500),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn main(req: Request, env: Env) -> Result<Response> {
|
||||||
|
console_error_panic_hook::set_once();
|
||||||
|
// tracing_subscriber::fmt::init();
|
||||||
|
// console_log::init_with_level(log::Level::Info).expect("error initializing log");
|
||||||
|
|
||||||
|
let userinfo = |mut req: Request, ctx: RouteContext<()>| async move {
|
||||||
|
let bearer = req
|
||||||
|
.headers()
|
||||||
|
.get(Authorization::<Bearer>::name().as_str())?
|
||||||
|
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
|
||||||
|
.as_ref()
|
||||||
|
.and_then(Bearer::decode);
|
||||||
|
let payload = if bearer.is_none() {
|
||||||
|
match req.form_data().await {
|
||||||
|
Ok(f) => {
|
||||||
|
let access_token = if let Some(FormEntry::Field(a)) = f.get("access_token") {
|
||||||
|
Some(a)
|
||||||
|
} else {
|
||||||
|
return Response::error("Missing code", 400);
|
||||||
|
};
|
||||||
|
UserInfoPayload { access_token }
|
||||||
|
}
|
||||||
|
Err(_) => return Response::error("Bad request", 400),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
UserInfoPayload { access_token: None }
|
||||||
|
};
|
||||||
|
let url = req.url()?;
|
||||||
|
let db_client = CFClient { ctx, url };
|
||||||
|
match oidc::userinfo(bearer, payload, &db_client).await {
|
||||||
|
Ok(r) => Ok(Response::from_json(&r)?),
|
||||||
|
Err(e) => e.into(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let router = Router::new();
|
||||||
|
router
|
||||||
|
.get_async(oidc::METADATA_PATH, |_req, ctx| async move {
|
||||||
|
match oidc::metadata(ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap()) {
|
||||||
|
Ok(m) => Response::from_json(&m),
|
||||||
|
Err(e) => e.into(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.get_async(oidc::JWK_PATH, |_req, ctx| async move {
|
||||||
|
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
|
||||||
|
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
match oidc::jwks(private_key) {
|
||||||
|
Ok(m) => Response::from_json(&m),
|
||||||
|
Err(e) => e.into(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.post_async(oidc::TOKEN_PATH, |mut req, ctx| async move {
|
||||||
|
let form_data = req.form_data().await?;
|
||||||
|
let code = if let Some(FormEntry::Field(c)) = form_data.get("code") {
|
||||||
|
c
|
||||||
|
} else {
|
||||||
|
return Response::error("Missing code", 400);
|
||||||
|
};
|
||||||
|
let client_id = match form_data.get("client_id") {
|
||||||
|
Some(FormEntry::Field(c)) => Some(c),
|
||||||
|
None => None,
|
||||||
|
_ => return Response::error("Client ID not a field", 400),
|
||||||
|
};
|
||||||
|
let client_secret = match form_data.get("client_secret") {
|
||||||
|
Some(FormEntry::Field(c)) => Some(c),
|
||||||
|
None => None,
|
||||||
|
_ => return Response::error("Client secret not a field", 400),
|
||||||
|
};
|
||||||
|
let grant_type = if let Some(FormEntry::Field(c)) = form_data.get("code") {
|
||||||
|
if let Ok(cc) = serde_json::from_str(&format!("\"{}\"", c)) {
|
||||||
|
cc
|
||||||
|
} else {
|
||||||
|
return Response::error("Invalid grant type", 400);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Response::error("Missing grant type", 400);
|
||||||
|
};
|
||||||
|
let secret = req
|
||||||
|
.headers()
|
||||||
|
.get(Authorization::<Bearer>::name().as_str())?
|
||||||
|
.and_then(|b| HeaderValue::from_str(b.as_ref()).ok())
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|b| {
|
||||||
|
if b.to_str().unwrap().starts_with("Bearer") {
|
||||||
|
Bearer::decode(b).map(|bb| bb.token().to_string())
|
||||||
|
} else {
|
||||||
|
Basic::decode(b).map(|bb| bb.password().to_string())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let private_key = RsaPrivateKey::from_pkcs1_pem(&ctx.secret(RSA_PEM_KEY)?.to_string())
|
||||||
|
.map_err(|e| anyhow!("Failed to load private key: {}", e))
|
||||||
|
.unwrap();
|
||||||
|
let base_url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
|
||||||
|
let url = req.url()?;
|
||||||
|
let db_client = CFClient { ctx, url };
|
||||||
|
let token_response = oidc::token(
|
||||||
|
TokenForm {
|
||||||
|
code,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
grant_type,
|
||||||
|
},
|
||||||
|
secret,
|
||||||
|
private_key,
|
||||||
|
base_url,
|
||||||
|
false,
|
||||||
|
&db_client,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
match token_response {
|
||||||
|
Ok(m) => Response::from_json(&m),
|
||||||
|
Err(e) => e.into(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
// TODO add browser session
|
||||||
|
.get_async(oidc::AUTHORIZE_PATH, |req, ctx| async move {
|
||||||
|
let base_url: Url = ctx.var(BASE_URL_KEY)?.to_string().parse().unwrap();
|
||||||
|
let url = req.url()?;
|
||||||
|
let query = url.query().unwrap_or_default();
|
||||||
|
let params = match serde_urlencoded::from_str(query) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
|
||||||
|
};
|
||||||
|
let nonce = rand::thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(16)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
let url = req.url()?;
|
||||||
|
let db_client = CFClient { ctx, url };
|
||||||
|
match oidc::authorize(params, nonce, &db_client).await {
|
||||||
|
Ok(url) => Response::redirect(base_url.join(&url).unwrap()),
|
||||||
|
Err(e) => match e {
|
||||||
|
CustomError::Redirect(url) => {
|
||||||
|
CustomError::Redirect(base_url.join(&url).unwrap().to_string())
|
||||||
|
}
|
||||||
|
c => c,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.post_async(oidc::REGISTER_PATH, |mut req, ctx| async move {
|
||||||
|
let payload = req.json().await?;
|
||||||
|
let url = req.url()?;
|
||||||
|
let db_client = CFClient { ctx, url };
|
||||||
|
match oidc::register(payload, &db_client).await {
|
||||||
|
Ok(r) => Ok(Response::from_json(&r)?.with_status(201)),
|
||||||
|
Err(e) => e.into(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.post_async(oidc::USERINFO_PATH, userinfo)
|
||||||
|
.get_async(oidc::USERINFO_PATH, userinfo)
|
||||||
|
.get_async(oidc::SIGNIN_PATH, |req, ctx| async move {
|
||||||
|
let url = req.url()?;
|
||||||
|
let query = url.query().unwrap_or_default();
|
||||||
|
let params = match serde_urlencoded::from_str(query) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => return CustomError::BadRequest("Bad query params".to_string()).into(),
|
||||||
|
};
|
||||||
|
let cookies = req
|
||||||
|
.headers()
|
||||||
|
.get(headers::Cookie::name().as_str())?
|
||||||
|
.and_then(|c| HeaderValue::from_str(&c).ok())
|
||||||
|
.and_then(|c| headers::Cookie::decode(&mut [c].iter()).ok());
|
||||||
|
if cookies.is_none() {
|
||||||
|
return Response::error("Missing cookies", 400);
|
||||||
|
}
|
||||||
|
let url = req.url()?;
|
||||||
|
let db_client = CFClient { ctx, url };
|
||||||
|
match oidc::sign_in(params, None, cookies.unwrap(), &db_client).await {
|
||||||
|
Ok(url) => Response::redirect(url),
|
||||||
|
Err(e) => e.into(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.run(req, env)
|
||||||
|
.await
|
||||||
|
}
|
35
wrangle_example.toml
Normal file
35
wrangle_example.toml
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
name = "siwe_oidc"
|
||||||
|
type = "javascript"
|
||||||
|
account_id = ""
|
||||||
|
# zone_id = ""
|
||||||
|
workers_dev = false
|
||||||
|
compatibility_date = "2021-12-20"
|
||||||
|
|
||||||
|
kv_namespaces = [
|
||||||
|
{ binding = "SIWE-OIDC", id = "", preview_id = "" }
|
||||||
|
]
|
||||||
|
|
||||||
|
[vars]
|
||||||
|
WORKERS_RS_VERSION = "0.0.7"
|
||||||
|
BASE_URL = "https://siweoidc.spruceid.xyz"
|
||||||
|
|
||||||
|
[durable_objects]
|
||||||
|
bindings = [
|
||||||
|
{ name = "SIWE-OIDC-CODES", class_name = "DOCodes" }
|
||||||
|
]
|
||||||
|
|
||||||
|
[[migrations]]
|
||||||
|
tag = "v1"
|
||||||
|
new_classes = ["DOCodes"]
|
||||||
|
|
||||||
|
[build]
|
||||||
|
command = "cargo install -q worker-build && worker-build --release"
|
||||||
|
|
||||||
|
[build.upload]
|
||||||
|
dir = "build/worker"
|
||||||
|
format = "modules"
|
||||||
|
main = "./shim.mjs"
|
||||||
|
|
||||||
|
[[build.upload.rules]]
|
||||||
|
globs = ["**/*.wasm"]
|
||||||
|
type = "CompiledWasm"
|
Loading…
Reference in New Issue
Block a user