From 8b22c4270da28bb6e968f0804da8e2dc4a865651 Mon Sep 17 00:00:00 2001 From: S7evinK Date: Mon, 7 Jun 2021 10:17:20 +0200 Subject: [PATCH] Use LimitReader to prevent DoS risk (#1843) * Use LimitReader to prevent DoS risk Signed-off-by: Till Faelligen * Check if bytesWritten is equal to the maxFileSize Add tests Signed-off-by: Till Faelligen * Use oldschool defer to cleanup after the tests * Let LimitReader read MaxFileSizeBytes + 1 Co-authored-by: Kegsay --- mediaapi/routing/upload.go | 3 +- mediaapi/routing/upload_test.go | 132 ++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 mediaapi/routing/upload_test.go diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index f1dd231d..ada02b11 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -147,7 +147,8 @@ func (r *uploadRequest) doUpload( // r.storeFileAndMetadata(ctx, tmpDir, ...) // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of // nested function to guarantee either storage or cleanup. - hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath) + lr := io.LimitReader(reqReader, int64(*cfg.MaxFileSizeBytes)+1) + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, lr, cfg.AbsBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ "MaxFileSizeBytes": *cfg.MaxFileSizeBytes, diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go new file mode 100644 index 00000000..032437b5 --- /dev/null +++ b/mediaapi/routing/upload_test.go @@ -0,0 +1,132 @@ +package routing + +import ( + "context" + "io" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/mediaapi/fileutils" + "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +func Test_uploadRequest_doUpload(t *testing.T) { + type fields struct { + MediaMetadata *types.MediaMetadata + Logger *log.Entry + } + type args struct { + ctx context.Context + reqReader io.Reader + cfg *config.MediaAPI + db storage.Database + activeThumbnailGeneration *types.ActiveThumbnailGeneration + } + + wd, err := os.Getwd() + if err != nil { + t.Errorf("failed to get current working directory: %v", err) + } + + maxSize := config.FileSizeBytes(8) + logger := log.New().WithField("mediaapi", "test") + testdataPath := filepath.Join(wd, "./testdata") + + cfg := &config.MediaAPI{ + MaxFileSizeBytes: &maxSize, + BasePath: config.Path(testdataPath), + AbsBasePath: config.Path(testdataPath), + DynamicThumbnails: false, + } + + // create testdata folder and remove when done + _ = os.Mkdir(testdataPath, os.ModePerm) + defer fileutils.RemoveDir(types.Path(testdataPath), nil) + + db, err := storage.Open(&config.DatabaseOptions{ + ConnectionString: "file::memory:?cache=shared", + MaxOpenConnections: 100, + MaxIdleConnections: 2, + ConnMaxLifetimeSeconds: -1, + }) + if err != nil { + t.Errorf("error opening mediaapi database: %v", err) + } + + tests := []struct { + name string + fields fields + args args + want *util.JSONResponse + }{ + { + name: "upload ok", + args: args{ + ctx: context.Background(), + reqReader: strings.NewReader("test"), + cfg: cfg, + db: db, + }, + fields: fields{ + Logger: logger, + MediaMetadata: &types.MediaMetadata{ + MediaID: "1337", + UploadName: "test ok", + }, + }, + want: nil, + }, + { + name: "upload ok (exact size)", + args: args{ + ctx: context.Background(), + reqReader: strings.NewReader("testtest"), + cfg: cfg, + db: db, + }, + fields: fields{ + Logger: logger, + MediaMetadata: &types.MediaMetadata{ + MediaID: "1338", + UploadName: "test ok (exact size)", + }, + }, + want: nil, + }, + { + name: "upload not ok", + args: args{ + ctx: context.Background(), + reqReader: strings.NewReader("test test test"), + cfg: cfg, + db: db, + }, + fields: fields{ + Logger: logger, + MediaMetadata: &types.MediaMetadata{ + MediaID: "1339", + UploadName: "test fail", + }, + }, + want: requestEntityTooLargeJSONResponse(maxSize), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &uploadRequest{ + MediaMetadata: tt.fields.MediaMetadata, + Logger: tt.fields.Logger, + } + if got := r.doUpload(tt.args.ctx, tt.args.reqReader, tt.args.cfg, tt.args.db, tt.args.activeThumbnailGeneration); !reflect.DeepEqual(got, tt.want) { + t.Errorf("doUpload() = %+v, want %+v", got, tt.want) + } + }) + } +}