278 lines
6.8 KiB
Go
278 lines
6.8 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
|
|
"database/sql"
|
|
|
|
"github.com/joho/godotenv"
|
|
"github.com/pion/interceptor"
|
|
"github.com/pion/webrtc/v3"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
type WebRTCStream struct {
|
|
audioTrack, videoTrack *webrtc.TrackLocalStaticRTP
|
|
}
|
|
|
|
var (
|
|
runningStreams map[string]WebRTCStream
|
|
runningStreamsLock sync.Mutex
|
|
api *webrtc.API
|
|
db *sql.DB
|
|
)
|
|
|
|
func main() {
|
|
if err := godotenv.Load(".env"); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
db = setupDatabase()
|
|
defer db.Close()
|
|
|
|
api = setupWebRTC()
|
|
runningStreams = map[string]WebRTCStream{}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/api/wish-server/whip", withCors(HandleWHIP))
|
|
mux.HandleFunc("/api/wish-server/whep", withCors(HandleWHEP))
|
|
|
|
log.Fatal((&http.Server{
|
|
Handler: mux,
|
|
Addr: os.Getenv("HTTP_ADDRESS"),
|
|
}).ListenAndServe())
|
|
}
|
|
|
|
func setupDatabase() *sql.DB {
|
|
db, err := sql.Open("sqlite3", os.Getenv("DATABASE_URL"))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
func setupWebRTC() *webrtc.API {
|
|
mediaEngine := &webrtc.MediaEngine{}
|
|
if err := mediaEngine.RegisterDefaultCodecs(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
interceptorRegistry := &interceptor.Registry{}
|
|
if err := webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
settingEngine := webrtc.SettingEngine{}
|
|
|
|
return webrtc.NewAPI(
|
|
webrtc.WithMediaEngine(mediaEngine),
|
|
webrtc.WithInterceptorRegistry(interceptorRegistry),
|
|
webrtc.WithSettingEngine(settingEngine),
|
|
)
|
|
}
|
|
|
|
func withCors(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
|
|
return func(res http.ResponseWriter, req *http.Request) {
|
|
res.Header().Set("Access-Control-Allow-Origin", "*")
|
|
res.Header().Set("Access-Control-Allow-Methods", "*")
|
|
res.Header().Set("Access-Control-Allow-Headers", "*")
|
|
|
|
if req.Method != http.MethodOptions {
|
|
next(res, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
func logHTTPError(w http.ResponseWriter, err string, code int) {
|
|
log.Println(err)
|
|
http.Error(w, err, code)
|
|
}
|
|
|
|
func getTracksForStream(streamName string) (
|
|
*webrtc.TrackLocalStaticRTP,
|
|
*webrtc.TrackLocalStaticRTP,
|
|
error,
|
|
) {
|
|
runningStreamsLock.Lock()
|
|
defer runningStreamsLock.Unlock()
|
|
|
|
foundStream, ok := runningStreams[streamName]
|
|
if !ok {
|
|
videoTrack, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264}, "video", "pion")
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
audioTrack, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus}, "audio", "pion")
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
foundStream = WebRTCStream{
|
|
audioTrack: audioTrack,
|
|
videoTrack: videoTrack,
|
|
}
|
|
runningStreams[streamName] = foundStream
|
|
}
|
|
|
|
return foundStream.audioTrack, foundStream.videoTrack, nil
|
|
}
|
|
|
|
func HandleWHIP(res http.ResponseWriter, req *http.Request) {
|
|
authorization := req.Header.Get("Authorization")
|
|
if authorization == "" {
|
|
logHTTPError(res, "Authorization was not set", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
streamName, streamPassword, _ := strings.Cut(authorization, ":")
|
|
if err := db.QueryRow("SELECT * FROM streams WHERE stream = ? AND password = ?", streamName, streamPassword).Scan(); err != nil {
|
|
logHTTPError(res, "Invalid stream authorization", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
offer, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
peerConnection, err := api.NewPeerConnection(webrtc.Configuration{})
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
audioTrack, videoTrack, err := getTracksForStream(streamName)
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
peerConnection.OnTrack(func(track *webrtc.TrackRemote, _recv *webrtc.RTPReceiver) {
|
|
var localTrack *webrtc.TrackLocalStaticRTP
|
|
if strings.HasPrefix(track.Codec().RTPCodecCapability.MimeType, "audio/") {
|
|
localTrack = audioTrack
|
|
} else {
|
|
localTrack = videoTrack
|
|
}
|
|
|
|
rtpBuf := make([]byte, 1500)
|
|
for {
|
|
rtpRead, _, readErr := track.Read(rtpBuf)
|
|
switch {
|
|
case errors.Is(readErr, io.EOF):
|
|
return
|
|
case readErr != nil:
|
|
log.Println(readErr)
|
|
return
|
|
}
|
|
|
|
if _, writeErr := localTrack.Write(rtpBuf[:rtpRead]); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) {
|
|
log.Println(writeErr)
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
peerConnection.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
|
if state == webrtc.ICEConnectionStateFailed {
|
|
if err := peerConnection.Close(); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
if err := peerConnection.SetRemoteDescription(webrtc.SessionDescription{
|
|
SDP: string(offer),
|
|
Type: webrtc.SDPTypeOffer,
|
|
}); err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
gatheringComplete := webrtc.GatheringCompletePromise(peerConnection)
|
|
|
|
answer, err := peerConnection.CreateAnswer(nil)
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
} else if err = peerConnection.SetLocalDescription(answer); err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
<-gatheringComplete
|
|
|
|
fmt.Fprint(res, peerConnection.LocalDescription().SDP)
|
|
}
|
|
|
|
func HandleWHEP(res http.ResponseWriter, req *http.Request) {
|
|
streamName := req.Header.Get("Authorization")
|
|
if streamName == "" {
|
|
logHTTPError(res, "Stream name was not set", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
offer, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
peerConnection, err := api.NewPeerConnection(webrtc.Configuration{})
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
audioTrack, videoTrack, err := getTracksForStream(streamName)
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if _, err = peerConnection.AddTrack(audioTrack); err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if _, err = peerConnection.AddTrack(videoTrack); err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := peerConnection.SetRemoteDescription(webrtc.SessionDescription{
|
|
SDP: string(offer),
|
|
Type: webrtc.SDPTypeOffer,
|
|
}); err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
gatheringComplete := webrtc.GatheringCompletePromise(peerConnection)
|
|
|
|
answer, err := peerConnection.CreateAnswer(nil)
|
|
if err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
} else if err = peerConnection.SetLocalDescription(answer); err != nil {
|
|
logHTTPError(res, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
<-gatheringComplete
|
|
|
|
fmt.Fprint(res, peerConnection.LocalDescription().SDP)
|
|
}
|