Pass a context when downloading remote media (#251)
parent
fef290c47e
commit
ce019738ff
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/common/config"
|
"github.com/matrix-org/dendrite/common/config"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/routing"
|
"github.com/matrix-org/dendrite/mediaapi/routing"
|
||||||
"github.com/matrix-org/dendrite/mediaapi/storage"
|
"github.com/matrix-org/dendrite/mediaapi/storage"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
@ -51,10 +52,12 @@ func main() {
|
||||||
log.WithError(err).Panic("Failed to open database")
|
log.WithError(err).Panic("Failed to open database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := gomatrixserverlib.NewClient()
|
||||||
|
|
||||||
log.Info("Starting media API server on ", cfg.Listen.MediaAPI)
|
log.Info("Starting media API server on ", cfg.Listen.MediaAPI)
|
||||||
|
|
||||||
api := mux.NewRouter()
|
api := mux.NewRouter()
|
||||||
routing.Setup(api, cfg, db)
|
routing.Setup(api, cfg, db, client)
|
||||||
common.SetupHTTPAPI(http.DefaultServeMux, api)
|
common.SetupHTTPAPI(http.DefaultServeMux, api)
|
||||||
|
|
||||||
log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil))
|
log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil))
|
||||||
|
|
|
@ -325,7 +325,7 @@ func (m *monolith) setupAPIs() {
|
||||||
)
|
)
|
||||||
|
|
||||||
mediaapi_routing.Setup(
|
mediaapi_routing.Setup(
|
||||||
m.api, m.cfg, m.mediaAPIDB,
|
m.api, m.cfg, m.mediaAPIDB, &m.federation.Client,
|
||||||
)
|
)
|
||||||
|
|
||||||
syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool(
|
syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool(
|
||||||
|
|
|
@ -31,7 +31,12 @@ import (
|
||||||
const pathPrefixR0 = "/_matrix/media/v1"
|
const pathPrefixR0 = "/_matrix/media/v1"
|
||||||
|
|
||||||
// Setup registers the media API HTTP handlers
|
// Setup registers the media API HTTP handlers
|
||||||
func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
|
func Setup(
|
||||||
|
apiMux *mux.Router,
|
||||||
|
cfg *config.Dendrite,
|
||||||
|
db *storage.Database,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
|
) {
|
||||||
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
|
||||||
|
|
||||||
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
|
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
|
||||||
|
@ -47,14 +52,21 @@ func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
|
||||||
MXCToResult: map[string]*types.RemoteRequestResult{},
|
MXCToResult: map[string]*types.RemoteRequestResult{},
|
||||||
}
|
}
|
||||||
r0mux.Handle("/download/{serverName}/{mediaId}",
|
r0mux.Handle("/download/{serverName}/{mediaId}",
|
||||||
makeDownloadAPI("download", cfg, db, activeRemoteRequests, activeThumbnailGeneration),
|
makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
||||||
).Methods("GET")
|
).Methods("GET")
|
||||||
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
|
||||||
makeDownloadAPI("thumbnail", cfg, db, activeRemoteRequests, activeThumbnailGeneration),
|
makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
|
||||||
).Methods("GET")
|
).Methods("GET")
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) http.HandlerFunc {
|
func makeDownloadAPI(
|
||||||
|
name string,
|
||||||
|
cfg *config.Dendrite,
|
||||||
|
db *storage.Database,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
|
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||||
|
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||||
|
) http.HandlerFunc {
|
||||||
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
req = util.RequestWithLogging(req)
|
req = util.RequestWithLogging(req)
|
||||||
|
|
||||||
|
@ -64,6 +76,17 @@ func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, ac
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail")
|
writers.Download(
|
||||||
|
w,
|
||||||
|
req,
|
||||||
|
gomatrixserverlib.ServerName(vars["serverName"]),
|
||||||
|
types.MediaID(vars["mediaId"]),
|
||||||
|
cfg,
|
||||||
|
db,
|
||||||
|
client,
|
||||||
|
activeRemoteRequests,
|
||||||
|
activeThumbnailGeneration,
|
||||||
|
name == "thumbnail",
|
||||||
|
)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,6 +68,7 @@ func Download(
|
||||||
mediaID types.MediaID,
|
mediaID types.MediaID,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
db *storage.Database,
|
db *storage.Database,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||||
isThumbnailRequest bool,
|
isThumbnailRequest bool,
|
||||||
|
@ -120,7 +121,8 @@ func Download(
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, err := dReq.doDownload(
|
metadata, err := dReq.doDownload(
|
||||||
req.Context(), w, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
|
req.Context(), w, cfg, db, client,
|
||||||
|
activeRemoteRequests, activeThumbnailGeneration,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: Handle the fact we might have started writing the response
|
// TODO: Handle the fact we might have started writing the response
|
||||||
|
@ -199,6 +201,7 @@ func (r *downloadRequest) doDownload(
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
db *storage.Database,
|
db *storage.Database,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||||
) (*types.MediaMetadata, error) {
|
) (*types.MediaMetadata, error) {
|
||||||
|
@ -216,7 +219,7 @@ func (r *downloadRequest) doDownload(
|
||||||
}
|
}
|
||||||
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
|
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
|
||||||
resErr := r.getRemoteFile(
|
resErr := r.getRemoteFile(
|
||||||
ctx, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
|
ctx, client, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
|
||||||
)
|
)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return nil, resErr
|
return nil, resErr
|
||||||
|
@ -442,6 +445,7 @@ func (r *downloadRequest) generateThumbnail(
|
||||||
// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines.
|
// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines.
|
||||||
func (r *downloadRequest) getRemoteFile(
|
func (r *downloadRequest) getRemoteFile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
cfg *config.Dendrite,
|
cfg *config.Dendrite,
|
||||||
db *storage.Database,
|
db *storage.Database,
|
||||||
activeRemoteRequests *types.ActiveRemoteRequests,
|
activeRemoteRequests *types.ActiveRemoteRequests,
|
||||||
|
@ -477,7 +481,8 @@ func (r *downloadRequest) getRemoteFile(
|
||||||
if mediaMetadata == nil {
|
if mediaMetadata == nil {
|
||||||
// If we do not have a record, we need to fetch the remote file first and then respond from the local file
|
// If we do not have a record, we need to fetch the remote file first and then respond from the local file
|
||||||
err := r.fetchRemoteFileAndStoreMetadata(
|
err := r.fetchRemoteFileAndStoreMetadata(
|
||||||
ctx, cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
|
ctx, client,
|
||||||
|
cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
|
||||||
cfg.Media.ThumbnailSizes, activeThumbnailGeneration,
|
cfg.Media.ThumbnailSizes, activeThumbnailGeneration,
|
||||||
cfg.Media.MaxThumbnailGenerators,
|
cfg.Media.MaxThumbnailGenerators,
|
||||||
)
|
)
|
||||||
|
@ -541,6 +546,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act
|
||||||
// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database
|
// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database
|
||||||
func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
absBasePath config.Path,
|
absBasePath config.Path,
|
||||||
maxFileSizeBytes config.FileSizeBytes,
|
maxFileSizeBytes config.FileSizeBytes,
|
||||||
db *storage.Database,
|
db *storage.Database,
|
||||||
|
@ -548,7 +554,9 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
||||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
|
||||||
maxThumbnailGenerators int,
|
maxThumbnailGenerators int,
|
||||||
) error {
|
) error {
|
||||||
finalPath, duplicate, err := r.fetchRemoteFile(absBasePath, maxFileSizeBytes)
|
finalPath, duplicate, err := r.fetchRemoteFile(
|
||||||
|
ctx, client, absBasePath, maxFileSizeBytes,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -597,11 +605,16 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, error) {
|
func (r *downloadRequest) fetchRemoteFile(
|
||||||
|
ctx context.Context,
|
||||||
|
client *gomatrixserverlib.Client,
|
||||||
|
absBasePath config.Path,
|
||||||
|
maxFileSizeBytes config.FileSizeBytes,
|
||||||
|
) (types.Path, bool, error) {
|
||||||
r.Logger.Info("Fetching remote file")
|
r.Logger.Info("Fetching remote file")
|
||||||
|
|
||||||
// create request for remote file
|
// create request for remote file
|
||||||
resp, err := r.createRemoteRequest()
|
resp, err := r.createRemoteRequest(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", false, err
|
return "", false, err
|
||||||
}
|
}
|
||||||
|
@ -664,10 +677,10 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
|
||||||
return types.Path(finalPath), duplicate, nil
|
return types.Path(finalPath), duplicate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *downloadRequest) createRemoteRequest() (*http.Response, error) {
|
func (r *downloadRequest) createRemoteRequest(
|
||||||
matrixClient := gomatrixserverlib.NewClient()
|
ctx context.Context, matrixClient *gomatrixserverlib.Client,
|
||||||
|
) (*http.Response, error) {
|
||||||
resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
resp, err := matrixClient.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,7 +116,7 @@
|
||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
||||||
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
||||||
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
|
"revision": "40b35e1c997fc7e35342aeb39187ff6bf3e10b2e",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -236,9 +236,15 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
|
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
|
||||||
func (fc *Client) CreateMediaDownloadRequest(matrixServer ServerName, mediaID string) (*http.Response, error) {
|
func (fc *Client) CreateMediaDownloadRequest(
|
||||||
|
ctx context.Context, matrixServer ServerName, mediaID string,
|
||||||
|
) (*http.Response, error) {
|
||||||
requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID
|
requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID
|
||||||
resp, err := fc.client.Get(requestURL)
|
req, err := http.NewRequest("GET", requestURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := fc.client.Do(req.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue