package main import ( "errors" "fmt" "io" "log" "net" "net/http" "os" "strconv" "strings" "sync" "database/sql" "github.com/joho/godotenv" "github.com/pion/ice/v2" "github.com/pion/interceptor" "github.com/pion/webrtc/v3" _ "github.com/mattn/go-sqlite3" ) type WebRTCStream struct { audioTrack, videoTrack *webrtc.TrackLocalStaticRTP } var ( runningStreams map[string]WebRTCStream runningStreamsLock sync.Mutex api *webrtc.API db *sql.DB ) func main() { if err := godotenv.Load(".env"); err != nil { log.Fatal(err) } db = setupDatabase() defer db.Close() api = setupWebRTC() runningStreams = map[string]WebRTCStream{} mux := http.NewServeMux() mux.HandleFunc("/api/wish-server/whip", withCors(HandleWHIP)) mux.HandleFunc("/api/wish-server/whep", withCors(HandleWHEP)) log.Fatal((&http.Server{ Handler: mux, Addr: os.Getenv("HTTP_ADDRESS"), }).ListenAndServe()) } func setupDatabase() *sql.DB { db, err := sql.Open("sqlite3", os.Getenv("DATABASE_URL")) if err != nil { log.Fatal(err) } return db } func setupWebRTC() *webrtc.API { mediaEngine := &webrtc.MediaEngine{} if err := mediaEngine.RegisterDefaultCodecs(); err != nil { log.Fatal(err) } interceptorRegistry := &interceptor.Registry{} if err := webrtc.RegisterDefaultInterceptors(mediaEngine, interceptorRegistry); err != nil { log.Fatal(err) } settingEngine := webrtc.SettingEngine{} setupICE(&settingEngine) return webrtc.NewAPI( webrtc.WithMediaEngine(mediaEngine), webrtc.WithInterceptorRegistry(interceptorRegistry), webrtc.WithSettingEngine(settingEngine), ) } func setupICE(settingEngine *webrtc.SettingEngine) { settingEngine.SetNetworkTypes([]webrtc.NetworkType{ webrtc.NetworkTypeUDP4, webrtc.NetworkTypeUDP6, webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6, }) if udpPort := os.Getenv("UDP_MUX_PORT"); udpPort != "" { port, err := strconv.Atoi(udpPort) if err != nil { log.Fatal(err) } mux, err := ice.NewMultiUDPMuxFromPort(port) if err != nil { log.Fatal(err) } settingEngine.SetICEUDPMux(mux) } if tcpAddr := os.Getenv("TCP_MUX_ADDR"); tcpAddr != "" { addr, err := net.ResolveTCPAddr("tcp", tcpAddr) if err != nil { log.Fatal(err) } listener, err := net.ListenTCP("tcp", addr) if err != nil { log.Fatal(err) } mux := webrtc.NewICETCPMux(nil, listener, 8) settingEngine.SetICETCPMux(mux) } } func withCors(next func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { res.Header().Set("Access-Control-Allow-Origin", "*") res.Header().Set("Access-Control-Allow-Methods", "*") res.Header().Set("Access-Control-Allow-Headers", "*") if req.Method != http.MethodOptions { next(res, req) } } } func logHTTPError(w http.ResponseWriter, err string, code int) { log.Println(err) http.Error(w, err, code) } func getTracksForStream(streamName string) ( *webrtc.TrackLocalStaticRTP, *webrtc.TrackLocalStaticRTP, error, ) { runningStreamsLock.Lock() defer runningStreamsLock.Unlock() foundStream, ok := runningStreams[streamName] if !ok { videoTrack, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264}, "video", "pion") if err != nil { return nil, nil, err } audioTrack, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus}, "audio", "pion") if err != nil { return nil, nil, err } foundStream = WebRTCStream{ audioTrack: audioTrack, videoTrack: videoTrack, } runningStreams[streamName] = foundStream } return foundStream.audioTrack, foundStream.videoTrack, nil } func HandleWHIP(res http.ResponseWriter, req *http.Request) { authorization := req.Header.Get("Authorization") if authorization == "" { logHTTPError(res, "Authorization was not set", http.StatusBadRequest) return } streamName, streamPassword, _ := strings.Cut(authorization, ":") streamName, _ = strings.CutPrefix(strings.ToLower(streamName), "bearer ") var qN string var qP string if err := db.QueryRow("SELECT * FROM streams WHERE stream = ? AND password = ?", streamName, streamPassword).Scan(&qN, &qP); err != nil { logHTTPError(res, "Invalid stream authorization for: "+streamName+" - "+err.Error(), http.StatusUnauthorized) return } offer, err := io.ReadAll(req.Body) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } peerConnection, err := api.NewPeerConnection(webrtc.Configuration{ ICEServers: []webrtc.ICEServer{ { URLs: []string{"stun:stun.cloudflare.com:3478"}, }, { URLs: []string{"stun:stun.l.google.com:19302"}, }, }, }) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } audioTrack, videoTrack, err := getTracksForStream(streamName) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } peerConnection.OnTrack(func(track *webrtc.TrackRemote, _recv *webrtc.RTPReceiver) { var localTrack *webrtc.TrackLocalStaticRTP if strings.HasPrefix(track.Codec().RTPCodecCapability.MimeType, "audio/") { localTrack = audioTrack } else { localTrack = videoTrack } rtpBuf := make([]byte, 1500) for { rtpRead, _, readErr := track.Read(rtpBuf) switch { case errors.Is(readErr, io.EOF): return case readErr != nil: log.Println(readErr) return } if _, writeErr := localTrack.Write(rtpBuf[:rtpRead]); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) { log.Println(writeErr) return } } }) peerConnection.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { if state == webrtc.ICEConnectionStateFailed { if err := peerConnection.Close(); err != nil { log.Println(err) return } } }) if err := peerConnection.SetRemoteDescription(webrtc.SessionDescription{ SDP: string(offer), Type: webrtc.SDPTypeOffer, }); err != nil { logHTTPError(res, err.Error(), http.StatusBadRequest) return } gatheringComplete := webrtc.GatheringCompletePromise(peerConnection) answer, err := peerConnection.CreateAnswer(nil) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } else if err = peerConnection.SetLocalDescription(answer); err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } <-gatheringComplete fmt.Fprint(res, peerConnection.LocalDescription().SDP) } func HandleWHEP(res http.ResponseWriter, req *http.Request) { streamName := req.Header.Get("Authorization") streamName, _ = strings.CutPrefix(strings.ToLower(streamName), "bearer ") if streamName == "" { logHTTPError(res, "Stream name was not set", http.StatusBadRequest) return } offer, err := io.ReadAll(req.Body) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } peerConnection, err := api.NewPeerConnection(webrtc.Configuration{}) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } audioTrack, videoTrack, err := getTracksForStream(streamName) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } if _, err = peerConnection.AddTrack(audioTrack); err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } if _, err = peerConnection.AddTrack(videoTrack); err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } if err := peerConnection.SetRemoteDescription(webrtc.SessionDescription{ SDP: string(offer), Type: webrtc.SDPTypeOffer, }); err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } gatheringComplete := webrtc.GatheringCompletePromise(peerConnection) answer, err := peerConnection.CreateAnswer(nil) if err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } else if err = peerConnection.SetLocalDescription(answer); err != nil { logHTTPError(res, err.Error(), http.StatusInternalServerError) return } <-gatheringComplete fmt.Fprint(res, peerConnection.LocalDescription().SDP) }