Don't use more than 999 variables in SQLite querys. (#1224)

Closes #1223

Signed-off-by: Henrik Sölver <henrik.solver@gmail.com>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
main
Henrik Sölver 2020-07-27 14:19:30 +02:00 committed by GitHub
parent c8d476a3cc
commit 83f038e12b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 24 additions and 6 deletions

View File

@ -243,6 +243,13 @@ func (s *currentRoomStateStatements) UpsertRoomState(
}) })
} }
func minOfInts(a, b int) int {
if a <= b {
return a
}
return b
}
func (s *currentRoomStateStatements) SelectEventsWithEventIDs( func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
@ -250,13 +257,24 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
for k, v := range eventIDs { for k, v := range eventIDs {
iEventIDs[k] = v iEventIDs[k] = v
} }
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) res := make([]types.StreamEvent, 0, len(eventIDs))
rows, err := txn.QueryContext(ctx, query, iEventIDs...) var start int
for start < len(eventIDs) {
n := minOfInts(len(eventIDs)-start, 999)
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") start = start + n
return rowsToStreamEvents(rows) events, err := rowsToStreamEvents(rows)
internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed")
if err != nil {
return nil, err
}
res = append(res, events...)
}
return res, nil
} }
func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) {