live.umm.gay/wish-server/main.go

278 lines
6.8 KiB
Go

package main
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"sync"
"database/sql"
"github.com/joho/godotenv"
"github.com/pion/interceptor"
"github.com/pion/webrtc/v3"
_ "github.com/mattn/go-sqlite3"
)
type WebRTCStream struct {
audioTrack, videoTrack *webrtc.TrackLocalStaticRTP
}
var (
runningStreams map[string]WebRTCStream
runningStreamsLock sync.Mutex
api *webrtc.API
db *sql.DB
)
func main() {
if err := godotenv.Load(".env"); err != nil {
log.Fatal(err)
}
db = setupDatabase()
defer db.Close()
api = setupWebRTC()
runningStreams = map[string]WebRTCStream{}
mux := http.NewServeMux()
mux.HandleFunc("/api/wish-server/whip", withCors(HandleWHIP))
mux.HandleFunc("/api/wish-server/whep", withCors(HandleWHEP))
log.Fatal((&http.Server{
Handler: mux,
Addr: os.Getenv("HTTP_ADDRESS"),
}).ListenAndServe())
}
func setupDatabase() *sql.DB {
db, err := sql.Open("sqlite3", os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatal(err)
}
return db
}
func setupWebRTC() *webrtc.API {
mediaEngine := &webrtc.MediaEngine{}
if err := mediaEngine.RegisterDefaultCodecs(); err != nil {
log.Fatal(err)
}
interceptorRegistry := &interceptor.Registry{}
if err := webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry); err != nil {
log.Fatal(err)
}
settingEngine := webrtc.SettingEngine{}
return webrtc.NewAPI(
webrtc.WithMediaEngine(mediaEngine),
webrtc.WithInterceptorRegistry(interceptorRegistry),
webrtc.WithSettingEngine(settingEngine),
)
}
func withCors(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
res.Header().Set("Access-Control-Allow-Origin", "*")
res.Header().Set("Access-Control-Allow-Methods", "*")
res.Header().Set("Access-Control-Allow-Headers", "*")
if req.Method != http.MethodOptions {
next(res, req)
}
}
}
func logHTTPError(w http.ResponseWriter, err string, code int) {
log.Println(err)
http.Error(w, err, code)
}
func getTracksForStream(streamName string) (
*webrtc.TrackLocalStaticRTP,
*webrtc.TrackLocalStaticRTP,
error,
) {
runningStreamsLock.Lock()
defer runningStreamsLock.Unlock()
foundStream, ok := runningStreams[streamName]
if !ok {
videoTrack, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264}, "video", "pion")
if err != nil {
return nil, nil, err
}
audioTrack, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus}, "audio", "pion")
if err != nil {
return nil, nil, err
}
foundStream = WebRTCStream{
audioTrack: audioTrack,
videoTrack: videoTrack,
}
runningStreams[streamName] = foundStream
}
return foundStream.audioTrack, foundStream.videoTrack, nil
}
func HandleWHIP(res http.ResponseWriter, req *http.Request) {
authorization := req.Header.Get("Authorization")
if authorization == "" {
logHTTPError(res, "Authorization was not set", http.StatusBadRequest)
return
}
streamName, streamPassword, _ := strings.Cut(authorization, ":")
if err := db.QueryRow("SELECT * FROM streams WHERE stream = ? AND password = ?", streamName, streamPassword).Scan(); err != nil {
logHTTPError(res, "Invalid stream authorization", http.StatusUnauthorized)
return
}
offer, err := io.ReadAll(req.Body)
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
peerConnection, err := api.NewPeerConnection(webrtc.Configuration{})
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
audioTrack, videoTrack, err := getTracksForStream(streamName)
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
peerConnection.OnTrack(func(track *webrtc.TrackRemote, _recv *webrtc.RTPReceiver) {
var localTrack *webrtc.TrackLocalStaticRTP
if strings.HasPrefix(track.Codec().RTPCodecCapability.MimeType, "audio/") {
localTrack = audioTrack
} else {
localTrack = videoTrack
}
rtpBuf := make([]byte, 1500)
for {
rtpRead, _, readErr := track.Read(rtpBuf)
switch {
case errors.Is(readErr, io.EOF):
return
case readErr != nil:
log.Println(readErr)
return
}
if _, writeErr := localTrack.Write(rtpBuf[:rtpRead]); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) {
log.Println(writeErr)
return
}
}
})
peerConnection.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state == webrtc.ICEConnectionStateFailed {
if err := peerConnection.Close(); err != nil {
log.Println(err)
return
}
}
})
if err := peerConnection.SetRemoteDescription(webrtc.SessionDescription{
SDP: string(offer),
Type: webrtc.SDPTypeOffer,
}); err != nil {
logHTTPError(res, err.Error(), http.StatusBadRequest)
return
}
gatheringComplete := webrtc.GatheringCompletePromise(peerConnection)
answer, err := peerConnection.CreateAnswer(nil)
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
} else if err = peerConnection.SetLocalDescription(answer); err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
<-gatheringComplete
fmt.Fprint(res, peerConnection.LocalDescription().SDP)
}
func HandleWHEP(res http.ResponseWriter, req *http.Request) {
streamName := req.Header.Get("Authorization")
if streamName == "" {
logHTTPError(res, "Stream name was not set", http.StatusBadRequest)
return
}
offer, err := io.ReadAll(req.Body)
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
peerConnection, err := api.NewPeerConnection(webrtc.Configuration{})
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
audioTrack, videoTrack, err := getTracksForStream(streamName)
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
if _, err = peerConnection.AddTrack(audioTrack); err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
if _, err = peerConnection.AddTrack(videoTrack); err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
if err := peerConnection.SetRemoteDescription(webrtc.SessionDescription{
SDP: string(offer),
Type: webrtc.SDPTypeOffer,
}); err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
gatheringComplete := webrtc.GatheringCompletePromise(peerConnection)
answer, err := peerConnection.CreateAnswer(nil)
if err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
} else if err = peerConnection.SetLocalDescription(answer); err != nil {
logHTTPError(res, err.Error(), http.StatusInternalServerError)
return
}
<-gatheringComplete
fmt.Fprint(res, peerConnection.LocalDescription().SDP)
}