Use Unique from github.com/matrix-org/util (#28)
* Update github.com/matrix-org/util * Use Unique from github.com/matrix-org/utilmain
parent
8ba9d4af04
commit
84682b33c9
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"sort"
|
"sort"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,9 +89,8 @@ func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.R
|
||||||
// Collect all the entries with the same type and key together.
|
// Collect all the entries with the same type and key together.
|
||||||
// We don't care about the order here because the conflict resolution
|
// We don't care about the order here because the conflict resolution
|
||||||
// algorithm doesn't depend on the order of the prev events.
|
// algorithm doesn't depend on the order of the prev events.
|
||||||
sort.Sort(stateEntrySorter(combined))
|
|
||||||
// Remove duplicate entires.
|
// Remove duplicate entires.
|
||||||
combined = combined[:unique(stateEntrySorter(combined))]
|
combined = combined[:util.SortAndUnique(stateEntrySorter(combined))]
|
||||||
|
|
||||||
// Find the conflicts
|
// Find the conflicts
|
||||||
conflicts := findDuplicateStateKeys(combined)
|
conflicts := findDuplicateStateKeys(combined)
|
||||||
|
@ -202,7 +202,7 @@ func loadStateAtSnapshot(db RoomEventDatabase, stateNID types.StateSnapshotNID)
|
||||||
// remains later in the list than the older entries for the same state key.
|
// remains later in the list than the older entries for the same state key.
|
||||||
sort.Stable(stateEntryByStateKeySorter(fullState))
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
// Unique returns the last entry and hence the most recent entry for each state key.
|
// Unique returns the last entry and hence the most recent entry for each state key.
|
||||||
fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))]
|
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
|
||||||
return fullState, nil
|
return fullState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +270,7 @@ func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.State
|
||||||
// remains later in the list than the older entries for the same state key.
|
// remains later in the list than the older entries for the same state key.
|
||||||
sort.Stable(stateEntryByStateKeySorter(fullState))
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
// Unique returns the last entry and hence the most recent entry for each state key.
|
// Unique returns the last entry and hence the most recent entry for each state key.
|
||||||
fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))]
|
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
|
||||||
// Add the full state for this StateSnapshotNID.
|
// Add the full state for this StateSnapshotNID.
|
||||||
combined = append(combined, fullState...)
|
combined = append(combined, fullState...)
|
||||||
}
|
}
|
||||||
|
@ -357,8 +357,7 @@ func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||||
func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
|
func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
|
||||||
sort.Sort(stateNIDSorter(nids))
|
return nids[:util.SortAndUnique(stateNIDSorter(nids))]
|
||||||
return nids[:unique(stateNIDSorter(nids))]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type stateBlockNIDSorter []types.StateBlockNID
|
type stateBlockNIDSorter []types.StateBlockNID
|
||||||
|
@ -368,37 +367,5 @@ func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||||
func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
|
func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
|
||||||
sort.Sort(stateBlockNIDSorter(nids))
|
return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))]
|
||||||
return nids[:unique(stateBlockNIDSorter(nids))]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove duplicate items from a sorted list.
|
|
||||||
// Takes the same interface as sort.Sort
|
|
||||||
// Returns the length of the data without duplicates
|
|
||||||
// Uses the last occurance of a duplicate.
|
|
||||||
// O(n).
|
|
||||||
func unique(data sort.Interface) int {
|
|
||||||
if data.Len() == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
length := data.Len()
|
|
||||||
// j is the next index to output an element to.
|
|
||||||
j := 0
|
|
||||||
for i := 1; i < length; i++ {
|
|
||||||
// If the previous element is less than this element then they are
|
|
||||||
// not equal. Otherwise they must be equal because the list is sorted.
|
|
||||||
// If they are equal then we move onto the next element.
|
|
||||||
if data.Less(i-1, i) {
|
|
||||||
// "Write" the previous element to the output position by swaping
|
|
||||||
// the elements.
|
|
||||||
// Note that if the list has no duplicates then i-1 == j so the
|
|
||||||
// swap does nothing. (This assumes that data.Swap(a,b) nops if a==b)
|
|
||||||
data.Swap(i-1, j)
|
|
||||||
// Advance to the next output position in the list.
|
|
||||||
j++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Output the last element.
|
|
||||||
data.Swap(length-1, j)
|
|
||||||
return j + 1
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,32 +5,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sortBytes []byte
|
|
||||||
|
|
||||||
func (s sortBytes) Len() int { return len(s) }
|
|
||||||
func (s sortBytes) Less(i, j int) bool { return s[i] < s[j] }
|
|
||||||
func (s sortBytes) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
||||||
func TestUnique(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
Input string
|
|
||||||
Want string
|
|
||||||
}{
|
|
||||||
{"", ""},
|
|
||||||
{"abc", "abc"},
|
|
||||||
{"aaabbbccc", "abc"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range testCases {
|
|
||||||
input := []byte(test.Input)
|
|
||||||
want := string(test.Want)
|
|
||||||
got := string(input[:unique(sortBytes(input))])
|
|
||||||
if got != want {
|
|
||||||
t.Fatal("Wanted ", want, " got ", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindDuplicateStateKeys(t *testing.T) {
|
func TestFindDuplicateStateKeys(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
Input []types.StateEntry
|
Input []types.StateEntry
|
||||||
|
|
|
@ -98,7 +98,7 @@
|
||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/util",
|
"importpath": "github.com/matrix-org/util",
|
||||||
"repository": "https://github.com/matrix-org/util",
|
"repository": "https://github.com/matrix-org/util",
|
||||||
"revision": "28bd7491c8aafbf346ca23821664f0f9911ef52b",
|
"revision": "ec8896cd7d9ba6de6143c5f123d1e45413657e7d",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -80,7 +80,7 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, req *http.Request) {
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
|
logger := GetLogger(req.Context())
|
||||||
logger.WithFields(log.Fields{
|
logger.WithFields(log.Fields{
|
||||||
"panic": r,
|
"panic": r,
|
||||||
}).Errorf(
|
}).Errorf(
|
||||||
|
@ -108,7 +108,7 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc {
|
||||||
ctx = context.WithValue(ctx, ctxValueRequestID, reqID)
|
ctx = context.WithValue(ctx, ctxValueRequestID, reqID)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
|
logger := GetLogger(req.Context())
|
||||||
logger.Print("Incoming request")
|
logger.Print("Incoming request")
|
||||||
|
|
||||||
res := handler.OnIncomingRequest(req)
|
res := handler.OnIncomingRequest(req)
|
||||||
|
@ -122,7 +122,7 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) {
|
func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) {
|
||||||
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
|
logger := GetLogger(req.Context())
|
||||||
|
|
||||||
// Set custom headers
|
// Set custom headers
|
||||||
if res.Headers != nil {
|
if res.Headers != nil {
|
||||||
|
|
|
@ -194,6 +194,28 @@ func TestProtect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProtectWithoutLogger(t *testing.T) {
|
||||||
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
||||||
|
mockWriter := httptest.NewRecorder()
|
||||||
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
|
||||||
|
h := Protect(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
panic("oh noes!")
|
||||||
|
})
|
||||||
|
|
||||||
|
h(mockWriter, mockReq)
|
||||||
|
|
||||||
|
expectCode := 500
|
||||||
|
if mockWriter.Code != expectCode {
|
||||||
|
t.Errorf("TestProtect wanted HTTP status %d, got %d", expectCode, mockWriter.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectBody := `{"message":"Internal Server Error"}`
|
||||||
|
actualBody := mockWriter.Body.String()
|
||||||
|
if actualBody != expectBody {
|
||||||
|
t.Errorf("TestProtect wanted body %s, got %s", expectBody, actualBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithCORSOptions(t *testing.T) {
|
func TestWithCORSOptions(t *testing.T) {
|
||||||
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
||||||
mockWriter := httptest.NewRecorder()
|
mockWriter := httptest.NewRecorder()
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Unique removes duplicate items from a sorted list in place.
|
||||||
|
// Takes the same interface as sort.Sort
|
||||||
|
// Returns the length of the data without duplicates
|
||||||
|
// Uses the last occurrence of a duplicate.
|
||||||
|
// O(n).
|
||||||
|
func Unique(data sort.Interface) int {
|
||||||
|
if !sort.IsSorted(data) {
|
||||||
|
panic(fmt.Errorf("util: the input to Unique() must be sorted"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.Len() == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
length := data.Len()
|
||||||
|
// j is the next index to output an element to.
|
||||||
|
j := 0
|
||||||
|
for i := 1; i < length; i++ {
|
||||||
|
// If the previous element is less than this element then they are
|
||||||
|
// not equal. Otherwise they must be equal because the list is sorted.
|
||||||
|
// If they are equal then we move onto the next element.
|
||||||
|
if data.Less(i-1, i) {
|
||||||
|
// "Write" the previous element to the output position by swapping
|
||||||
|
// the elements.
|
||||||
|
// Note that if the list has no duplicates then i-1 == j so the
|
||||||
|
// swap does nothing. (This assumes that data.Swap(a,b) nops if a==b)
|
||||||
|
data.Swap(i-1, j)
|
||||||
|
// Advance to the next output position in the list.
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Output the last element.
|
||||||
|
data.Swap(length-1, j)
|
||||||
|
return j + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SortAndUnique sorts a list and removes duplicate entries in place.
|
||||||
|
// Takes the same interface as sort.Sort
|
||||||
|
// Returns the length of the data without duplicates
|
||||||
|
// Uses the last occurrence of a duplicate.
|
||||||
|
// O(nlog(n))
|
||||||
|
func SortAndUnique(data sort.Interface) int {
|
||||||
|
sort.Sort(data)
|
||||||
|
return Unique(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UniqueStrings turns a list of strings into a sorted list of unique strings.
|
||||||
|
// O(nlog(n))
|
||||||
|
func UniqueStrings(strings []string) []string {
|
||||||
|
return strings[:SortAndUnique(sort.StringSlice(strings))]
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sortBytes []byte
|
||||||
|
|
||||||
|
func (s sortBytes) Len() int { return len(s) }
|
||||||
|
func (s sortBytes) Less(i, j int) bool { return s[i] < s[j] }
|
||||||
|
func (s sortBytes) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
|
func TestUnique(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Input string
|
||||||
|
Want string
|
||||||
|
}{
|
||||||
|
{"", ""},
|
||||||
|
{"abc", "abc"},
|
||||||
|
{"aaabbbccc", "abc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
input := []byte(test.Input)
|
||||||
|
want := string(test.Want)
|
||||||
|
got := string(input[:Unique(sortBytes(input))])
|
||||||
|
if got != want {
|
||||||
|
t.Fatal("Wanted ", want, " got ", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sortByFirstByte []string
|
||||||
|
|
||||||
|
func (s sortByFirstByte) Len() int { return len(s) }
|
||||||
|
func (s sortByFirstByte) Less(i, j int) bool { return s[i][0] < s[j][0] }
|
||||||
|
func (s sortByFirstByte) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
|
func TestUniquePicksLastDuplicate(t *testing.T) {
|
||||||
|
input := []string{
|
||||||
|
"aardvark",
|
||||||
|
"avacado",
|
||||||
|
"cat",
|
||||||
|
"cucumber",
|
||||||
|
}
|
||||||
|
want := []string{
|
||||||
|
"avacado",
|
||||||
|
"cucumber",
|
||||||
|
}
|
||||||
|
got := input[:Unique(sortByFirstByte(input))]
|
||||||
|
|
||||||
|
if len(want) != len(got) {
|
||||||
|
t.Errorf("Wanted %#v got %#v", want, got)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if want[i] != got[i] {
|
||||||
|
t.Errorf("Wanted %#v got %#v", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniquePanicsIfNotSorted(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Error("Expected Unique() to panic on unsorted input but it didn't")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
Unique(sort.StringSlice{"out", "of", "order"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniqueStrings(t *testing.T) {
|
||||||
|
input := []string{
|
||||||
|
"badger", "badger", "badger", "badger",
|
||||||
|
"badger", "badger", "badger", "badger",
|
||||||
|
"badger", "badger", "badger", "badger",
|
||||||
|
"mushroom", "mushroom",
|
||||||
|
"badger", "badger", "badger", "badger",
|
||||||
|
"badger", "badger", "badger", "badger",
|
||||||
|
"badger", "badger", "badger", "badger",
|
||||||
|
"snake", "snake",
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{"badger", "mushroom", "snake"}
|
||||||
|
|
||||||
|
got := UniqueStrings(input)
|
||||||
|
|
||||||
|
if len(want) != len(got) {
|
||||||
|
t.Errorf("Wanted %#v got %#v", want, got)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if want[i] != got[i] {
|
||||||
|
t.Errorf("Wanted %#v got %#v", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue