From 9c7c2858cc603cde9d5336c670c57dceabd53b9e Mon Sep 17 00:00:00 2001 From: nullishamy Date: Thu, 24 Apr 2025 20:19:54 +0100 Subject: [PATCH] feat: auth basics --- Cargo.lock | 1 + Cargo.toml | 3 +- ferri-main/Cargo.toml | 2 +- ferri-main/src/lib.rs | 9 ++ ferri-server/Cargo.toml | 1 + ferri-server/src/endpoints/api/apps.rs | 36 ++++++- ferri-server/src/endpoints/api/timeline.rs | 9 +- ferri-server/src/endpoints/oauth.rs | 111 +++++++++++++++++++-- ferri-server/src/lib.rs | 44 ++++++-- ferri-server/src/types/oauth.rs | 2 +- migrations/20250423182916_add_auth.sql | 26 +++++ 11 files changed, 215 insertions(+), 29 deletions(-) create mode 100644 migrations/20250423182916_add_auth.sql diff --git a/Cargo.lock b/Cargo.lock index 402385e..f1201d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2156,6 +2156,7 @@ version = "0.1.0" dependencies = [ "chrono", "main", + "rand 0.8.5", "reqwest", "rocket", "rocket_db_pools", diff --git a/Cargo.toml b/Cargo.toml index b0d9a9e..e039c64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,5 @@ serde = "1.0.219" rocket = { version = "0.5.1", features = ["json"] } sqlx = { version = "0.7", features = [ "runtime-tokio", "sqlite", "macros" ], default-features = false } uuid = { version = "1.16.0", features = ["v4"] } -chrono = "0.4.40" \ No newline at end of file +chrono = "0.4.40" +rand = "0.8" \ No newline at end of file diff --git a/ferri-main/Cargo.toml b/ferri-main/Cargo.toml index 855afc8..0d91f57 100644 --- a/ferri-main/Cargo.toml +++ b/ferri-main/Cargo.toml @@ -13,5 +13,5 @@ uuid = { workspace = true } base64 = "0.22.1" rsa = { version = "0.9.8", features = ["sha2"] } -rand = "0.8" +rand = { workspace = true } url = "2.5.4" diff --git a/ferri-main/src/lib.rs b/ferri-main/src/lib.rs index 4d8826e..620300a 100644 --- a/ferri-main/src/lib.rs +++ b/ferri-main/src/lib.rs @@ -1,2 +1,11 @@ pub mod ap; pub mod config; +use rand::{Rng, distributions::Alphanumeric}; + +pub fn gen_token(len: usize) -> String { + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(len) + .map(char::from) + .collect() +} diff --git a/ferri-server/Cargo.toml b/ferri-server/Cargo.toml index 5ff5b60..fba127f 100644 --- a/ferri-server/Cargo.toml +++ b/ferri-server/Cargo.toml @@ -11,5 +11,6 @@ reqwest = { workspace = true } sqlx = { workspace = true } uuid = { workspace = true } chrono = { workspace = true } +rand = { workspace = true } url = "2.5.4" \ No newline at end of file diff --git a/ferri-server/src/endpoints/api/apps.rs b/ferri-server/src/endpoints/api/apps.rs index 40425cf..1ed5001 100644 --- a/ferri-server/src/endpoints/api/apps.rs +++ b/ferri-server/src/endpoints/api/apps.rs @@ -1,14 +1,44 @@ use rocket::{form::Form, post, serde::json::Json}; +use crate::Db; use crate::types::oauth::{App, CredentialApplication}; +use rocket_db_pools::Connection; #[post("/apps", data = "")] -pub async fn new_app(app: Form) -> Json { +pub async fn new_app(app: Form, mut db: Connection) -> Json { + let secret = main::gen_token(15); + + // Abort when we encounter a duplicate + let is_app_present = sqlx::query!( + r#" + INSERT INTO app (client_id, client_secret, scopes) + VALUES (?1, ?2, ?3) + "#, + app.client_name, + app.scopes, + secret + ) + .execute(&mut **db) + .await + .is_err(); + + let mut app: App = app.clone(); + + if is_app_present { + let existing_app = sqlx::query!("SELECT * FROM app WHERE client_id = ?1", app.client_name) + .fetch_one(&mut **db) + .await + .unwrap(); + + app.client_name = existing_app.client_id; + app.scopes = existing_app.scopes; + } + Json(CredentialApplication { name: app.client_name.clone(), scopes: app.scopes.clone(), redirect_uris: app.redirect_uris.clone(), - client_id: format!("id-for-{}", app.client_name), - client_secret: format!("secret-for-{}", app.client_name), + client_id: app.client_name.clone(), + client_secret: secret, }) } diff --git a/ferri-server/src/endpoints/api/timeline.rs b/ferri-server/src/endpoints/api/timeline.rs index 385c0de..3a3ec13 100644 --- a/ferri-server/src/endpoints/api/timeline.rs +++ b/ferri-server/src/endpoints/api/timeline.rs @@ -1,4 +1,4 @@ -use crate::{Db, endpoints::api::user::CredentialAcount}; +use crate::{AuthenticatedUser, Db, endpoints::api::user::CredentialAcount}; use rocket::{ get, serde::{Deserialize, Serialize, json::Json}, @@ -32,7 +32,12 @@ pub struct TimelineStatus { } #[get("/timelines/home?")] -pub async fn home(mut db: Connection, limit: i64) -> Json> { +pub async fn home( + mut db: Connection, + limit: i64, + user: AuthenticatedUser, +) -> Json> { + dbg!(user); let posts = sqlx::query!( r#" SELECT p.id as "post_id", u.id as "user_id", p.content, p.uri as "post_uri", diff --git a/ferri-server/src/endpoints/oauth.rs b/ferri-server/src/endpoints/oauth.rs index 5b48ebc..7045ec1 100644 --- a/ferri-server/src/endpoints/oauth.rs +++ b/ferri-server/src/endpoints/oauth.rs @@ -1,8 +1,12 @@ +use crate::Db; use rocket::{ + FromForm, + form::Form, get, post, response::Redirect, serde::{Deserialize, Serialize, json::Json}, }; +use rocket_db_pools::Connection; #[get("/oauth/authorize?&&&")] pub async fn authorize( @@ -10,11 +14,45 @@ pub async fn authorize( scope: &str, redirect_uri: &str, response_type: &str, + mut db: Connection, ) -> Redirect { - Redirect::temporary(format!( - "{}?code=code-for-{}&state=state-for-{}", - redirect_uri, client_id, client_id - )) + // For now, we will always authorize the request and assign it to an admin user + let user_id = "9b9d497b-2731-435f-a929-e609ca69dac9"; + let code = main::gen_token(15); + + // This will act as a token for the user, but we will in future say that it expires very shortly + // and can only be used for obtaining an access token etc + sqlx::query!( + r#" + INSERT INTO auth (token, user_id) + VALUES (?1, ?2) + "#, + code, + user_id + ) + .execute(&mut **db) + .await + .unwrap(); + + let id_token = main::gen_token(10); + + // Add an oauth entry for the `code` which /oauth/token will rewrite + sqlx::query!( + r#" + INSERT INTO oauth (id_token, client_id, expires_in, scope, access_token) + VALUES (?1, ?2, ?3, ?4, ?5) + "#, + id_token, + client_id, + 3600, + scope, + code + ) + .execute(&mut **db) + .await + .unwrap(); + + Redirect::temporary(format!("{}?code={}", redirect_uri, code)) } #[derive(Serialize, Deserialize, Debug)] @@ -27,13 +65,66 @@ pub struct Token { pub id_token: String, } -#[post("/oauth/token")] -pub async fn new_token() -> Json { +#[derive(Deserialize, Debug, FromForm)] +#[serde(crate = "rocket::serde")] +struct NewTokenRequest { + client_id: String, + redirect_uri: String, + grant_type: String, + code: String, + client_secret: String, +} + +#[post("/oauth/token", data = "")] +pub async fn new_token(req: Form, mut db: Connection) -> Json { + let oauth = sqlx::query!( + r#" + SELECT o.*, a.* + FROM oauth o + INNER JOIN auth a ON a.token = ?2 + WHERE o.access_token = ?1 + "#, + req.code, + req.code + ) + .fetch_one(&mut **db) + .await + .unwrap(); + + let access_token = main::gen_token(15); + + // Important: setup 'auth' first + sqlx::query!( + r#" + INSERT INTO auth (token, user_id) + VALUES (?1, ?2) + "#, + access_token, + oauth.user_id + ) + .execute(&mut **db) + .await + .unwrap(); + + sqlx::query!( + "UPDATE oauth SET access_token = ?1 WHERE access_token = ?2", + access_token, + req.code + ) + .execute(&mut **db) + .await + .unwrap(); + + sqlx::query!("DELETE FROM auth WHERE token = ?1", req.code) + .execute(&mut **db) + .await + .unwrap(); + Json(Token { - access_token: "9b9d497b-2731-435f-a929-e609ca69dac9".to_string(), + access_token: access_token.to_string(), token_type: "Bearer".to_string(), - expires_in: 3600, - scope: "read write follow push".to_string(), - id_token: "id-token".to_string(), + expires_in: oauth.expires_in, + scope: oauth.scope.to_string(), + id_token: oauth.id_token, }) } diff --git a/ferri-server/src/lib.rs b/ferri-server/src/lib.rs index 176de74..baebb31 100644 --- a/ferri-server/src/lib.rs +++ b/ferri-server/src/lib.rs @@ -2,15 +2,17 @@ use endpoints::{ api::{self, timeline}, custom, inbox, oauth, user, well_known, }; + use main::ap::http; use main::config::Config; use rocket::{ Build, Request, Rocket, build, get, - http::ContentType, + http::{ContentType, Status}, + outcome::IntoOutcome, request::{FromRequest, Outcome}, routes, }; -use rocket_db_pools::{Database, sqlx}; +use rocket_db_pools::{Connection, Database, sqlx}; mod cors; mod endpoints; @@ -34,6 +36,7 @@ async fn activity_endpoint(activity: String) { #[derive(Debug)] struct AuthenticatedUser { username: String, + token: String, actor_id: String, } @@ -48,15 +51,34 @@ enum LoginError { impl<'a> FromRequest<'a> for AuthenticatedUser { type Error = LoginError; async fn from_request(request: &'a Request<'_>) -> Outcome { - let token = request.headers().get_one("Authorization").unwrap(); - let token = token - .strip_prefix("Bearer") - .map(|s| s.trim()) - .unwrap_or(token); - Outcome::Success(AuthenticatedUser { - username: token.to_string(), - actor_id: format!("https://ferri.amy.mov/users/{}", token), - }) + let token = request.headers().get_one("Authorization"); + + if let Some(token) = token { + let token = token + .strip_prefix("Bearer") + .map(|s| s.trim()) + .unwrap_or(token); + + let mut conn = request.guard::>().await.unwrap(); + let auth = sqlx::query!(r#" + SELECT * + FROM auth a + INNER JOIN user u ON a.user_id = u.id + WHERE token = ?1 + "#, token) + .fetch_one(&mut **conn) + .await; + + if let Ok(auth) = auth { + return Outcome::Success(AuthenticatedUser { + token: auth.token, + username: auth.display_name, + actor_id: auth.actor_id, + }) + } + } + + Outcome::Forward(Status::Unauthorized) } } diff --git a/ferri-server/src/types/oauth.rs b/ferri-server/src/types/oauth.rs index 567dd19..8791fe3 100644 --- a/ferri-server/src/types/oauth.rs +++ b/ferri-server/src/types/oauth.rs @@ -3,7 +3,7 @@ use rocket::{ serde::{Deserialize, Serialize}, }; -#[derive(Serialize, Deserialize, Debug, FromForm)] +#[derive(Serialize, Deserialize, Debug, FromForm, Clone)] #[serde(crate = "rocket::serde")] pub struct App { pub client_name: String, diff --git a/migrations/20250423182916_add_auth.sql b/migrations/20250423182916_add_auth.sql new file mode 100644 index 0000000..d3807cb --- /dev/null +++ b/migrations/20250423182916_add_auth.sql @@ -0,0 +1,26 @@ +CREATE TABLE IF NOT EXISTS auth +( + token TEXT PRIMARY KEY NOT NULL, + user_id TEXT NOT NULL, + + FOREIGN KEY(user_id) REFERENCES user(id) +); + +CREATE TABLE IF NOT EXISTS app +( + client_id TEXT PRIMARY KEY NOT NULL, + client_secret TEXT NOT NULL, + scopes TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS oauth +( + id_token TEXT PRIMARY KEY NOT NULL, + client_id TEXT NOT NULL, + expires_in INTEGER NOT NULL, + scope TEXT NOT NULL, + access_token TEXT NOT NULL, + + FOREIGN KEY(access_token) REFERENCES auth(token), + FOREIGN KEY(client_id) REFERENCES app(client_id) +);