Pass a context when downloading remote media (#251)
parent
fef290c47e
commit
ce019738ff
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/common/config"
|
||||
"github.com/matrix-org/dendrite/mediaapi/routing"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
@ -51,10 +52,12 @@ func main() {
|
|||
log.WithError(err).Panic("Failed to open database")
|
||||
}
|
||||
|
||||
client := gomatrixserverlib.NewClient()
|
||||
|
||||
log.Info("Starting media API server on ", cfg.Listen.MediaAPI)
|
||||
|
||||
api := mux.NewRouter()
|
||||
routing.Setup(api, cfg, db)
|
||||
routing.Setup(api, cfg, db, client)
|
||||
common.SetupHTTPAPI(http.DefaultServeMux, api)
|
||||
|
||||
log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil))
|
||||
|
|
|
@ -325,7 +325,7 @@ func (m *monolith) setupAPIs() {
|
|||
)
|
||||
|
||||
mediaapi_routing.Setup(
|
||||
m.api, m.cfg, m.mediaAPIDB,
|
||||
m.api, m.cfg, m.mediaAPIDB, &m.federation.Client,
|
||||
)
|
||||
|
||||
syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool(
|
||||
|
|
|
@ -31,7 +31,12 @@ import (
|
|||
const pathPrefixR0 = "/_matrix/media/v1"
|
||||
|
||||
// Setup registers the media API HTTP handlers
|
||||
func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
|
||||
func Setup(
|
||||
apiMux *mux.Router,
|
||||
cfg *config.Dendrite,
|
||||
db *storage.Database,
|
||||
client *gomatrixserverlib.Client,
|
||||
) {
|
||||
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
||||
|
||||
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
|
||||
|
@ -47,14 +52,21 @@ func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
|
|||
MXCToResult: map[string]*types.RemoteRequestResult{},
|
||||
}
|
||||
r0mux.Handle("/download/{serverName}/{mediaId}",
|
||||
makeDownloadAPI("download", cfg, db, activeRemoteRequests, activeThumbnailGeneration),
|
||||
makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
||||
).Methods("GET")
|
||||
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||
makeDownloadAPI("thumbnail", cfg, db, activeRemoteRequests, activeThumbnailGeneration),
|
||||
makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
||||
).Methods("GET")
|
||||
}
|
||||
|
||||
func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) http.HandlerFunc {
|
||||
func makeDownloadAPI(
|
||||
name string,
|
||||
cfg *config.Dendrite,
|
||||
db *storage.Database,
|
||||
client *gomatrixserverlib.Client,
|
||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||
) http.HandlerFunc {
|
||||
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req = util.RequestWithLogging(req)
|
||||
|
||||
|
@ -64,6 +76,17 @@ func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, ac
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
vars := mux.Vars(req)
|
||||
writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail")
|
||||
writers.Download(
|
||||
w,
|
||||
req,
|
||||
gomatrixserverlib.ServerName(vars["serverName"]),
|
||||
types.MediaID(vars["mediaId"]),
|
||||
cfg,
|
||||
db,
|
||||
client,
|
||||
activeRemoteRequests,
|
||||
activeThumbnailGeneration,
|
||||
name == "thumbnail",
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -68,6 +68,7 @@ func Download(
|
|||
mediaID types.MediaID,
|
||||
cfg *config.Dendrite,
|
||||
db *storage.Database,
|
||||
client *gomatrixserverlib.Client,
|
||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||
isThumbnailRequest bool,
|
||||
|
@ -120,7 +121,8 @@ func Download(
|
|||
}
|
||||
|
||||
metadata, err := dReq.doDownload(
|
||||
req.Context(), w, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
|
||||
req.Context(), w, cfg, db, client,
|
||||
activeRemoteRequests, activeThumbnailGeneration,
|
||||
)
|
||||
if err != nil {
|
||||
// TODO: Handle the fact we might have started writing the response
|
||||
|
@ -199,6 +201,7 @@ func (r *downloadRequest) doDownload(
|
|||
w http.ResponseWriter,
|
||||
cfg *config.Dendrite,
|
||||
db *storage.Database,
|
||||
client *gomatrixserverlib.Client,
|
||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||
) (*types.MediaMetadata, error) {
|
||||
|
@ -216,7 +219,7 @@ func (r *downloadRequest) doDownload(
|
|||
}
|
||||
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
|
||||
resErr := r.getRemoteFile(
|
||||
ctx, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
|
||||
ctx, client, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
|
||||
)
|
||||
if resErr != nil {
|
||||
return nil, resErr
|
||||
|
@ -442,6 +445,7 @@ func (r *downloadRequest) generateThumbnail(
|
|||
// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines.
|
||||
func (r *downloadRequest) getRemoteFile(
|
||||
ctx context.Context,
|
||||
client *gomatrixserverlib.Client,
|
||||
cfg *config.Dendrite,
|
||||
db *storage.Database,
|
||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||
|
@ -477,7 +481,8 @@ func (r *downloadRequest) getRemoteFile(
|
|||
if mediaMetadata == nil {
|
||||
// If we do not have a record, we need to fetch the remote file first and then respond from the local file
|
||||
err := r.fetchRemoteFileAndStoreMetadata(
|
||||
ctx, cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
|
||||
ctx, client,
|
||||
cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
|
||||
cfg.Media.ThumbnailSizes, activeThumbnailGeneration,
|
||||
cfg.Media.MaxThumbnailGenerators,
|
||||
)
|
||||
|
@ -541,6 +546,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act
|
|||
// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database
|
||||
func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
||||
ctx context.Context,
|
||||
client *gomatrixserverlib.Client,
|
||||
absBasePath config.Path,
|
||||
maxFileSizeBytes config.FileSizeBytes,
|
||||
db *storage.Database,
|
||||
|
@ -548,7 +554,9 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
|||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||
maxThumbnailGenerators int,
|
||||
) error {
|
||||
finalPath, duplicate, err := r.fetchRemoteFile(absBasePath, maxFileSizeBytes)
|
||||
finalPath, duplicate, err := r.fetchRemoteFile(
|
||||
ctx, client, absBasePath, maxFileSizeBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -597,11 +605,16 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, error) {
|
||||
func (r *downloadRequest) fetchRemoteFile(
|
||||
ctx context.Context,
|
||||
client *gomatrixserverlib.Client,
|
||||
absBasePath config.Path,
|
||||
maxFileSizeBytes config.FileSizeBytes,
|
||||
) (types.Path, bool, error) {
|
||||
r.Logger.Info("Fetching remote file")
|
||||
|
||||
// create request for remote file
|
||||
resp, err := r.createRemoteRequest()
|
||||
resp, err := r.createRemoteRequest(ctx, client)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
@ -664,10 +677,10 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
|
|||
return types.Path(finalPath), duplicate, nil
|
||||
}
|
||||
|
||||
func (r *downloadRequest) createRemoteRequest() (*http.Response, error) {
|
||||
matrixClient := gomatrixserverlib.NewClient()
|
||||
|
||||
resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||
func (r *downloadRequest) createRemoteRequest(
|
||||
ctx context.Context, matrixClient *gomatrixserverlib.Client,
|
||||
) (*http.Response, error) {
|
||||
resp, err := matrixClient.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
||||
}
|
||||
|
|
|
@ -116,7 +116,7 @@
|
|||
{
|
||||
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
||||
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
||||
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
|
||||
"revision": "40b35e1c997fc7e35342aeb39187ff6bf3e10b2e",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
|
|
|
@ -236,9 +236,15 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
|||
}
|
||||
|
||||
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
|
||||
func (fc *Client) CreateMediaDownloadRequest(matrixServer ServerName, mediaID string) (*http.Response, error) {
|
||||
func (fc *Client) CreateMediaDownloadRequest(
|
||||
ctx context.Context, matrixServer ServerName, mediaID string,
|
||||
) (*http.Response, error) {
|
||||
requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID
|
||||
resp, err := fc.client.Get(requestURL)
|
||||
req, err := http.NewRequest("GET", requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := fc.client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue