159 lines
4.4 KiB
Go
159 lines
4.4 KiB
Go
|
// Copyright 2017 Vector Creations Ltd
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
package test
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/matrix-org/gomatrixserverlib"
|
||
|
)
|
||
|
|
||
|
// Request contains the information necessary to issue a request and test its result
|
||
|
type Request struct {
|
||
|
Req *http.Request
|
||
|
WantedBody string
|
||
|
WantedStatusCode int
|
||
|
LastErr *LastRequestErr
|
||
|
}
|
||
|
|
||
|
// LastRequestErr is a synchronised error wrapper
|
||
|
// Useful for obtaining the last error from a set of requests
|
||
|
type LastRequestErr struct {
|
||
|
sync.Mutex
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
// Set sets the error
|
||
|
func (r *LastRequestErr) Set(err error) {
|
||
|
r.Lock()
|
||
|
defer r.Unlock()
|
||
|
r.Err = err
|
||
|
}
|
||
|
|
||
|
// Get gets the error
|
||
|
func (r *LastRequestErr) Get() error {
|
||
|
r.Lock()
|
||
|
defer r.Unlock()
|
||
|
return r.Err
|
||
|
}
|
||
|
|
||
|
// CanonicalJSONInput canonicalises a slice of JSON strings
|
||
|
// Useful for test input
|
||
|
func CanonicalJSONInput(jsonData []string) []string {
|
||
|
for i := range jsonData {
|
||
|
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
|
||
|
if err != nil && err != io.EOF {
|
||
|
panic(err)
|
||
|
}
|
||
|
jsonData[i] = string(jsonBytes)
|
||
|
}
|
||
|
return jsonData
|
||
|
}
|
||
|
|
||
|
// Do issues a request and checks the status code and body of the response
|
||
|
func (r *Request) Do() (err error) {
|
||
|
client := &http.Client{
|
||
|
Timeout: 5 * time.Second,
|
||
|
Transport: &http.Transport{
|
||
|
TLSClientConfig: &tls.Config{
|
||
|
InsecureSkipVerify: true,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
res, err := client.Do(r.Req)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer (func() { err = res.Body.Close() })()
|
||
|
|
||
|
if res.StatusCode != r.WantedStatusCode {
|
||
|
return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode)
|
||
|
}
|
||
|
|
||
|
if r.WantedBody != "" {
|
||
|
resBytes, err := ioutil.ReadAll(res.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if string(jsonBytes) != r.WantedBody {
|
||
|
return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body.
|
||
|
// It then closes the given channel and returns.
|
||
|
func (r *Request) DoUntilSuccess(done chan error) {
|
||
|
r.LastErr = &LastRequestErr{}
|
||
|
for {
|
||
|
if err := r.Do(); err != nil {
|
||
|
r.LastErr.Set(err)
|
||
|
time.Sleep(1 * time.Second) // don't tightloop
|
||
|
continue
|
||
|
}
|
||
|
close(done)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Run repeatedly issues a request until success, error or a timeout is reached
|
||
|
func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) {
|
||
|
fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout)
|
||
|
done := make(chan error, 1)
|
||
|
|
||
|
// We need to wait for the server to:
|
||
|
// - have connected to the database
|
||
|
// - have created the tables
|
||
|
// - be listening on the given port
|
||
|
go r.DoUntilSuccess(done)
|
||
|
|
||
|
// wait for one of:
|
||
|
// - the test to pass (done channel is closed)
|
||
|
// - the server to exit with an error (error sent on serverCmdChan)
|
||
|
// - our test timeout to expire
|
||
|
// We don't need to clean up since the main() function handles that in the event we panic
|
||
|
select {
|
||
|
case <-time.After(timeout):
|
||
|
fmt.Printf("==TESTING== %v TIMEOUT\n", label)
|
||
|
if reqErr := r.LastErr.Get(); reqErr != nil {
|
||
|
fmt.Println("Last /sync request error:")
|
||
|
fmt.Println(reqErr)
|
||
|
}
|
||
|
panic(fmt.Sprintf("%v server timed out", label))
|
||
|
case err := <-serverCmdChan:
|
||
|
if err != nil {
|
||
|
fmt.Println("=============================================================================================")
|
||
|
fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label)
|
||
|
fmt.Println(" export PGHOST=/var/run/postgresql")
|
||
|
fmt.Println("=============================================================================================")
|
||
|
panic(err)
|
||
|
}
|
||
|
case <-done:
|
||
|
fmt.Printf("==TESTING== %v PASSED\n", label)
|
||
|
}
|
||
|
}
|