diff --git a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go index 82cc1d7c..d54bcdf6 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) // FileSizeBytes is a file size in bytes @@ -63,8 +64,10 @@ type MediaMetadata struct { type RemoteRequestResult struct { // Condition used for the requester to signal the result to all other routines waiting on this condition Cond *sync.Cond - // Resulting HTTP status code from the request - Result int + // MediaMetadata of the requested file to avoid querying the database for every waiting routine + MediaMetadata *MediaMetadata + // An error in util.JSONResponse form. nil in case of no error. + ErrorResponse *util.JSONResponse } // ActiveRemoteRequests is a lockable map of media URIs requested from remote homeservers diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index aecb1dd3..eac5d764 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -140,10 +140,14 @@ func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.MediaAPI } } // If we do not have a record and the origin is remote, we need to fetch it and respond with that file - return r.respondFromRemoteFile(w, cfg, db, activeRemoteRequests) + resErr := r.getRemoteFile(cfg, db, activeRemoteRequests) + if resErr != nil { + return resErr + } + } else { + // If we have a record, we can respond from the local file + r.MediaMetadata = mediaMetadata } - // If we have a record, we can respond from the local file - r.MediaMetadata = mediaMetadata return r.respondFromLocalFile(w, cfg.AbsBasePath) } @@ -207,137 +211,103 @@ func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePat return nil } -// respondFromRemoteFile fetches the remote file, caches it locally and responds from that local file +// getRemoteFile fetches the remote file and caches it locally // A hash map of active remote requests to a struct containing a sync.Cond is used to only download remote files once, // regardless of how many download requests are received. +// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines. // Returns a util.JSONResponse error in case of error -func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse { - // Note: getMediaMetadataForRemoteFile uses mutexes and conditions from activeRemoteRequests - mediaMetadata, resErr := r.getMediaMetadataForRemoteFile(db, activeRemoteRequests) +func (r *downloadRequest) getRemoteFile(cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) (errorResponse *util.JSONResponse) { + // Note: getMediaMetadataFromActiveRequest uses mutexes and conditions from activeRemoteRequests + mediaMetadata, resErr := r.getMediaMetadataFromActiveRequest(activeRemoteRequests) if resErr != nil { return resErr } else if mediaMetadata != nil { - // If we have a record, we can respond from the local file + // If we got metadata from an active request, we can respond from the local file r.MediaMetadata = mediaMetadata } else { - // If we do not have a record, we need to fetch the remote file first and then respond from the local file - // Note: getRemoteFile uses mutexes and conditions from activeRemoteRequests - if resErr := r.getRemoteFile(cfg.AbsBasePath, cfg.MaxFileSizeBytes, db, activeRemoteRequests); resErr != nil { - return resErr - } - } - return r.respondFromLocalFile(w, cfg.AbsBasePath) -} - -func (r *downloadRequest) getMediaMetadataForRemoteFile(db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) (*types.MediaMetadata, *util.JSONResponse) { - activeRemoteRequests.Lock() - defer activeRemoteRequests.Unlock() - - // check if we have a record of the media in our database - mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) - if err != nil { - r.Logger.WithError(err).Error("Error querying the database.") - resErr := jsonerror.InternalServerError() - return nil, &resErr - } - - if mediaMetadata != nil { - // If we have a record, we can respond from the local file - return mediaMetadata, nil - } - - // No record was found - - // Check if there is an active remote request for the file - mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) - if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { - r.Logger.Info("Waiting for another goroutine to fetch the remote file.") - - // NOTE: Wait unlocks and locks again internally. There is still a deferred Unlock() that will unlock this. - activeRemoteRequestResult.Cond.Wait() + // Note: This is an active request that MUST broadcastMediaMetadata to wake up waiting goroutines! + // Note: errorResponse is the named return variable + // Note: broadcastMediaMetadata uses mutexes and conditions from activeRemoteRequests + defer r.broadcastMediaMetadata(activeRemoteRequests, errorResponse) // check if we have a record of the media in our database mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) if err != nil { r.Logger.WithError(err).Error("Error querying the database.") resErr := jsonerror.InternalServerError() - return nil, &resErr + return &resErr } - if mediaMetadata != nil { + if mediaMetadata == nil { + // If we do not have a record, we need to fetch the remote file first and then respond from the local file + resErr := r.fetchRemoteFileAndStoreMetadata(cfg.AbsBasePath, cfg.MaxFileSizeBytes, db) + if resErr != nil { + return resErr + } + } else { // If we have a record, we can respond from the local file - return mediaMetadata, nil + r.MediaMetadata = mediaMetadata + } + } + return +} + +func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests *types.ActiveRemoteRequests) (*types.MediaMetadata, *util.JSONResponse) { + // Check if there is an active remote request for the file + mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + + activeRemoteRequests.Lock() + defer activeRemoteRequests.Unlock() + + if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { + r.Logger.Info("Waiting for another goroutine to fetch the remote file.") + + // NOTE: Wait unlocks and locks again internally. There is still a deferred Unlock() that will unlock this. + activeRemoteRequestResult.Cond.Wait() + if activeRemoteRequestResult.ErrorResponse != nil { + return nil, activeRemoteRequestResult.ErrorResponse } - // Note: if the result was 200, we shouldn't get here - switch activeRemoteRequestResult.Result { - case 404: + if activeRemoteRequestResult.MediaMetadata == nil { return nil, &util.JSONResponse{ Code: 404, JSON: jsonerror.NotFound("File not found."), } - case 500: - r.Logger.Error("Other goroutine failed to fetch the remote file.") - resErr := jsonerror.InternalServerError() - return nil, &resErr - default: - r.Logger.Error("Other goroutine failed to fetch the remote file.") - return nil, &util.JSONResponse{ - Code: activeRemoteRequestResult.Result, - JSON: jsonerror.Unknown("Failed to fetch file from remote server."), - } } + + return activeRemoteRequestResult.MediaMetadata, nil } // No active remote request so create one activeRemoteRequests.MXCToResult[mxcURL] = &types.RemoteRequestResult{ Cond: &sync.Cond{L: activeRemoteRequests}, } + return nil, nil } -// getRemoteFile fetches the file from the remote server and stores its metadata in the database +// broadcastMediaMetadata broadcasts the media metadata and error response to waiting goroutines // Only the owner of the activeRemoteRequestResult for this origin and media ID should call this function. -func (r *downloadRequest) getRemoteFile(absBasePath types.Path, maxFileSizeBytes types.FileSizeBytes, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse { - // Wake up other goroutines after this function returns. - isError := true - var result int - defer func() { - if isError { - // If an error happens, the lock MUST NOT have been taken, isError MUST be true and so the lock is taken here. - activeRemoteRequests.Lock() - } - defer activeRemoteRequests.Unlock() - mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) - if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { - r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.") - if result == 0 { - r.Logger.Error("Invalid result, treating as InternalServerError") - result = 500 - } - activeRemoteRequestResult.Result = result - activeRemoteRequestResult.Cond.Broadcast() - } - delete(activeRemoteRequests.MXCToResult, mxcURL) - }() +func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.ActiveRemoteRequests, errorResponse *util.JSONResponse) { + activeRemoteRequests.Lock() + defer activeRemoteRequests.Unlock() + mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { + r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.") + activeRemoteRequestResult.MediaMetadata = r.MediaMetadata + activeRemoteRequestResult.ErrorResponse = errorResponse + activeRemoteRequestResult.Cond.Broadcast() + } + delete(activeRemoteRequests.MXCToResult, mxcURL) +} +// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database +func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(absBasePath types.Path, maxFileSizeBytes types.FileSizeBytes, db *storage.Database) *util.JSONResponse { finalPath, duplicate, resErr := r.fetchRemoteFile(absBasePath, maxFileSizeBytes) if resErr != nil { - result = resErr.Code return resErr } - // NOTE: Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. - // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. - // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. - // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. - // The locking below mitigates this situation. - - // NOTE: The following two lines MUST remain together! - // isError == true causes the lock to be taken in a deferred function! - activeRemoteRequests.Lock() - isError = false - r.Logger.WithFields(log.Fields{ "Base64Hash": r.MediaMetadata.Base64Hash, "UploadName": r.MediaMetadata.UploadName, @@ -357,7 +327,6 @@ func (r *downloadRequest) getRemoteFile(absBasePath types.Path, maxFileSizeBytes // NOTE: It should really not be possible to fail the uniqueness test here so // there is no need to handle that separately resErr := jsonerror.InternalServerError() - result = resErr.Code return &resErr } @@ -370,7 +339,6 @@ func (r *downloadRequest) getRemoteFile(absBasePath types.Path, maxFileSizeBytes "Content-Type": r.MediaMetadata.ContentType, }).Infof("Remote file cached") - result = 200 return nil }