[wish-server-rs] Refactor :)

This commit is contained in:
Charlotte Som 2023-03-13 16:18:32 +00:00
parent 5df33492f4
commit e15483198a
7 changed files with 409 additions and 338 deletions

View file

@ -0,0 +1,9 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = false
insert_final_newline = true

View file

@ -1,355 +1,33 @@
use std::{ use miette::{Context, IntoDiagnostic, Result};
collections::{hash_map::Entry, HashMap}, use std::{net::SocketAddr, sync::Arc};
fmt::Display,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};
use axum::{ use axum::{routing, Router};
extract::State,
http::{header, HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
routing, Router,
};
use miette::{miette, Context, IntoDiagnostic, Result};
use once_cell::sync::Lazy;
use tower_http::{ use tower_http::{
cors::{Any, CorsLayer}, cors::{Any, CorsLayer},
trace::TraceLayer, trace::TraceLayer,
}; };
use tracing::{error, instrument};
use webrtc::{
api::{
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
setting_engine::SettingEngine, APIBuilder, API as WebRTC,
},
ice::network_type::NetworkType,
ice_transport::ice_server::RTCIceServer,
interceptor::registry::Registry as InterceptorRegistry,
peer_connection::{
configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState,
sdp::session_description::RTCSessionDescription, RTCPeerConnection,
},
rtp_transceiver::rtp_codec::RTCRtpCodecCapability,
track::track_local::{track_local_static_rtp::TrackLocalStaticRTP, TrackLocalWriter},
Error,
};
fn create_rtc_config() -> RTCConfiguration { mod streams;
RTCConfiguration { mod util;
ice_servers: vec![ mod wish;
RTCIceServer {
urls: vec!["stun:stun.cloudflare.com:3478".into()],
..Default::default()
},
RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".into()],
..Default::default()
},
],
..Default::default()
}
}
fn setup_webrtc() -> Result<WebRTC> { use crate::wish::{setup_webrtc, whep::handle_whep, whip::handle_whip};
let mut media_engine = MediaEngine::default();
media_engine
.register_default_codecs()
.into_diagnostic()
.wrap_err("Failed to register default media engine codecs.")?;
let interceptor_registry = InterceptorRegistry::new();
let interceptor_registry =
register_default_interceptors(interceptor_registry, &mut media_engine)
.into_diagnostic()
.wrap_err("Failed to register default interceptors.")?;
let mut setting_engine = SettingEngine::default();
setup_ice(&mut setting_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(interceptor_registry)
.with_setting_engine(setting_engine)
.build();
Ok(api)
}
fn setup_ice(setting_engine: &mut SettingEngine) -> Result<()> {
setting_engine.set_network_types(vec![
NetworkType::Tcp4,
NetworkType::Tcp6,
NetworkType::Udp4,
NetworkType::Udp6,
]);
// TODO: Set up UDP muxing?
Ok(())
}
#[derive(Clone)] #[derive(Clone)]
struct OngoingStream { pub struct AppState {
video_track: Arc<TrackLocalStaticRTP>, pub webrtc: Arc<webrtc::api::API>,
audio_track: Arc<TrackLocalStaticRTP>, pub db: Arc<&'static str>,
viewer_count: Arc<AtomicU64>,
}
static STREAMS: Lazy<Mutex<HashMap<String, OngoingStream>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
async fn get_ongoing_stream(channel: &str) -> Result<OngoingStream> {
let mut streams = STREAMS.lock().expect("Mutex was poisoned");
let stream = match streams.entry(channel.into()) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => {
let video_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: "video/h264".into(),
..Default::default()
},
"video".into(),
"pion".into(),
));
let audio_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: "audio/opus".into(),
..Default::default()
},
"audio".into(),
"pion".into(),
));
v.insert(OngoingStream {
video_track,
audio_track,
viewer_count: Arc::new(AtomicU64::new(0)),
})
}
};
Ok(stream.clone())
}
async fn setup_whip_connection(
channel: &str,
webrtc: Arc<WebRTC>,
) -> Result<Arc<RTCPeerConnection>> {
let rtc_config = create_rtc_config();
let peer_connection = Arc::new(
webrtc
.new_peer_connection(rtc_config)
.await
.into_diagnostic()?,
);
let OngoingStream {
video_track,
audio_track,
..
} = get_ongoing_stream(channel).await?;
peer_connection.on_track(Box::new(move |track, _recv, _tx| {
let local_track = if track.codec().capability.mime_type.starts_with("audio/") {
&audio_track
} else {
&video_track
}
.clone();
tokio::spawn(async move {
let mut rtp_buf = vec![0u8; 1500];
while let Ok((rtp_read, _)) = track.read(&mut rtp_buf).await {
let (Ok(_) | Err(Error::ErrClosedPipe)) = local_track.write(&rtp_buf[..rtp_read]).await else { break };
}
Result::<()>::Ok(())
});
Box::pin(async {})
}));
Ok(peer_connection)
}
fn log_http_error<E: Display>(status_code: StatusCode, error: E) -> Response {
error!("{error}");
(status_code, format!("{}", error)).into_response()
}
#[instrument(skip_all)]
async fn handle_whip(
method: Method,
headers: HeaderMap,
State(webrtc): State<Arc<WebRTC>>,
offer: String,
) -> Response {
if method != Method::POST {
return (StatusCode::METHOD_NOT_ALLOWED, "Please use POST!").into_response();
}
let Some(auth) = headers.get(header::AUTHORIZATION) else { return (StatusCode::UNAUTHORIZED, "Authorization was not set").into_response() };
let auth = auth
.to_str()
.into_diagnostic()
.wrap_err("Failed to decode auth header")
.unwrap();
let auth = auth.strip_prefix("Bearer ").unwrap_or(auth);
let Some((channel, _key)) = auth.split_once(':') else { return (StatusCode::UNAUTHORIZED, "Invalid Authorization header").into_response() };
// TODO: Validate the stream key
let peer_connection = match setup_whip_connection(channel, webrtc)
.await
.wrap_err("Failed to initialize peer connection")
{
Ok(peer_connection) => peer_connection,
Err(e) => {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
}
};
let Ok(description) = RTCSessionDescription::offer(offer) else { return log_http_error(StatusCode::BAD_REQUEST, "Malformed SDP offer") };
if let Err(e) = peer_connection.set_remote_description(description).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
};
let mut gather_complete = peer_connection.gathering_complete_promise().await;
let answer = match peer_connection.create_answer(None).await {
Ok(answer) => answer,
Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e),
};
if let Err(e) = peer_connection.set_local_description(answer).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
}
let _ = gather_complete.recv().await;
match peer_connection.local_description().await {
Some(desc) => (StatusCode::CREATED, desc.sdp).into_response(),
None => log_http_error(
StatusCode::INTERNAL_SERVER_ERROR,
miette!("No local description exists!"),
),
}
}
async fn setup_whep_connection(
channel: &str,
webrtc: Arc<WebRTC>,
) -> Result<Arc<RTCPeerConnection>> {
let rtc_config = RTCConfiguration::default();
let peer_connection = Arc::new(
webrtc
.new_peer_connection(rtc_config)
.await
.into_diagnostic()?,
);
let OngoingStream {
video_track,
audio_track,
viewer_count,
} = get_ongoing_stream(channel).await?;
peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
if s == RTCPeerConnectionState::Connected {
viewer_count.fetch_add(1, Ordering::Relaxed);
}
if s == RTCPeerConnectionState::Disconnected {
viewer_count.fetch_sub(1, Ordering::Relaxed);
}
Box::pin(async {})
}));
peer_connection
.add_track(video_track)
.await
.into_diagnostic()
.wrap_err("Failed to add video track")?;
peer_connection
.add_track(audio_track)
.await
.into_diagnostic()
.wrap_err("Failed to add audio track")?;
Ok(peer_connection)
}
#[instrument(skip_all)]
async fn handle_whep(
method: Method,
headers: HeaderMap,
State(webrtc): State<Arc<WebRTC>>,
offer: String,
) -> Response {
if method != Method::POST {
return (StatusCode::METHOD_NOT_ALLOWED, "Please use POST!").into_response();
}
let channel = match headers
.get(header::AUTHORIZATION)
.ok_or(miette!("Authorization header was not set"))
.and_then(|h| {
h.to_str()
.into_diagnostic()
.wrap_err("Authorization header was malformed")
}) {
Ok(a) => a,
Err(e) => return log_http_error(StatusCode::BAD_REQUEST, e),
};
let channel = channel.strip_prefix("Bearer ").unwrap_or(channel);
let peer_connection = match setup_whep_connection(channel, webrtc).await {
Ok(p) => p,
Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e),
};
let Ok(description) = RTCSessionDescription::offer(offer) else { return log_http_error(StatusCode::BAD_REQUEST, "Malformed SDP offer") };
if let Err(e) = peer_connection.set_remote_description(description).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
};
let mut gather_complete = peer_connection.gathering_complete_promise().await;
let answer = match peer_connection.create_answer(None).await {
Ok(answer) => answer,
Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e),
};
if let Err(e) = peer_connection.set_local_description(answer).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
}
let _ = gather_complete.recv().await;
match peer_connection.local_description().await {
Some(desc) => (StatusCode::CREATED, desc.sdp).into_response(),
None => log_http_error(
StatusCode::INTERNAL_SERVER_ERROR,
miette!("No local description exists!"),
),
}
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let webrtc = Arc::new(setup_webrtc()?);
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
// TODO: CORS let webrtc = Arc::new(setup_webrtc()?);
let app_state = AppState {
webrtc,
db: Arc::new("ooo weee i'm the data base!!!"), // TODO: sqlx
};
let app = Router::new() let app = Router::new()
.route("/api/wish-server/whip", routing::any(handle_whip)) .route("/api/wish-server/whip", routing::any(handle_whip))
@ -361,7 +39,7 @@ async fn main() -> Result<()> {
.allow_origin(Any) .allow_origin(Any)
.allow_headers(Any), .allow_headers(Any),
) )
.with_state(webrtc); .with_state(app_state);
let bind_addr: SocketAddr = "127.0.0.1:3001" let bind_addr: SocketAddr = "127.0.0.1:3001"
.parse() .parse()

View file

@ -0,0 +1,57 @@
use std::{
collections::{hash_map::Entry, HashMap},
sync::{atomic::AtomicU64, Arc, Mutex},
};
use miette::Result;
use once_cell::sync::Lazy;
use webrtc::{
rtp_transceiver::rtp_codec::RTCRtpCodecCapability,
track::track_local::track_local_static_rtp::TrackLocalStaticRTP,
};
#[derive(Clone)]
pub struct OngoingStream {
pub video_track: Arc<TrackLocalStaticRTP>,
pub audio_track: Arc<TrackLocalStaticRTP>,
pub viewer_count: Arc<AtomicU64>,
}
static STREAMS: Lazy<Mutex<HashMap<String, OngoingStream>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
pub async fn get_ongoing_stream(channel: &str) -> Result<OngoingStream> {
let mut streams = STREAMS.lock().expect("Mutex was poisoned");
let stream = match streams.entry(channel.into()) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => {
let video_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: "video/h264".into(),
..Default::default()
},
"video".into(),
"pion".into(),
));
let audio_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: "audio/opus".into(),
..Default::default()
},
"audio".into(),
"pion".into(),
));
v.insert(OngoingStream {
video_track,
audio_track,
viewer_count: Arc::new(AtomicU64::new(0)),
})
}
};
Ok(stream.clone())
}

View file

@ -0,0 +1,11 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use std::fmt::Display;
use tracing::error;
pub fn log_http_error<E: Display>(status_code: StatusCode, error: E) -> Response {
error!("{error}");
(status_code, format!("{}", error)).into_response()
}

View file

@ -0,0 +1,68 @@
use miette::{Context, IntoDiagnostic, Result};
use webrtc::{
api::{
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
setting_engine::SettingEngine, APIBuilder, API as WebRTC,
},
ice::network_type::NetworkType,
ice_transport::ice_server::RTCIceServer,
interceptor::registry::Registry as InterceptorRegistry,
peer_connection::configuration::RTCConfiguration,
};
pub mod whep;
pub mod whip;
fn setup_ice(setting_engine: &mut SettingEngine) -> Result<()> {
setting_engine.set_network_types(vec![
NetworkType::Tcp4,
NetworkType::Tcp6,
NetworkType::Udp4,
NetworkType::Udp6,
]);
// TODO: Set up UDP muxing?
Ok(())
}
pub fn setup_webrtc() -> Result<WebRTC> {
let mut media_engine = MediaEngine::default();
media_engine
.register_default_codecs()
.into_diagnostic()
.wrap_err("Failed to register default media engine codecs.")?;
let interceptor_registry = InterceptorRegistry::new();
let interceptor_registry =
register_default_interceptors(interceptor_registry, &mut media_engine)
.into_diagnostic()
.wrap_err("Failed to register default interceptors.")?;
let mut setting_engine = SettingEngine::default();
setup_ice(&mut setting_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(interceptor_registry)
.with_setting_engine(setting_engine)
.build();
Ok(api)
}
pub fn create_rtc_config() -> RTCConfiguration {
RTCConfiguration {
ice_servers: vec![
RTCIceServer {
urls: vec!["stun:stun.cloudflare.com:3478".into()],
..Default::default()
},
RTCIceServer {
urls: vec!["stun:stun.l.google.com:19302".into()],
..Default::default()
},
],
..Default::default()
}
}

View file

@ -0,0 +1,124 @@
use std::sync::{atomic::Ordering, Arc};
use axum::{
extract::State,
http::{header, HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
};
use miette::{miette, Context, IntoDiagnostic, Result};
use crate::{
streams::{get_ongoing_stream, OngoingStream},
util::log_http_error,
AppState,
};
use tracing::instrument;
use webrtc::{
api::API as WebRTC,
peer_connection::{
configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState,
sdp::session_description::RTCSessionDescription, RTCPeerConnection,
},
};
async fn setup_whep_connection(
channel: &str,
webrtc: Arc<WebRTC>,
) -> Result<Arc<RTCPeerConnection>> {
let rtc_config = RTCConfiguration::default();
let peer_connection = Arc::new(
webrtc
.new_peer_connection(rtc_config)
.await
.into_diagnostic()?,
);
let OngoingStream {
video_track,
audio_track,
viewer_count,
} = get_ongoing_stream(channel).await?;
peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
if s == RTCPeerConnectionState::Connected {
viewer_count.fetch_add(1, Ordering::Relaxed);
}
if s == RTCPeerConnectionState::Disconnected {
viewer_count.fetch_sub(1, Ordering::Relaxed);
}
Box::pin(async {})
}));
peer_connection
.add_track(video_track)
.await
.into_diagnostic()
.wrap_err("Failed to add video track")?;
peer_connection
.add_track(audio_track)
.await
.into_diagnostic()
.wrap_err("Failed to add audio track")?;
Ok(peer_connection)
}
#[instrument(skip_all)]
pub async fn handle_whep(
method: Method,
headers: HeaderMap,
State(app): State<AppState>,
offer: String,
) -> Response {
if method != Method::POST {
return (StatusCode::METHOD_NOT_ALLOWED, "Please use POST!").into_response();
}
let channel = match headers
.get(header::AUTHORIZATION)
.ok_or(miette!("Authorization header was not set"))
.and_then(|h| {
h.to_str()
.into_diagnostic()
.wrap_err("Authorization header was malformed")
}) {
Ok(a) => a,
Err(e) => return log_http_error(StatusCode::BAD_REQUEST, e),
};
let channel = channel.strip_prefix("Bearer ").unwrap_or(channel);
let peer_connection = match setup_whep_connection(channel, app.webrtc).await {
Ok(p) => p,
Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e),
};
let Ok(description) = RTCSessionDescription::offer(offer) else { return log_http_error(StatusCode::BAD_REQUEST, "Malformed SDP offer") };
if let Err(e) = peer_connection.set_remote_description(description).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
};
let mut gather_complete = peer_connection.gathering_complete_promise().await;
let answer = match peer_connection.create_answer(None).await {
Ok(answer) => answer,
Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e),
};
if let Err(e) = peer_connection.set_local_description(answer).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
}
let _ = gather_complete.recv().await;
match peer_connection.local_description().await {
Some(desc) => (StatusCode::CREATED, desc.sdp).into_response(),
None => log_http_error(
StatusCode::INTERNAL_SERVER_ERROR,
miette!("No local description exists!"),
),
}
}

