use std::{ collections::{hash_map::Entry, HashMap}, fmt::Display, net::SocketAddr, sync::{ atomic::{AtomicU64, Ordering}, Arc, Mutex, }, }; use axum::{ extract::State, http::{header, HeaderMap, Method, StatusCode}, response::{IntoResponse, Response}, routing, Router, }; use miette::{miette, Context, IntoDiagnostic, Result}; use once_cell::sync::Lazy; use tower_http::{ cors::{Any, CorsLayer}, trace::TraceLayer, }; use tracing::{error, instrument}; use webrtc::{ api::{ interceptor_registry::register_default_interceptors, media_engine::MediaEngine, setting_engine::SettingEngine, APIBuilder, API as WebRTC, }, ice::network_type::NetworkType, ice_transport::ice_server::RTCIceServer, interceptor::registry::Registry as InterceptorRegistry, peer_connection::{ configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState, sdp::session_description::RTCSessionDescription, RTCPeerConnection, }, rtp_transceiver::rtp_codec::RTCRtpCodecCapability, track::track_local::{track_local_static_rtp::TrackLocalStaticRTP, TrackLocalWriter}, Error, }; fn create_rtc_config() -> RTCConfiguration { RTCConfiguration { ice_servers: vec![ RTCIceServer { urls: vec!["stun:stun.cloudflare.com:3478".into()], ..Default::default() }, RTCIceServer { urls: vec!["stun:stun.l.google.com:19302".into()], ..Default::default() }, ], ..Default::default() } } fn setup_webrtc() -> Result { let mut media_engine = MediaEngine::default(); media_engine .register_default_codecs() .into_diagnostic() .wrap_err("Failed to register default media engine codecs.")?; let interceptor_registry = InterceptorRegistry::new(); let interceptor_registry = register_default_interceptors(interceptor_registry, &mut media_engine) .into_diagnostic() .wrap_err("Failed to register default interceptors.")?; let mut setting_engine = SettingEngine::default(); setup_ice(&mut setting_engine)?; let api = APIBuilder::new() .with_media_engine(media_engine) .with_interceptor_registry(interceptor_registry) .with_setting_engine(setting_engine) .build(); Ok(api) } fn setup_ice(setting_engine: &mut SettingEngine) -> Result<()> { setting_engine.set_network_types(vec![ NetworkType::Tcp4, NetworkType::Tcp6, NetworkType::Udp4, NetworkType::Udp6, ]); // TODO: Set up UDP muxing? Ok(()) } #[derive(Clone)] struct OngoingStream { video_track: Arc, audio_track: Arc, viewer_count: Arc, } static STREAMS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); async fn get_ongoing_stream(channel: &str) -> Result { let mut streams = STREAMS.lock().expect("Mutex was poisoned"); let stream = match streams.entry(channel.into()) { Entry::Occupied(o) => o.into_mut(), Entry::Vacant(v) => { let video_track = Arc::new(TrackLocalStaticRTP::new( RTCRtpCodecCapability { mime_type: "video/h264".into(), ..Default::default() }, "video".into(), "pion".into(), )); let audio_track = Arc::new(TrackLocalStaticRTP::new( RTCRtpCodecCapability { mime_type: "audio/opus".into(), ..Default::default() }, "audio".into(), "pion".into(), )); v.insert(OngoingStream { video_track, audio_track, viewer_count: Arc::new(AtomicU64::new(0)), }) } }; Ok(stream.clone()) } async fn setup_whip_connection( channel: &str, webrtc: Arc, ) -> Result> { let rtc_config = create_rtc_config(); let peer_connection = Arc::new( webrtc .new_peer_connection(rtc_config) .await .into_diagnostic()?, ); let OngoingStream { video_track, audio_track, .. } = get_ongoing_stream(channel).await?; peer_connection.on_track(Box::new(move |track, _recv, _tx| { let local_track = if track.codec().capability.mime_type.starts_with("audio/") { &audio_track } else { &video_track } .clone(); tokio::spawn(async move { let mut rtp_buf = vec![0u8; 1500]; while let Ok((rtp_read, _)) = track.read(&mut rtp_buf).await { let (Ok(_) | Err(Error::ErrClosedPipe)) = local_track.write(&rtp_buf[..rtp_read]).await else { break }; } Result::<()>::Ok(()) }); Box::pin(async {}) })); Ok(peer_connection) } fn log_http_error(status_code: StatusCode, error: E) -> Response { error!("{error}"); (status_code, format!("{}", error)).into_response() } #[instrument(skip_all)] async fn handle_whip( method: Method, headers: HeaderMap, State(webrtc): State>, offer: String, ) -> Response { if method != Method::POST { return (StatusCode::METHOD_NOT_ALLOWED, "Please use POST!").into_response(); } let Some(auth) = headers.get(header::AUTHORIZATION) else { return (StatusCode::UNAUTHORIZED, "Authorization was not set").into_response() }; let auth = auth .to_str() .into_diagnostic() .wrap_err("Failed to decode auth header") .unwrap(); let auth = auth.strip_prefix("Bearer ").unwrap_or(auth); let Some((channel, _key)) = auth.split_once(':') else { return (StatusCode::UNAUTHORIZED, "Invalid Authorization header").into_response() }; // TODO: Validate the stream key let peer_connection = match setup_whip_connection(channel, webrtc) .await .wrap_err("Failed to initialize peer connection") { Ok(peer_connection) => peer_connection, Err(e) => { return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e); } }; let Ok(description) = RTCSessionDescription::offer(offer) else { return log_http_error(StatusCode::BAD_REQUEST, "Malformed SDP offer") }; if let Err(e) = peer_connection.set_remote_description(description).await { return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e); }; let mut gather_complete = peer_connection.gathering_complete_promise().await; let answer = match peer_connection.create_answer(None).await { Ok(answer) => answer, Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e), }; if let Err(e) = peer_connection.set_local_description(answer).await { return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e); } let _ = gather_complete.recv().await; match peer_connection.local_description().await { Some(desc) => (StatusCode::CREATED, desc.sdp).into_response(), None => log_http_error( StatusCode::INTERNAL_SERVER_ERROR, miette!("No local description exists!"), ), } } async fn setup_whep_connection( channel: &str, webrtc: Arc, ) -> Result> { let rtc_config = RTCConfiguration::default(); let peer_connection = Arc::new( webrtc .new_peer_connection(rtc_config) .await .into_diagnostic()?, ); let OngoingStream { video_track, audio_track, viewer_count, } = get_ongoing_stream(channel).await?; peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { if s == RTCPeerConnectionState::Connected { viewer_count.fetch_add(1, Ordering::Relaxed); } if s == RTCPeerConnectionState::Disconnected { viewer_count.fetch_sub(1, Ordering::Relaxed); } Box::pin(async {}) })); peer_connection .add_track(video_track) .await .into_diagnostic() .wrap_err("Failed to add video track")?; peer_connection .add_track(audio_track) .await .into_diagnostic() .wrap_err("Failed to add audio track")?; Ok(peer_connection) } #[instrument(skip_all)] async fn handle_whep( method: Method, headers: HeaderMap, State(webrtc): State>, offer: String, ) -> Response { if method != Method::POST { return (StatusCode::METHOD_NOT_ALLOWED, "Please use POST!").into_response(); } let channel = match headers .get(header::AUTHORIZATION) .ok_or(miette!("Authorization header was not set")) .and_then(|h| { h.to_str() .into_diagnostic() .wrap_err("Authorization header was malformed") }) { Ok(a) => a, Err(e) => return log_http_error(StatusCode::BAD_REQUEST, e), }; let channel = channel.strip_prefix("Bearer ").unwrap_or(channel); let peer_connection = match setup_whep_connection(channel, webrtc).await { Ok(p) => p, Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e), }; let Ok(description) = RTCSessionDescription::offer(offer) else { return log_http_error(StatusCode::BAD_REQUEST, "Malformed SDP offer") }; if let Err(e) = peer_connection.set_remote_description(description).await { return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e); }; let mut gather_complete = peer_connection.gathering_complete_promise().await; let answer = match peer_connection.create_answer(None).await { Ok(answer) => answer, Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e), }; if let Err(e) = peer_connection.set_local_description(answer).await { return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e); } let _ = gather_complete.recv().await; match peer_connection.local_description().await { Some(desc) => (StatusCode::CREATED, desc.sdp).into_response(), None => log_http_error( StatusCode::INTERNAL_SERVER_ERROR, miette!("No local description exists!"), ), } } #[tokio::main] async fn main() -> Result<()> { let webrtc = Arc::new(setup_webrtc()?); tracing_subscriber::fmt::init(); // TODO: CORS let app = Router::new() .route("/api/wish-server/whip", routing::any(handle_whip)) .route("/api/wish-server/whep", routing::any(handle_whep)) .layer(TraceLayer::new_for_http()) .layer( CorsLayer::new() .allow_methods(Any) .allow_origin(Any) .allow_headers(Any), ) .with_state(webrtc); let bind_addr: SocketAddr = "127.0.0.1:3001" .parse() .into_diagnostic() .wrap_err("Couldn't parse bind address")?; println!("Listening at http://{bind_addr}/ ..."); axum::Server::bind(&bind_addr) .serve(app.into_make_service()) .await .unwrap(); Ok(()) }