- Make sure we always cleanup the temp directory on error.
- Complain about it having an error prone API shape.
main
Kegsay 2020-08-26 15:38:34 +01:00 committed by GitHub
parent 29d6481842
commit 3802efe301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 16 deletions

View File

@ -16,6 +16,7 @@ package fileutils
import ( import (
"bufio" "bufio"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -27,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -104,15 +106,23 @@ func RemoveDir(dir types.Path, logger *log.Entry) {
} }
} }
// WriteTempFile writes to a new temporary file // WriteTempFile writes to a new temporary file.
func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { // The file is deleted if there was an error while writing.
func WriteTempFile(
ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path,
) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
size = -1 size = -1
logger := util.GetLogger(ctx)
tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath) tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath)
if err != nil { if err != nil {
return return
} }
defer (func() { err = tmpFile.Close() })() defer func() {
err2 := tmpFile.Close()
if err == nil {
err = err2
}
}()
// The amount of data read is limited to maxFileSizeBytes. At this point, if there is more data it will be truncated. // The amount of data read is limited to maxFileSizeBytes. At this point, if there is more data it will be truncated.
limitedReader := io.LimitReader(reqReader, int64(maxFileSizeBytes)) limitedReader := io.LimitReader(reqReader, int64(maxFileSizeBytes))
@ -123,11 +133,13 @@ func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, a
teeReader := io.TeeReader(limitedReader, hasher) teeReader := io.TeeReader(limitedReader, hasher)
bytesWritten, err := io.Copy(tmpFileWriter, teeReader) bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
RemoveDir(tmpDir, logger)
return return
} }
err = tmpFileWriter.Flush() err = tmpFileWriter.Flush()
if err != nil { if err != nil {
RemoveDir(tmpDir, logger)
return return
} }

View File

@ -728,12 +728,11 @@ func (r *downloadRequest) fetchRemoteFile(
// method of deduplicating files to save storage, as well as a way to conduct // method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. // integrity checks on the file data in the repository.
// Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(resp.Body, maxFileSizeBytes, absBasePath) hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, resp.Body, maxFileSizeBytes, absBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{ r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": maxFileSizeBytes, "MaxFileSizeBytes": maxFileSizeBytes,
}).Warn("Error while downloading file from remote server") }).Warn("Error while downloading file from remote server")
fileutils.RemoveDir(tmpDir, r.Logger)
return "", false, errors.New("file could not be downloaded from remote server") return "", false, errors.New("file could not be downloaded from remote server")
} }

View File

