From ce019738ffdefe8fcfd8796b3252fe4899d71a1c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 21 Sep 2017 16:20:10 +0100 Subject: [PATCH] Pass a context when downloading remote media (#251) --- .../cmd/dendrite-media-api-server/main.go | 5 ++- .../cmd/dendrite-monolith-server/main.go | 2 +- .../dendrite/mediaapi/routing/routing.go | 33 ++++++++++++++++--- .../dendrite/mediaapi/writers/download.go | 33 +++++++++++++------ vendor/manifest | 2 +- .../matrix-org/gomatrixserverlib/client.go | 10 ++++-- 6 files changed, 65 insertions(+), 20 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go index 51cd6017..04674a2c 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/gomatrixserverlib" log "github.com/Sirupsen/logrus" ) @@ -51,10 +52,12 @@ func main() { log.WithError(err).Panic("Failed to open database") } + client := gomatrixserverlib.NewClient() + log.Info("Starting media API server on ", cfg.Listen.MediaAPI) api := mux.NewRouter() - routing.Setup(api, cfg, db) + routing.Setup(api, cfg, db, client) common.SetupHTTPAPI(http.DefaultServeMux, api) log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil)) diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go index 82a77f2f..c1c19769 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go @@ -325,7 +325,7 @@ func (m *monolith) setupAPIs() { ) mediaapi_routing.Setup( - m.api, m.cfg, m.mediaAPIDB, + m.api, m.cfg, m.mediaAPIDB, &m.federation.Client, ) syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool( diff --git a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go index 85a40362..6a01a65b 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go @@ -31,7 +31,12 @@ import ( const pathPrefixR0 = "/_matrix/media/v1" // Setup registers the media API HTTP handlers -func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) { +func Setup( + apiMux *mux.Router, + cfg *config.Dendrite, + db *storage.Database, + client *gomatrixserverlib.Client, +) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ @@ -47,14 +52,21 @@ func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) { MXCToResult: map[string]*types.RemoteRequestResult{}, } r0mux.Handle("/download/{serverName}/{mediaId}", - makeDownloadAPI("download", cfg, db, activeRemoteRequests, activeThumbnailGeneration), + makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods("GET") r0mux.Handle("/thumbnail/{serverName}/{mediaId}", - makeDownloadAPI("thumbnail", cfg, db, activeRemoteRequests, activeThumbnailGeneration), + makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration), ).Methods("GET") } -func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) http.HandlerFunc { +func makeDownloadAPI( + name string, + cfg *config.Dendrite, + db *storage.Database, + client *gomatrixserverlib.Client, + activeRemoteRequests *types.ActiveRemoteRequests, + activeThumbnailGeneration *types.ActiveThumbnailGeneration, +) http.HandlerFunc { return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -64,6 +76,17 @@ func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, ac w.Header().Set("Content-Type", "application/json") vars := mux.Vars(req) - writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail") + writers.Download( + w, + req, + gomatrixserverlib.ServerName(vars["serverName"]), + types.MediaID(vars["mediaId"]), + cfg, + db, + client, + activeRemoteRequests, + activeThumbnailGeneration, + name == "thumbnail", + ) })) } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index b92fe2d9..c611b336 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -68,6 +68,7 @@ func Download( mediaID types.MediaID, cfg *config.Dendrite, db *storage.Database, + client *gomatrixserverlib.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, @@ -120,7 +121,8 @@ func Download( } metadata, err := dReq.doDownload( - req.Context(), w, cfg, db, activeRemoteRequests, activeThumbnailGeneration, + req.Context(), w, cfg, db, client, + activeRemoteRequests, activeThumbnailGeneration, ) if err != nil { // TODO: Handle the fact we might have started writing the response @@ -199,6 +201,7 @@ func (r *downloadRequest) doDownload( w http.ResponseWriter, cfg *config.Dendrite, db *storage.Database, + client *gomatrixserverlib.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) (*types.MediaMetadata, error) { @@ -216,7 +219,7 @@ func (r *downloadRequest) doDownload( } // If we do not have a record and the origin is remote, we need to fetch it and respond with that file resErr := r.getRemoteFile( - ctx, cfg, db, activeRemoteRequests, activeThumbnailGeneration, + ctx, client, cfg, db, activeRemoteRequests, activeThumbnailGeneration, ) if resErr != nil { return nil, resErr @@ -442,6 +445,7 @@ func (r *downloadRequest) generateThumbnail( // Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines. func (r *downloadRequest) getRemoteFile( ctx context.Context, + client *gomatrixserverlib.Client, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, @@ -477,7 +481,8 @@ func (r *downloadRequest) getRemoteFile( if mediaMetadata == nil { // If we do not have a record, we need to fetch the remote file first and then respond from the local file err := r.fetchRemoteFileAndStoreMetadata( - ctx, cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db, + ctx, client, + cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db, cfg.Media.ThumbnailSizes, activeThumbnailGeneration, cfg.Media.MaxThumbnailGenerators, ) @@ -541,6 +546,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act // fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( ctx context.Context, + client *gomatrixserverlib.Client, absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, db *storage.Database, @@ -548,7 +554,9 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, ) error { - finalPath, duplicate, err := r.fetchRemoteFile(absBasePath, maxFileSizeBytes) + finalPath, duplicate, err := r.fetchRemoteFile( + ctx, client, absBasePath, maxFileSizeBytes, + ) if err != nil { return err } @@ -597,11 +605,16 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( return nil } -func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, error) { +func (r *downloadRequest) fetchRemoteFile( + ctx context.Context, + client *gomatrixserverlib.Client, + absBasePath config.Path, + maxFileSizeBytes config.FileSizeBytes, +) (types.Path, bool, error) { r.Logger.Info("Fetching remote file") // create request for remote file - resp, err := r.createRemoteRequest() + resp, err := r.createRemoteRequest(ctx, client) if err != nil { return "", false, err } @@ -664,10 +677,10 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy return types.Path(finalPath), duplicate, nil } -func (r *downloadRequest) createRemoteRequest() (*http.Response, error) { - matrixClient := gomatrixserverlib.NewClient() - - resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) +func (r *downloadRequest) createRemoteRequest( + ctx context.Context, matrixClient *gomatrixserverlib.Client, +) (*http.Response, error) { + resp, err := matrixClient.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) if err != nil { return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) } diff --git a/vendor/manifest b/vendor/manifest index cac40142..e30b2f2e 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -116,7 +116,7 @@ { "importpath": "github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib", - "revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e", + "revision": "40b35e1c997fc7e35342aeb39187ff6bf3e10b2e", "branch": "master" }, { diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/client.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/client.go index 3cd87196..87cc4753 100644 --- a/vendor/src/github.com/matrix-org/gomatrixserverlib/client.go +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/client.go @@ -236,9 +236,15 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo } // CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error -func (fc *Client) CreateMediaDownloadRequest(matrixServer ServerName, mediaID string) (*http.Response, error) { +func (fc *Client) CreateMediaDownloadRequest( + ctx context.Context, matrixServer ServerName, mediaID string, +) (*http.Response, error) { requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID - resp, err := fc.client.Get(requestURL) + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, err + } + resp, err := fc.client.Do(req.WithContext(ctx)) if err != nil { return nil, err }