View file

@ -0,0 +1,124 @@
use std::sync::Arc;
use axum::{
extract::State,
http::{header, HeaderMap, Method, StatusCode},
response::{IntoResponse, Response},
};
use miette::{miette, Context, IntoDiagnostic, Result};
use tracing::instrument;
use crate::{
streams::{get_ongoing_stream, OngoingStream},
util::log_http_error,
AppState,
};
use webrtc::{
api::API as WebRTC,
peer_connection::{sdp::session_description::RTCSessionDescription, RTCPeerConnection},
track::track_local::TrackLocalWriter,
};
use super::create_rtc_config;
async fn setup_whip_connection(
channel: &str,
webrtc: Arc<WebRTC>,
) -> Result<Arc<RTCPeerConnection>> {
let rtc_config = create_rtc_config();
let peer_connection = Arc::new(
webrtc
.new_peer_connection(rtc_config)
.await
.into_diagnostic()?,
);
let OngoingStream {
video_track,
audio_track,
..
} = get_ongoing_stream(channel).await?;
peer_connection.on_track(Box::new(move |track, _recv, _tx| {
let local_track = if track.codec().capability.mime_type.starts_with("audio/") {
&audio_track
} else {
&video_track
}
.clone();
tokio::spawn(async move {
let mut rtp_buf = vec![0u8; 1500];
while let Ok((rtp_read, _)) = track.read(&mut rtp_buf).await {
let (Ok(_) | Err(webrtc::Error::ErrClosedPipe)) = local_track.write(&rtp_buf[..rtp_read]).await else { break };
}
Result::<()>::Ok(())
});
Box::pin(async {})
}));
Ok(peer_connection)
}
#[instrument(skip_all)]
#[axum::debug_handler]
pub async fn handle_whip(
method: Method,
headers: HeaderMap,
State(app): State<AppState>,
offer: String,
) -> Response {
if method != Method::POST {
return (StatusCode::METHOD_NOT_ALLOWED, "Please use POST!").into_response();
}
let Some(auth) = headers.get(header::AUTHORIZATION) else { return (StatusCode::UNAUTHORIZED, "Authorization was not set").into_response() };
let auth = auth
.to_str()
.into_diagnostic()
.wrap_err("Failed to decode auth header")
.unwrap();
let auth = auth.strip_prefix("Bearer ").unwrap_or(auth);
let Some((channel, _key)) = auth.split_once(':') else { return (StatusCode::UNAUTHORIZED, "Invalid Authorization header").into_response() };
// TODO: Validate the stream key
let peer_connection = match setup_whip_connection(channel, app.webrtc)
.await
.wrap_err("Failed to initialize peer connection")
{
Ok(peer_connection) => peer_connection,
Err(e) => {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
}
};
let Ok(description) = RTCSessionDescription::offer(offer) else { return log_http_error(StatusCode::BAD_REQUEST, "Malformed SDP offer") };
if let Err(e) = peer_connection.set_remote_description(description).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
};
let mut gather_complete = peer_connection.gathering_complete_promise().await;
let answer = match peer_connection.create_answer(None).await {
Ok(answer) => answer,
Err(e) => return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e),
};
if let Err(e) = peer_connection.set_local_description(answer).await {
return log_http_error(StatusCode::INTERNAL_SERVER_ERROR, e);
}
let _ = gather_complete.recv().await;
match peer_connection.local_description().await {
Some(desc) => (StatusCode::CREATED, desc.sdp).into_response(),
None => log_http_error(
StatusCode::INTERNAL_SERVER_ERROR,
miette!("No local description exists!"),
),
}
}