Pass a context when downloading remote media (#251)

main
Mark Haines 2017-09-21 16:20:10 +01:00 committed by GitHub
parent fef290c47e
commit ce019738ff
6 changed files with 65 additions and 20 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/Sirupsen/logrus"
)
@ -51,10 +52,12 @@ func main() {
log.WithError(err).Panic("Failed to open database")
}
client := gomatrixserverlib.NewClient()
log.Info("Starting media API server on ", cfg.Listen.MediaAPI)
api := mux.NewRouter()
routing.Setup(api, cfg, db)
routing.Setup(api, cfg, db, client)
common.SetupHTTPAPI(http.DefaultServeMux, api)
log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil))

View File

@ -325,7 +325,7 @@ func (m *monolith) setupAPIs() {
)
mediaapi_routing.Setup(
m.api, m.cfg, m.mediaAPIDB,
m.api, m.cfg, m.mediaAPIDB, &m.federation.Client,
)
syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool(

View File

@ -31,7 +31,12 @@ import (
const pathPrefixR0 = "/_matrix/media/v1"
// Setup registers the media API HTTP handlers
func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
func Setup(
apiMux *mux.Router,
cfg *config.Dendrite,
db *storage.Database,
client *gomatrixserverlib.Client,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
@ -47,14 +52,21 @@ func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
MXCToResult: map[string]*types.RemoteRequestResult{},
}
r0mux.Handle("/download/{serverName}/{mediaId}",
makeDownloadAPI("download", cfg, db, activeRemoteRequests, activeThumbnailGeneration),
makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods("GET")
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", cfg, db, activeRemoteRequests, activeThumbnailGeneration),
makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods("GET")
}
func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) http.HandlerFunc {
func makeDownloadAPI(
name string,
cfg *config.Dendrite,
db *storage.Database,
client *gomatrixserverlib.Client,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) http.HandlerFunc {
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)
@ -64,6 +76,17 @@ func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, ac
w.Header().Set("Content-Type", "application/json")
vars := mux.Vars(req)
writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail")
writers.Download(
w,
req,
gomatrixserverlib.ServerName(vars["serverName"]),
types.MediaID(vars["mediaId"]),
cfg,
db,
client,
activeRemoteRequests,
activeThumbnailGeneration,
name == "thumbnail",
)
}))
}

View File

@ -68,6 +68,7 @@ func Download(
mediaID types.MediaID,
cfg *config.Dendrite,
db *storage.Database,
client *gomatrixserverlib.Client,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool,
@ -120,7 +121,8 @@ func Download(
}
metadata, err := dReq.doDownload(
req.Context(), w, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
req.Context(), w, cfg, db, client,
activeRemoteRequests, activeThumbnailGeneration,
)
if err != nil {
// TODO: Handle the fact we might have started writing the response
@ -199,6 +201,7 @@ func (r *downloadRequest) doDownload(
w http.ResponseWriter,
cfg *config.Dendrite,
db *storage.Database,
client *gomatrixserverlib.Client,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) (*types.MediaMetadata, error) {
@ -216,7 +219,7 @@ func (r *downloadRequest) doDownload(
}
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
resErr := r.getRemoteFile(
ctx, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
ctx, client, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
)
if resErr != nil {
return nil, resErr
@ -442,6 +445,7 @@ func (r *downloadRequest) generateThumbnail(
// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines.
func (r *downloadRequest) getRemoteFile(
ctx context.Context,
client *gomatrixserverlib.Client,
cfg *config.Dendrite,
db *storage.Database,
activeRemoteRequests *types.ActiveRemoteRequests,
@ -477,7 +481,8 @@ func (r *downloadRequest) getRemoteFile(
if mediaMetadata == nil {
// If we do not have a record, we need to fetch the remote file first and then respond from the local file
err := r.fetchRemoteFileAndStoreMetadata(
ctx, cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
ctx, client,
cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
cfg.Media.ThumbnailSizes, activeThumbnailGeneration,
cfg.Media.MaxThumbnailGenerators,
)
@ -541,6 +546,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act
// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database
func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
ctx context.Context,
client *gomatrixserverlib.Client,
absBasePath config.Path,
maxFileSizeBytes config.FileSizeBytes,
db *storage.Database,
@ -548,7 +554,9 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
maxThumbnailGenerators int,
) error {
finalPath, duplicate, err := r.fetchRemoteFile(absBasePath, maxFileSizeBytes)
finalPath, duplicate, err := r.fetchRemoteFile(
ctx, client, absBasePath, maxFileSizeBytes,
)
if err != nil {
return err
}
@ -597,11 +605,16 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil
}
func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, error) {
func (r *downloadRequest) fetchRemoteFile(
ctx context.Context,
client *gomatrixserverlib.Client,
absBasePath config.Path,
maxFileSizeBytes config.FileSizeBytes,
) (types.Path, bool, error) {
r.Logger.Info("Fetching remote file")
// create request for remote file
resp, err := r.createRemoteRequest()
resp, err := r.createRemoteRequest(ctx, client)
if err != nil {
return "", false, err
}
@ -664,10 +677,10 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
return types.Path(finalPath), duplicate, nil
}
func (r *downloadRequest) createRemoteRequest() (*http.Response, error) {
matrixClient := gomatrixserverlib.NewClient()
resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
func (r *downloadRequest) createRemoteRequest(
ctx context.Context, matrixClient *gomatrixserverlib.Client,
) (*http.Response, error) {
resp, err := matrixClient.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil {
return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
}

2
vendor/manifest vendored
View File

@ -116,7 +116,7 @@
{
"importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
"revision": "40b35e1c997fc7e35342aeb39187ff6bf3e10b2e",
"branch": "master"
},
{

View File

@ -236,9 +236,15 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
}
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
func (fc *Client) CreateMediaDownloadRequest(matrixServer ServerName, mediaID string) (*http.Response, error) {
func (fc *Client) CreateMediaDownloadRequest(
ctx context.Context, matrixServer ServerName, mediaID string,
) (*http.Response, error) {
requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID
resp, err := fc.client.Get(requestURL)
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, err
}
resp, err := fc.client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}