diff --git a/wish-server-rs/.gitignore b/wish-server-rs/.gitignore index aea5790..955f3d9 100644 --- a/wish-server-rs/.gitignore +++ b/wish-server-rs/.gitignore @@ -1,2 +1,3 @@ /target /data.db +/data.db-* diff --git a/wish-server-rs/Cargo.lock b/wish-server-rs/Cargo.lock index a000163..4402a40 100644 --- a/wish-server-rs/Cargo.lock +++ b/wish-server-rs/Cargo.lock @@ -736,6 +736,12 @@ dependencies = [ "syn", ] +[[package]] +name = "dotenv" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" + [[package]] name = "dotenvy" version = "0.15.6" @@ -3138,6 +3144,7 @@ name = "wish-server-rs" version = "0.1.0" dependencies = [ "axum", + "dotenv", "miette", "once_cell", "sqlx", diff --git a/wish-server-rs/Cargo.toml b/wish-server-rs/Cargo.toml index 935d00e..e5b826b 100644 --- a/wish-server-rs/Cargo.toml +++ b/wish-server-rs/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] axum = { version = "0.6.11", features = ["macros"] } +dotenv = "0.15.0" miette = { version = "5.5.0", features = ["fancy"] } once_cell = "1.17.1" sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "sqlite"] } diff --git a/wish-server-rs/src/db.rs b/wish-server-rs/src/db.rs deleted file mode 100644 index e69de29..0000000 diff --git a/wish-server-rs/src/main.rs b/wish-server-rs/src/main.rs index 15e5fb5..2409c7e 100644 --- a/wish-server-rs/src/main.rs +++ b/wish-server-rs/src/main.rs @@ -1,6 +1,6 @@ use miette::{Context, IntoDiagnostic, Result}; use sqlx::{sqlite::SqlitePoolOptions, SqlitePool}; -use std::{net::SocketAddr, sync::Arc}; +use std::{env::VarError, net::SocketAddr, sync::Arc}; use axum::{routing, Router}; use tower_http::{ @@ -22,6 +22,16 @@ pub struct AppState { #[tokio::main] async fn main() -> Result<()> { + let _ = dotenv::dotenv(); + { + if let Err(VarError::NotPresent) = std::env::var("RUST_LOG") { + std::env::set_var( + "RUST_LOG", + "wish_server_rs=debug,tower_http=debug,webrtc=info", + ) + } + } + tracing_subscriber::fmt::init(); let db = { @@ -54,7 +64,8 @@ async fn main() -> Result<()> { ) .with_state(app_state); - let bind_addr: SocketAddr = "127.0.0.1:3001" + let bind_addr: SocketAddr = std::env::var("BIND_ADDRESS") + .unwrap_or_else(|_| "127.0.0.1:3001".into()) .parse() .into_diagnostic() .wrap_err("Couldn't parse bind address")?; diff --git a/wish-server-rs/src/wish/whip.rs b/wish-server-rs/src/wish/whip.rs index 603ab9c..aaed768 100644 --- a/wish-server-rs/src/wish/whip.rs +++ b/wish-server-rs/src/wish/whip.rs @@ -83,8 +83,27 @@ pub async fn handle_whip( .unwrap(); let auth = auth.strip_prefix("Bearer ").unwrap_or(auth); - let Some((channel, _key)) = auth.split_once(':') else { return (StatusCode::UNAUTHORIZED, "Invalid Authorization header").into_response() }; - // TODO: Validate the stream key + let Some((channel, key)) = auth.split_once(':') else { return (StatusCode::UNAUTHORIZED, "Invalid Authorization header").into_response() }; + + let mut db = app + .db + .acquire() + .await + .expect("SQLite connections never fail"); + match sqlx::query!( + "SELECT stream FROM streams WHERE stream = ? AND password = ?", + channel, + key + ) + .fetch_one(&mut db) + .await + { + Err(sqlx::Error::RowNotFound) => { + return log_http_error(StatusCode::UNAUTHORIZED, "Invalid stream key.") + } + Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e), + _ => {} + }; let peer_connection = match setup_whip_connection(channel, app.webrtc) .await