// Copyright 2017 Vector Creations Ltd // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package test import ( "crypto/tls" "fmt" "io" "io/ioutil" "net/http" "sync" "time" "github.com/matrix-org/gomatrixserverlib" ) // Request contains the information necessary to issue a request and test its result type Request struct { Req *http.Request WantedBody string WantedStatusCode int LastErr *LastRequestErr } // LastRequestErr is a synchronised error wrapper // Useful for obtaining the last error from a set of requests type LastRequestErr struct { sync.Mutex Err error } // Set sets the error func (r *LastRequestErr) Set(err error) { r.Lock() defer r.Unlock() r.Err = err } // Get gets the error func (r *LastRequestErr) Get() error { r.Lock() defer r.Unlock() return r.Err } // CanonicalJSONInput canonicalises a slice of JSON strings // Useful for test input func CanonicalJSONInput(jsonData []string) []string { for i := range jsonData { jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i])) if err != nil && err != io.EOF { panic(err) } jsonData[i] = string(jsonBytes) } return jsonData } // Do issues a request and checks the status code and body of the response func (r *Request) Do() (err error) { client := &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } res, err := client.Do(r.Req) if err != nil { return err } defer (func() { err = res.Body.Close() })() if res.StatusCode != r.WantedStatusCode { return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode) } if r.WantedBody != "" { resBytes, err := ioutil.ReadAll(res.Body) if err != nil { return err } jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes) if err != nil { return err } if string(jsonBytes) != r.WantedBody { return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes)) } } return nil } // DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body. // It then closes the given channel and returns. func (r *Request) DoUntilSuccess(done chan error) { r.LastErr = &LastRequestErr{} for { if err := r.Do(); err != nil { r.LastErr.Set(err) time.Sleep(1 * time.Second) // don't tightloop continue } close(done) return } } // Run repeatedly issues a request until success, error or a timeout is reached func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) { fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout) done := make(chan error, 1) // We need to wait for the server to: // - have connected to the database // - have created the tables // - be listening on the given port go r.DoUntilSuccess(done) // wait for one of: // - the test to pass (done channel is closed) // - the server to exit with an error (error sent on serverCmdChan) // - our test timeout to expire // We don't need to clean up since the main() function handles that in the event we panic select { case <-time.After(timeout): fmt.Printf("==TESTING== %v TIMEOUT\n", label) if reqErr := r.LastErr.Get(); reqErr != nil { fmt.Println("Last /sync request error:") fmt.Println(reqErr) } panic(fmt.Sprintf("%v server timed out", label)) case err := <-serverCmdChan: if err != nil { fmt.Println("=============================================================================================") fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label) fmt.Println(" export PGHOST=/var/run/postgresql") fmt.Println("=============================================================================================") panic(err) } case <-done: fmt.Printf("==TESTING== %v PASSED\n", label) } }