diff --git a/serverkeyapi/api/api.go b/serverkeyapi/api/api.go index f108d437..a43f9c0a 100644 --- a/serverkeyapi/api/api.go +++ b/serverkeyapi/api/api.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "time" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/gomatrixserverlib" @@ -69,9 +70,12 @@ func (s *httpServerKeyInternalAPI) FetcherName() string { } func (s *httpServerKeyInternalAPI) StoreKeys( - ctx context.Context, + _ context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) error { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() request := InputPublicKeysRequest{ Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), } @@ -84,9 +88,12 @@ func (s *httpServerKeyInternalAPI) StoreKeys( } func (s *httpServerKeyInternalAPI) FetchKeys( - ctx context.Context, + _ context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) request := QueryPublicKeysRequest{ Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp), @@ -94,8 +101,12 @@ func (s *httpServerKeyInternalAPI) FetchKeys( response := QueryPublicKeysResponse{ Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), } + now := gomatrixserverlib.AsTimestamp(time.Now()) for req, ts := range requests { if res, ok := s.immutableCache.GetServerKey(req); ok { + if now > res.ValidUntilTS && res.ExpiredTS == gomatrixserverlib.PublicKeyNotExpired { + continue + } result[req] = res continue } diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go index cbae59d9..c63b23f0 100644 --- a/serverkeyapi/internal/api.go +++ b/serverkeyapi/internal/api.go @@ -3,6 +3,7 @@ package internal import ( "context" "fmt" + "time" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/serverkeyapi/api" @@ -24,25 +25,35 @@ func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { } func (s *ServerKeyAPI) StoreKeys( - ctx context.Context, + _ context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, ) error { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() // Store any keys that we were given in our database. return s.OurKeyRing.KeyDatabase.StoreKeys(ctx, results) } func (s *ServerKeyAPI) FetchKeys( - ctx context.Context, + _ context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + // Run in a background context - we don't want to stop this work just + // because the caller gives up waiting. + ctx := context.Background() results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} // First consult our local database and see if we have the requested // keys. These might come from a cache, depending on the database // implementation used. + now := gomatrixserverlib.AsTimestamp(time.Now()) if dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests); err == nil { // We successfully got some keys. Add them to the results and // remove them from the request list. for req, res := range dbResults { + if now > res.ValidUntilTS && res.ExpiredTS == gomatrixserverlib.PublicKeyNotExpired { + continue + } results[req] = res delete(requests, req) } @@ -61,6 +72,9 @@ func (s *ServerKeyAPI) FetchKeys( results[req] = res delete(requests, req) } + if err = s.OurKeyRing.KeyDatabase.StoreKeys(ctx, fetcherResults); err != nil { + return nil, fmt.Errorf("server key API failed to store retrieved keys: %w", err) + } } } // If we failed to fetch any keys then we should report an error. diff --git a/serverkeyapi/internal/http.go b/serverkeyapi/internal/http.go index 30327571..eef66b76 100644 --- a/serverkeyapi/internal/http.go +++ b/serverkeyapi/internal/http.go @@ -14,28 +14,16 @@ import ( func (s *ServerKeyAPI) SetupHTTP(internalAPIMux *mux.Router) { internalAPIMux.Handle(api.ServerKeyQueryPublicKeyPath, internal.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { - result := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} request := api.QueryPublicKeysRequest{} response := api.QueryPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - lookup := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) - for req, timestamp := range request.Requests { - if res, ok := s.ImmutableCache.GetServerKey(req); ok { - result[req] = res - continue - } - lookup[req] = timestamp - } - keys, err := s.FetchKeys(req.Context(), lookup) + keys, err := s.FetchKeys(req.Context(), request.Requests) if err != nil { return util.ErrorResponse(err) } - for req, res := range keys { - result[req] = res - } - response.Results = result + response.Results = keys return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), )