@ -53,8 +53,8 @@ func Setup(
uploadHandler := httputil.MakeAuthAPI( uploadHandler := httputil.MakeAuthAPI(
"upload", userAPI, "upload", userAPI,
func(req *http.Request, _ *userapi.Device) util.JSONResponse { func(req *http.Request, dev *userapi.Device) util.JSONResponse {
return Upload(req, cfg, db, activeThumbnailGeneration) return Upload(req, cfg, dev, db, activeThumbnailGeneration)
}, },
) )

View File

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer" "github.com/matrix-org/dendrite/mediaapi/thumbnailer"
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -55,8 +56,8 @@ type uploadResponse struct {
// This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large. // This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large.
// Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory. // Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory.
// TODO: We should time out requests if they have not received any data within a configured timeout period. // TODO: We should time out requests if they have not received any data within a configured timeout period.
func Upload(req *http.Request, cfg *config.MediaAPI, db storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration) util.JSONResponse { func Upload(req *http.Request, cfg *config.MediaAPI, dev *userapi.Device, db storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration) util.JSONResponse {
r, resErr := parseAndValidateRequest(req, cfg) r, resErr := parseAndValidateRequest(req, cfg, dev)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -76,13 +77,14 @@ func Upload(req *http.Request, cfg *config.MediaAPI, db storage.Database, active
// parseAndValidateRequest parses the incoming upload request to validate and extract // parseAndValidateRequest parses the incoming upload request to validate and extract
// all the metadata about the media being uploaded. // all the metadata about the media being uploaded.
// Returns either an uploadRequest or an error formatted as a util.JSONResponse // Returns either an uploadRequest or an error formatted as a util.JSONResponse
func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRequest, *util.JSONResponse) { func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI, dev *userapi.Device) (*uploadRequest, *util.JSONResponse) {
r := &uploadRequest{ r := &uploadRequest{
MediaMetadata: &types.MediaMetadata{ MediaMetadata: &types.MediaMetadata{
Origin: cfg.Matrix.ServerName, Origin: cfg.Matrix.ServerName,
FileSizeBytes: types.FileSizeBytes(req.ContentLength), FileSizeBytes: types.FileSizeBytes(req.ContentLength),
ContentType: types.ContentType(req.Header.Get("Content-Type")), ContentType: types.ContentType(req.Header.Get("Content-Type")),
UploadName: types.Filename(url.PathEscape(req.FormValue("filename"))), UploadName: types.Filename(url.PathEscape(req.FormValue("filename"))),
UserID: types.MatrixUserID(dev.UserID),
}, },
Logger: util.GetLogger(req.Context()).WithField("Origin", cfg.Matrix.ServerName), Logger: util.GetLogger(req.Context()).WithField("Origin", cfg.Matrix.ServerName),
} }
@ -138,12 +140,18 @@ func (r *uploadRequest) doUpload(
// method of deduplicating files to save storage, as well as a way to conduct // method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. // integrity checks on the file data in the repository.
// Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath) //
// TODO: This has a bad API shape where you either need to call:
// fileutils.RemoveDir(tmpDir, r.Logger)
// or call:
// r.storeFileAndMetadata(ctx, tmpDir, ...)
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
// nested function to guarantee either storage or cleanup.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{ r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes, "MaxFileSizeBytes": *cfg.MaxFileSizeBytes,
}).Warn("Error while transferring file") }).Warn("Error while transferring file")
fileutils.RemoveDir(tmpDir, r.Logger)
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to upload"), JSON: jsonerror.Unknown("Failed to upload"),
@ -157,11 +165,14 @@ func (r *uploadRequest) doUpload(
ctx, hash, r.MediaMetadata.Origin, ctx, hash, r.MediaMetadata.Origin,
) )
if err != nil { if err != nil {
fileutils.RemoveDir(tmpDir, r.Logger)
r.Logger.WithError(err).Error("Error querying the database by hash.") r.Logger.WithError(err).Error("Error querying the database by hash.")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return &resErr return &resErr
} }
if existingMetadata != nil { if existingMetadata != nil {
// The file already exists, delete the uploaded temporary file.
defer fileutils.RemoveDir(tmpDir, r.Logger)
// The file already exists. Make a new media ID up for it. // The file already exists. Make a new media ID up for it.
mediaID, merr := r.generateMediaID(ctx, db) mediaID, merr := r.generateMediaID(ctx, db)
if merr != nil { if merr != nil {
@ -181,15 +192,13 @@ func (r *uploadRequest) doUpload(
Base64Hash: hash, Base64Hash: hash,
UserID: r.MediaMetadata.UserID, UserID: r.MediaMetadata.UserID,
} }
// Clean up the uploaded temporary file.
fileutils.RemoveDir(tmpDir, r.Logger)
} else { } else {
// The file doesn't exist. Update the request metadata. // The file doesn't exist. Update the request metadata.
r.MediaMetadata.FileSizeBytes = bytesWritten r.MediaMetadata.FileSizeBytes = bytesWritten
r.MediaMetadata.Base64Hash = hash r.MediaMetadata.Base64Hash = hash
r.MediaMetadata.MediaID, err = r.generateMediaID(ctx, db) r.MediaMetadata.MediaID, err = r.generateMediaID(ctx, db)
if err != nil { if err != nil {
fileutils.RemoveDir(tmpDir, r.Logger)
r.Logger.WithError(err).Error("Failed to generate media ID for new upload") r.Logger.WithError(err).Error("Failed to generate media ID for new upload")
resErr := jsonerror.InternalServerError() resErr := jsonerror.InternalServerError()
return &resErr return &resErr