gb vendor fetch github.com/matrix-org/gomatrixserverlib/

main
Mark Haines 2017-02-06 14:56:16 +00:00
parent 5b5c2091bf
commit a45a824f41
28 changed files with 8180 additions and 1 deletions

15
vendor/manifest vendored
View File

@ -71,6 +71,12 @@
"revision": "a6657b2386e9b8be76484c08711b02c7cf867ead",
"branch": "master"
},
{
"importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "48ee56a33d195dc412dd919a0e81af70c9aaf4a3",
"branch": "master"
},
{
"importpath": "github.com/matrix-org/util",
"repository": "https://github.com/matrix-org/util",
@ -149,6 +155,13 @@
"revision": "61e43dc76f7ee59a82bdf3d71033dc12bea4c77d",
"branch": "master"
},
{
"importpath": "golang.org/x/crypto/ed25519",
"repository": "https://go.googlesource.com/crypto",
"revision": "77014cf7f9bde4925afeed52b7bf676d5f5b4285",
"branch": "master",
"path": "/ed25519"
},
{
"importpath": "golang.org/x/net/context",
"repository": "https://go.googlesource.com/net",
@ -175,4 +188,4 @@
"branch": "master"
}
]
}
}

View File

@ -0,0 +1,5 @@
gomatrixserverlib
=================
[![GoDoc](https://godoc.org/github.com/matrix-org/gomatrixserverlib?status.svg)](https://godoc.org/github.com/matrix-org/gomatrixserverlib)
Go library for common functions needed by matrix servers.

View File

@ -0,0 +1,48 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/base64"
"encoding/json"
)
// A Base64String is a string of bytes that are base64 encoded when used in JSON.
// The bytes encoded using base64 when marshalled as JSON.
// When the bytes are unmarshalled from JSON they are decoded from base64.
type Base64String []byte
// MarshalJSON encodes the bytes as base64 and then encodes the base64 as a JSON string.
// This takes a value receiver so that maps and slices of Base64String encode correctly.
func (b64 Base64String) MarshalJSON() ([]byte, error) {
// This could be made more efficient by using base64.RawStdEncoding.Encode
// to write the base64 directly to the JSON. We don't need to JSON escape
// any of the characters used in base64.
return json.Marshal(base64.RawStdEncoding.EncodeToString(b64))
}
// UnmarshalJSON decodes a JSON string and then decodes the resulting base64.
// This takes a pointer receiver because it needs to write the result of decoding.
func (b64 *Base64String) UnmarshalJSON(raw []byte) (err error) {
// We could add a fast path that used base64.RawStdEncoding.Decode
// directly on the raw JSON if the JSON didn't contain any escapes.
var str string
if err = json.Unmarshal(raw, &str); err != nil {
return
}
*b64, err = base64.RawStdEncoding.DecodeString(str)
return
}

View File

@ -0,0 +1,82 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"testing"
)
func TestMarshalBase64(t *testing.T) {
input := Base64String("this\xffis\xffa\xfftest")
want := `"dGhpc/9pc/9h/3Rlc3Q"`
got, err := json.Marshal(input)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("json.Marshal(Base64String(%q)): wanted %q got %q", string(input), want, string(got))
}
}
func TestUnmarshalBase64(t *testing.T) {
input := []byte(`"dGhpc/9pc/9h/3Rlc3Q"`)
want := "this\xffis\xffa\xfftest"
var got Base64String
err := json.Unmarshal(input, &got)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("json.Unmarshal(%q): wanted %q got %q", string(input), want, string(got))
}
}
func TestMarshalBase64Struct(t *testing.T) {
input := struct{ Value Base64String }{Base64String("this\xffis\xffa\xfftest")}
want := `{"Value":"dGhpc/9pc/9h/3Rlc3Q"}`
got, err := json.Marshal(input)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got))
}
}
func TestMarshalBase64Map(t *testing.T) {
input := map[string]Base64String{"Value": Base64String("this\xffis\xffa\xfftest")}
want := `{"Value":"dGhpc/9pc/9h/3Rlc3Q"}`
got, err := json.Marshal(input)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got))
}
}
func TestMarshalBase64Slice(t *testing.T) {
input := []Base64String{Base64String("this\xffis\xffa\xfftest")}
want := `["dGhpc/9pc/9h/3Rlc3Q"]`
got, err := json.Marshal(input)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got))
}
}

View File

@ -0,0 +1,143 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
)
// A Client makes request to the federation listeners of matrix
// homeservers
type Client struct {
client http.Client
}
// UserInfo represents information about a user.
type UserInfo struct {
Sub string `json:"sub"`
}
// NewClient makes a new Client
func NewClient() *Client {
// TODO: Verify ceritificates
tripper := federationTripper{
transport: &http.Transport{
// Set our own DialTLS function to avoid the default net/http SNI.
// By default net/http and crypto/tls set the SNI to the target host.
// By avoiding the default implementation we can keep the ServerName
// as the empty string so that crypto/tls doesn't add SNI.
DialTLS: func(network, addr string) (net.Conn, error) {
rawconn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
// Wrap a raw connection ourselves since tls.Dial defaults the SNI
conn := tls.Client(rawconn, &tls.Config{
ServerName: "",
// TODO: We should be checking that the TLS certificate we see here matches
// one of the allowed SHA-256 fingerprints for the server.
InsecureSkipVerify: true,
})
if err := conn.Handshake(); err != nil {
return nil, err
}
return conn, nil
},
},
}
return &Client{
client: http.Client{Transport: &tripper},
}
}
type federationTripper struct {
transport http.RoundTripper
}
func makeHTTPSURL(u *url.URL, addr string) (httpsURL url.URL) {
httpsURL = *u
httpsURL.Scheme = "https"
httpsURL.Host = addr
return
}
func (f *federationTripper) RoundTrip(r *http.Request) (*http.Response, error) {
host := r.URL.Host
dnsResult, err := LookupServer(host)
if err != nil {
return nil, err
}
var resp *http.Response
for _, addr := range dnsResult.Addrs {
u := makeHTTPSURL(r.URL, addr)
r.URL = &u
resp, err = f.transport.RoundTrip(r)
if err == nil {
return resp, nil
}
}
return nil, fmt.Errorf("no address found for matrix host %v", host)
}
// LookupUserInfo gets information about a user from a given matrix homeserver
// using a bearer access token.
func (fc *Client) LookupUserInfo(matrixServer, token string) (u UserInfo, err error) {
url := url.URL{
Scheme: "matrix",
Host: matrixServer,
Path: "/_matrix/federation/v1/openid/userinfo",
RawQuery: url.Values{"access_token": []string{token}}.Encode(),
}
var response *http.Response
response, err = fc.client.Get(url.String())
if response != nil {
defer response.Body.Close()
}
if err != nil {
return
}
if response.StatusCode < 200 || response.StatusCode >= 300 {
var errorOutput []byte
errorOutput, err = ioutil.ReadAll(response.Body)
if err != nil {
return
}
err = fmt.Errorf("HTTP %d : %s", response.StatusCode, errorOutput)
return
}
err = json.NewDecoder(response.Body).Decode(&u)
if err != nil {
return
}
userParts := strings.SplitN(u.Sub, ":", 2)
if len(userParts) != 2 || userParts[1] != matrixServer {
err = fmt.Errorf("userID doesn't match server name '%v' != '%v'", u.Sub, matrixServer)
return
}
return
}

View File

@ -0,0 +1,360 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"fmt"
"golang.org/x/crypto/ed25519"
"time"
)
// An EventReference is a reference to a matrix event.
type EventReference struct {
// The event ID of the event.
EventID string
// The sha256 of the redacted event.
EventSHA256 Base64String
}
// An EventBuilder is used to build a new event.
type EventBuilder struct {
// The user ID of the user sending the event.
Sender string `json:"sender"`
// The room ID of the room this event is in.
RoomID string `json:"room_id"`
// The type of the event.
Type string `json:"type"`
// The state_key of the event if the event is a state event or nil if the event is not a state event.
StateKey *string `json:"state_key,omitempty"`
// The events that immediately preceeded this event in the room history.
PrevEvents []EventReference `json:"prev_events"`
// The events needed to authenticate this event.
AuthEvents []EventReference `json:"auth_events"`
// The event ID of the event being redacted if this event is a "m.room.redaction".
Redacts string `json:"redacts,omitempty"`
// The depth of the event, This should be one greater than the maximum depth of the previous events.
// The create event has a depth of 1.
Depth int64 `json:"depth"`
content []byte
unsigned []byte
}
// SetContent sets the JSON content key of the event.
func (eb *EventBuilder) SetContent(content interface{}) (err error) {
eb.content, err = json.Marshal(content)
return
}
// SetUnsigned sets the JSON unsigned key of the event.
func (eb *EventBuilder) SetUnsigned(unsigned interface{}) (err error) {
eb.unsigned, err = json.Marshal(unsigned)
return
}
// An Event is a matrix event.
// The event should always contain valid JSON.
// If the event content hash is invalid then the event is redacted.
// Redacted events contain only the fields covered by the event signature.
type Event struct {
redacted bool
eventJSON []byte
fields eventFields
}
type eventFields struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
Sender string `json:"sender"`
Type string `json:"type"`
StateKey *string `json:"state_key"`
Content rawJSON `json:"content"`
PrevEvents []EventReference `json:"prev_events"`
AuthEvents []EventReference `json:"auth_events"`
Redacts string `json:"redacts"`
Depth int64 `json:"depth"`
}
var emptyEventReferenceList = []EventReference{}
// Build a new Event.
// This is used when a local event is created on this server.
// Call this after filling out the necessary fields.
// This can be called mutliple times on the same builder.
// A different event ID must be supplied each time this is called.
func (eb *EventBuilder) Build(eventID string, now time.Time, origin, keyID string, privateKey ed25519.PrivateKey) (result Event, err error) {
var event struct {
EventBuilder
EventID string `json:"event_id"`
RawContent rawJSON `json:"content"`
RawUnsigned rawJSON `json:"unsigned"`
OriginServerTS int64 `json:"origin_server_ts"`
Origin string `json:"origin"`
}
event.EventBuilder = *eb
if event.PrevEvents == nil {
event.PrevEvents = emptyEventReferenceList
}
if event.AuthEvents == nil {
event.AuthEvents = emptyEventReferenceList
}
event.RawContent = rawJSON(event.content)
event.RawUnsigned = rawJSON(event.unsigned)
event.OriginServerTS = now.UnixNano() / 1000000
event.Origin = origin
event.EventID = eventID
// TODO: Check size limits.
var eventJSON []byte
if eventJSON, err = json.Marshal(&event); err != nil {
return
}
if eventJSON, err = addContentHashesToEvent(eventJSON); err != nil {
return
}
if eventJSON, err = signEvent(origin, keyID, privateKey, eventJSON); err != nil {
return
}
if eventJSON, err = CanonicalJSON(eventJSON); err != nil {
return
}
result.eventJSON = eventJSON
err = json.Unmarshal(eventJSON, &result.fields)
return
}
// NewEventFromUntrustedJSON loads a new event from some JSON that may be invalid.
// This checks that the event is valid JSON.
// It also checks the content hashes to ensure the event has not been tampered with.
// This should be used when receiving new events from remote servers.
func NewEventFromUntrustedJSON(eventJSON []byte) (result Event, err error) {
var event map[string]rawJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return
}
// Synapse removes these keys from events in case a server accidentally added them.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/crypto/event_signing.py#L57-L62
delete(event, "outlier")
delete(event, "destinations")
delete(event, "age_ts")
// TODO: Check that the event fields are correctly defined.
// TODO: Check size limits.
if eventJSON, err = json.Marshal(event); err != nil {
return
}
if err = checkEventContentHash(eventJSON); err != nil {
result.redacted = true
// If the content hash doesn't match then we have to discard all non-essential fields
// because they've been tampered with.
if eventJSON, err = redactEvent(eventJSON); err != nil {
return
}
}
if eventJSON, err = CanonicalJSON(eventJSON); err != nil {
return
}
result.eventJSON = eventJSON
err = json.Unmarshal(eventJSON, &result.fields)
return
}
// NewEventFromTrustedJSON loads a new event from some JSON that must be valid.
// This will be more efficient than NewEventFromUntrustedJSON since it can skip cryptographic checks.
// This can be used when loading matrix events from a local database.
func NewEventFromTrustedJSON(eventJSON []byte, redacted bool) (result Event, err error) {
result.redacted = redacted
result.eventJSON = eventJSON
err = json.Unmarshal(eventJSON, &result.fields)
return
}
// Redacted returns whether the event is redacted.
func (e Event) Redacted() bool { return e.redacted }
// JSON returns the JSON bytes for the event.
func (e Event) JSON() []byte { return e.eventJSON }
// Redact returns a redacted copy of the event.
func (e Event) Redact() Event {
if e.redacted {
return e
}
eventJSON, err := redactEvent(e.eventJSON)
if err != nil {
// This is unreachable for events created with EventBuilder.Build or NewEventFromUntrustedJSON
panic(fmt.Errorf("gomatrixserverlib: invalid event %v", err))
}
if eventJSON, err = CanonicalJSON(eventJSON); err != nil {
// This is unreachable for events created with EventBuilder.Build or NewEventFromUntrustedJSON
panic(fmt.Errorf("gomatrixserverlib: invalid event %v", err))
}
return Event{
redacted: true,
eventJSON: eventJSON,
}
}
// EventReference returns an EventReference for the event.
// The reference can be used to refer to this event from other events.
func (e Event) EventReference() EventReference {
reference, err := referenceOfEvent(e.eventJSON)
if err != nil {
// This is unreachable for events created with EventBuilder.Build or NewEventFromUntrustedJSON
// This can be reached if NewEventFromTrustedJSON is given JSON from an untrusted source.
panic(fmt.Errorf("gomatrixserverlib: invalid event %v (%q)", err, string(e.eventJSON)))
}
return reference
}
// Sign returns a copy of the event with an additional signature.
func (e Event) Sign(signingName, keyID string, privateKey ed25519.PrivateKey) Event {
eventJSON, err := signEvent(signingName, keyID, privateKey, e.eventJSON)
if err != nil {
// This is unreachable for events created with EventBuilder.Build or NewEventFromUntrustedJSON
panic(fmt.Errorf("gomatrixserverlib: invalid event %v (%q)", err, string(e.eventJSON)))
}
if eventJSON, err = CanonicalJSON(eventJSON); err != nil {
// This is unreachable for events created with EventBuilder.Build or NewEventFromUntrustedJSON
panic(fmt.Errorf("gomatrixserverlib: invalid event %v (%q)", err, string(e.eventJSON)))
}
return Event{
redacted: e.redacted,
eventJSON: eventJSON,
}
}
// KeyIDs returns a list of key IDs that the named entity has signed the event with.
func (e Event) KeyIDs(signingName string) []string {
var event struct {
Signatures map[string]map[string]rawJSON `json:"signatures"`
}
if err := json.Unmarshal(e.eventJSON, &event); err != nil {
// This should unreachable for events created with EventBuilder.Build or NewEventFromUntrustedJSON
panic(fmt.Errorf("gomatrixserverlib: invalid event %v", err))
}
var keyIDs []string
for keyID := range event.Signatures[signingName] {
keyIDs = append(keyIDs, keyID)
}
return keyIDs
}
// Verify checks a ed25519 signature
func (e Event) Verify(signingName, keyID string, publicKey ed25519.PublicKey) error {
return verifyEventSignature(signingName, keyID, publicKey, e.eventJSON)
}
// StateKey returns the "state_key" of the event, or the nil if the event is not a state event.
func (e Event) StateKey() *string {
return e.fields.StateKey
}
// StateKeyEquals returns true if the event is a state event and the "state_key" matches.
func (e Event) StateKeyEquals(stateKey string) bool {
if e.fields.StateKey == nil {
return false
}
return *e.fields.StateKey == stateKey
}
// EventID returns the event ID of the event.
func (e Event) EventID() string {
return e.fields.EventID
}
// Sender returns the user ID of the sender of the event.
func (e Event) Sender() string {
return e.fields.Sender
}
// Type returns the type of the event.
func (e Event) Type() string {
return e.fields.Type
}
// Content returns the content JSON of the event.
func (e Event) Content() []byte {
return []byte(e.fields.Content)
}
// PrevEvents returns references to the direct ancestors of the event.
func (e Event) PrevEvents() []EventReference {
return e.fields.PrevEvents
}
// AuthEvents returns references to the events needed to auth the event.
func (e Event) AuthEvents() []EventReference {
return e.fields.AuthEvents
}
// Redacts returns the event ID of the event this event redacts.
func (e Event) Redacts() string {
return e.fields.Redacts
}
// RoomID returns the room ID of the room the event is in.
func (e Event) RoomID() string {
return e.fields.RoomID
}
// Depth returns the depth of the event.
func (e Event) Depth() int64 {
return e.fields.Depth
}
// UnmarshalJSON implements json.Unmarshaller
func (er *EventReference) UnmarshalJSON(data []byte) error {
var tuple []rawJSON
if err := json.Unmarshal(data, &tuple); err != nil {
return err
}
if len(tuple) != 2 {
return fmt.Errorf("gomatrixserverlib: invalid event reference, invalid length: %d != 2", len(tuple))
}
if err := json.Unmarshal(tuple[0], &er.EventID); err != nil {
return fmt.Errorf("gomatrixserverlib: invalid event reference, first element is invalid: %q %v", string(tuple[0]), err)
}
var hashes struct {
SHA256 Base64String `json:"sha256"`
}
if err := json.Unmarshal(tuple[1], &hashes); err != nil {
return fmt.Errorf("gomatrixserverlib: invalid event reference, second element is invalid: %q %v", string(tuple[1]), err)
}
er.EventSHA256 = hashes.SHA256
return nil
}
// MarshalJSON implements json.Marshaller
func (er EventReference) MarshalJSON() ([]byte, error) {
hashes := struct {
SHA256 Base64String `json:"sha256"`
}{er.EventSHA256}
tuple := []interface{}{er.EventID, hashes}
return json.Marshal(&tuple)
}

View File

@ -0,0 +1,819 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"fmt"
"sort"
)
const (
join = "join"
ban = "ban"
leave = "leave"
invite = "invite"
public = "public"
)
// StateNeeded lists the event types and state_keys needed to authenticate an event.
type StateNeeded struct {
// Is the m.room.create event needed to auth the event.
Create bool
// Is the m.room.join_rules event needed to auth the event.
JoinRules bool
// Is the m.room.power_levels event needed to auth the event.
PowerLevels bool
// List of m.room.member state_keys needed to auth the event
Member []string
// List of m.room.third_party_invite state_keys
ThirdPartyInvite []string
}
// StateNeededForAuth returns the event types and state_keys needed to authenticate an event.
// This takes a list of events to facilitate bulk processing when doing auth checks as part of state conflict resolution.
func StateNeededForAuth(events []Event) (result StateNeeded) {
var members []string
var thirdpartyinvites []string
for _, event := range events {
switch event.Type() {
case "m.room.create":
// The create event doesn't require any state to authenticate.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L123
case "m.room.aliases":
// Alias events need:
// * The create event.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L128
// Alias events need no further authentication.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L160
result.Create = true
case "m.room.member":
// Member events need:
// * The previous membership of the target.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L355
// * The current membership state of the sender.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L348
// * The join rules for the room if the event is a join event.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L361
// * The power levels for the room.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L370
// * And optionally may require a m.third_party_invite event
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L393
content, err := newMemberContentFromEvent(event)
if err != nil {
// If we hit an error decoding the content we ignore it here.
// The event will be rejected when the actual checks encounter the same error.
continue
}
result.Create = true
result.PowerLevels = true
stateKey := event.StateKey()
if stateKey != nil {
members = append(members, event.Sender(), *stateKey)
}
if content.Membership == join {
result.JoinRules = true
}
if content.ThirdPartyInvite != nil {
token, err := thirdPartyInviteToken(content.ThirdPartyInvite)
if err != nil {
// If we hit an error decoding the content we ignore it here.
// The event will be rejected when the actual checks encounter the same error.
continue
} else {
thirdpartyinvites = append(thirdpartyinvites, token)
}
}
default:
// All other events need:
// * The membership of the sender.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L177
// * The power levels for the room.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L196
result.Create = true
result.PowerLevels = true
members = append(members, event.Sender())
}
}
// Deduplicate the state keys.
sort.Strings(members)
result.Member = members[:unique(sort.StringSlice(members))]
sort.Strings(thirdpartyinvites)
result.ThirdPartyInvite = thirdpartyinvites[:unique(sort.StringSlice(thirdpartyinvites))]
return
}
// Remove duplicate items from a sorted list.
// Takes the same interface as sort.Sort
// Returns the length of the data without duplicates
// Uses the last occurrence of a duplicate.
// O(n).
func unique(data sort.Interface) int {
length := data.Len()
if length == 0 {
return 0
}
j := 0
for i := 1; i < length; i++ {
if data.Less(i-1, i) {
data.Swap(i-1, j)
j++
}
}
data.Swap(length-1, j)
return j + 1
}
// thirdPartyInviteToken extracts the token from the third_party_invite.
func thirdPartyInviteToken(thirdPartyInviteData json.RawMessage) (string, error) {
var thirdPartyInvite struct {
Signed struct {
Token string `json:"token"`
} `json:"signed"`
}
if err := json.Unmarshal(thirdPartyInviteData, &thirdPartyInvite); err != nil {
return "", err
}
if thirdPartyInvite.Signed.Token == "" {
return "", fmt.Errorf("missing 'third_party_invite.signed.token' JSON key")
}
return thirdPartyInvite.Signed.Token, nil
}
// AuthEvents are the state events needed to authenticate an event.
type AuthEvents interface {
// Create returns the m.room.create event for the room.
Create() (*Event, error)
// JoinRules returns the m.room.join_rules event for the room.
JoinRules() (*Event, error)
// PowerLevels returns the m.room.power_levels event for the room.
PowerLevels() (*Event, error)
// Member returns the m.room.member event for the given user_id state_key.
Member(stateKey string) (*Event, error)
// ThirdPartyInvite returns the m.room.third_party_invite event for the
// given state_key
ThirdPartyInvite(stateKey string) (*Event, error)
}
// A NotAllowed error is returned if an event does not pass the auth checks.
type NotAllowed struct {
Message string
}
func (a *NotAllowed) Error() string {
return "eventauth: " + a.Message
}
func errorf(message string, args ...interface{}) error {
return &NotAllowed{Message: fmt.Sprintf(message, args...)}
}
// Allowed checks whether an event is allowed by the auth events.
// It returns a NotAllowed error if the event is not allowed.
// If there was an error loading the auth events then it returns that error.
func Allowed(event Event, authEvents AuthEvents) error {
switch event.Type() {
case "m.room.create":
return createEventAllowed(event)
case "m.room.aliases":
return aliasEventAllowed(event, authEvents)
case "m.room.member":
return memberEventAllowed(event, authEvents)
case "m.room.power_levels":
return powerLevelsEventAllowed(event, authEvents)
case "m.room.redaction":
return redactEventAllowed(event, authEvents)
default:
return defaultEventAllowed(event, authEvents)
}
}
// createEventAllowed checks whether the m.room.create event is allowed.
// It returns an error if the event is not allowed.
func createEventAllowed(event Event) error {
if !event.StateKeyEquals("") {
return errorf("create event state key is not empty: %v", event.StateKey())
}
roomIDDomain, err := domainFromID(event.RoomID())
if err != nil {
return err
}
senderDomain, err := domainFromID(event.Sender())
if err != nil {
return err
}
if senderDomain != roomIDDomain {
return errorf("create event room ID domain does not match sender: %q != %q", roomIDDomain, senderDomain)
}
if len(event.PrevEvents()) > 0 {
return errorf("create event must be the first event in the room: found %d prev_events", len(event.PrevEvents()))
}
return nil
}
// memberEventAllowed checks whether the m.room.member event is allowed.
// Membership events have different authentication rules to ordinary events.
func memberEventAllowed(event Event, authEvents AuthEvents) error {
allower, err := newMembershipAllower(authEvents, event)
if err != nil {
return err
}
return allower.membershipAllowed(event)
}
// aliasEventAllowed checks whether the m.room.aliases event is allowed.
// Alias events have different authentication rules to ordinary events.
func aliasEventAllowed(event Event, authEvents AuthEvents) error {
// The alias events have different auth rules to ordinary events.
// In particular we allow any server to send a m.room.aliases event without checking if the sender is in the room.
// This allows server admins to update the m.room.aliases event for their server when they change the aliases on their server.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L143-L160
create, err := newCreateContentFromAuthEvents(authEvents)
senderDomain, err := domainFromID(event.Sender())
if err != nil {
return err
}
if event.RoomID() != create.roomID {
return errorf("create event has different roomID: %q != %q", event.RoomID(), create.roomID)
}
// Check that server is allowed in the room by the m.room.federate flag.
if err := create.domainAllowed(senderDomain); err != nil {
return err
}
// Check that event is a state event.
// Check that the state key matches the server sending this event.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L158
if !event.StateKeyEquals(senderDomain) {
return errorf("alias state_key does not match sender domain, %q != %q", senderDomain, event.StateKey())
}
return nil
}
// powerLevelsEventAllowed checks whether the m.room.power_levels event is allowed.
// It returns an error if the event is not allowed or if there was a problem
// loading the auth events needed.
func powerLevelsEventAllowed(event Event, authEvents AuthEvents) error {
allower, err := newEventAllower(authEvents, event.Sender())
if err != nil {
return err
}
// power level events must pass the default checks.
// These checks will catch if the user has a high enough level to set a m.room.power_levels state event.
if err = allower.commonChecks(event); err != nil {
return err
}
// Parse the power levels.
newPowerLevels, err := newPowerLevelContentFromEvent(event)
if err != nil {
return err
}
// Check that the user levels are all valid user IDs
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L1063
for userID := range newPowerLevels.userLevels {
if !isValidUserID(userID) {
return errorf("Not a valid user ID: %q", userID)
}
}
// Grab the old power level event so that we can check if the event existed.
var oldEvent *Event
if oldEvent, err = authEvents.PowerLevels(); err != nil {
return err
} else if oldEvent == nil {
// If this is the first power level event then it can set the levels to
// any value it wants to.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L1074
return nil
}
// Grab the old levels so that we can compare new the levels against them.
oldPowerLevels := allower.powerLevels
senderLevel := oldPowerLevels.userLevel(event.Sender())
// Check that the changes in event levels are allowed.
if err = checkEventLevels(senderLevel, oldPowerLevels, newPowerLevels); err != nil {
return err
}
// Check that the changes in user levels are allowed.
return checkUserLevels(senderLevel, event.Sender(), oldPowerLevels, newPowerLevels)
}
// checkEventLevels checks that the changes in event levels are allowed.
func checkEventLevels(senderLevel int64, oldPowerLevels, newPowerLevels powerLevelContent) error {
type levelPair struct {
old int64
new int64
}
// Build a list of event levels to check.
// This differs slightly in behaviour from the code in synapse because it will use the
// default value if a level is not present in one of the old or new events.
// First add all the named levels.
levelChecks := []levelPair{
{oldPowerLevels.banLevel, newPowerLevels.banLevel},
{oldPowerLevels.inviteLevel, newPowerLevels.inviteLevel},
{oldPowerLevels.kickLevel, newPowerLevels.kickLevel},
{oldPowerLevels.redactLevel, newPowerLevels.redactLevel},
{oldPowerLevels.stateDefaultLevel, newPowerLevels.stateDefaultLevel},
{oldPowerLevels.eventDefaultLevel, newPowerLevels.eventDefaultLevel},
}
// Then add checks for each event key in the new levels.
// We use the default values for non-state events when applying the checks.
// TODO: the per event levels do not distinguish between state and non-state events.
// However the default values do make that distinction. We may want to change this.
// For example if there is an entry for "my.custom.type" events it sets the level
// for sending the event with and without a "state_key". But if there is no entry
// for "my.custom.type it will use the state default when sent with a "state_key"
// and will use the event default when sent without.
const (
isStateEvent = false
)
for eventType := range newPowerLevels.eventLevels {
levelChecks = append(levelChecks, levelPair{
oldPowerLevels.eventLevel(eventType, isStateEvent),
newPowerLevels.eventLevel(eventType, isStateEvent),
})
}
// Then add checks for each event key in the old levels.
// Some of these will be duplicates of the ones added using the keys from
// the new levels. But it doesn't hurt to run the checks twice for the same level.
for eventType := range oldPowerLevels.eventLevels {
levelChecks = append(levelChecks, levelPair{
oldPowerLevels.eventLevel(eventType, isStateEvent),
newPowerLevels.eventLevel(eventType, isStateEvent),
})
}
// Check each of the levels in the list.
for _, level := range levelChecks {
// Check if the level is being changed.
if level.old == level.new {
// Levels are always allowed to stay the same.
continue
}
// Users are allowed to change the level for an event if:
// * the old level was less than or equal to their own
// * the new level was less than or equal to their own
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L1134
// Check if the user is trying to set any of the levels to above their own.
if senderLevel < level.new {
return errorf(
"sender with level %d is not allowed to change level from %d to %d"+
" because the new level is above the level of the sender",
senderLevel, level.old, level.new,
)
}
// Check if the user is trying to set a level that was above their own.
if senderLevel < level.old {
return errorf(
"sender with level %d is not allowed to change level from %d to %d"+
" because the current level is above the level of the sender",
senderLevel, level.old, level.new,
)
}
}
return nil
}
// checkUserLevels checks that the changes in user levels are allowed.
func checkUserLevels(senderLevel int64, senderID string, oldPowerLevels, newPowerLevels powerLevelContent) error {
type levelPair struct {
old int64
new int64
userID string
}
// Build a list of user levels to check.
// This differs slightly in behaviour from the code in synapse because it will use the
// default value if a level is not present in one of the old or new events.
// First add the user default level.
userLevelChecks := []levelPair{
{oldPowerLevels.userDefaultLevel, newPowerLevels.userDefaultLevel, ""},
}
// Then add checks for each user key in the new levels.
for userID := range newPowerLevels.userLevels {
userLevelChecks = append(userLevelChecks, levelPair{
oldPowerLevels.userLevel(userID), newPowerLevels.userLevel(userID), userID,
})
}
// Then add checks for each user key in the old levels.
// Some of these will be duplicates of the ones added using the keys from
// the new levels. But it doesn't hurt to run the checks twice for the same level.
for userID := range oldPowerLevels.userLevels {
userLevelChecks = append(userLevelChecks, levelPair{
oldPowerLevels.userLevel(userID), newPowerLevels.userLevel(userID), userID,
})
}
// Check each of the levels in the list.
for _, level := range userLevelChecks {
// Check if the level is being changed.
if level.old == level.new {
// Levels are always allowed to stay the same.
continue
}
// Users are allowed to change the level of other users if:
// * the old level was less than their own
// * the new level was less than or equal to their own
// They are allowed to change their own level if:
// * the new level was less than or equal to their own
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L1126-L1127
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L1134
// Check if the user is trying to set any of the levels to above their own.
if senderLevel < level.new {
return errorf(
"sender with level %d is not allowed change user level from %d to %d"+
" because the new level is above the level of the sender",
senderLevel, level.old, level.new,
)
}
// Check if the user is changing their own user level.
if level.userID == senderID {
// Users are always allowed to reduce their own user level.
// We know that the user is reducing their level because of the previous checks.
continue
}
// Check if the user is changing the level that was above or the same as their own.
if senderLevel <= level.old {
return errorf(
"sender with level %d is not allowed to change user level from %d to %d"+
" because the old level is equal to or above the level of the sender",
senderLevel, level.old, level.new,
)
}
}
return nil
}
// redactEventAllowed checks whether the m.room.redaction event is allowed.
// It returns an error if the event is not allowed or if there was a problem
// loading the auth events needed.
func redactEventAllowed(event Event, authEvents AuthEvents) error {
allower, err := newEventAllower(authEvents, event.Sender())
if err != nil {
return err
}
// redact events must pass the default checks,
if err = allower.commonChecks(event); err != nil {
return err
}
senderDomain, err := domainFromID(event.Sender())
if err != nil {
return err
}
redactDomain, err := domainFromID(event.Redacts())
if err != nil {
return err
}
// Servers are always allowed to redact their own messages.
// This is so that users can redact their own messages, but since
// we don't know which user ID sent the message being redacted
// the only check we can do is to compare the domains of the
// sender and the redacted event.
// We leave it up to the sending server to implement the additional checks
// to ensure that only events that should be redacted are redacted.
if senderDomain == redactDomain {
return nil
}
// Otherwise the sender must have enough power.
// This allows room admins and ops to redact messages sent by other servers.
senderLevel := allower.powerLevels.userLevel(event.Sender())
redactLevel := allower.powerLevels.redactLevel
if senderLevel >= redactLevel {
return nil
}
return errorf(
"%q is not allowed to redact message from %q. %d < %d",
event.Sender(), redactDomain, senderLevel, redactLevel,
)
}
// defaultEventAllowed checks whether the event is allowed by the default
// checks for events.
// It returns an error if the event is not allowed or if there was a
// problem loading the auth events needed.
func defaultEventAllowed(event Event, authEvents AuthEvents) error {
allower, err := newEventAllower(authEvents, event.Sender())
if err != nil {
return err
}
return allower.commonChecks(event)
}
// An eventAllower has the information needed to authorise all events types
// other than m.room.create, m.room.member and m.room.aliases which are special.
type eventAllower struct {
// The content of the m.room.create.
create createContent
// The content of the m.room.member event for the sender.
member memberContent
// The content of the m.room.power_levels event for the room.
powerLevels powerLevelContent
}
// newEventAllower loads the information needed to authorise an event sent
// by a given user ID from the auth events.
func newEventAllower(authEvents AuthEvents, senderID string) (e eventAllower, err error) {
if e.create, err = newCreateContentFromAuthEvents(authEvents); err != nil {
return
}
if e.member, err = newMemberContentFromAuthEvents(authEvents, senderID); err != nil {
return
}
if e.powerLevels, err = newPowerLevelContentFromAuthEvents(authEvents, e.create.Creator); err != nil {
return
}
return
}
// commonChecks does the checks that are applied to all events types other than
// m.room.create, m.room.member, or m.room.alias.
func (e *eventAllower) commonChecks(event Event) error {
if event.RoomID() != e.create.roomID {
return errorf("create event has different roomID: %q != %q", event.RoomID(), e.create.roomID)
}
sender := event.Sender()
stateKey := event.StateKey()
if err := e.create.userIDAllowed(sender); err != nil {
return err
}
// Check that the sender is in the room.
// Every event other than m.room.create, m.room.member and m.room.aliases require this.
if e.member.Membership != join {
return errorf("sender %q not in room", sender)
}
senderLevel := e.powerLevels.userLevel(sender)
eventLevel := e.powerLevels.eventLevel(event.Type(), stateKey != nil)
if senderLevel < eventLevel {
return errorf(
"sender %q is not allowed to send event. %d < %d",
event.Sender(), senderLevel, eventLevel,
)
}
// Check that all state_keys that begin with '@' are only updated by users
// with that ID.
if stateKey != nil && len(*stateKey) > 0 && (*stateKey)[0] == '@' {
if *stateKey != sender {
return errorf(
"sender %q is not allowed to modify the state belonging to %q",
sender, *stateKey,
)
}
}
// TODO: Implement other restrictions on state_keys required by the specification.
// However as synapse doesn't implement those checks at the moment we'll hold off
// so that checks between the two codebases don't diverge too much.
return nil
}
// A membershipAllower has the information needed to authenticate a m.room.member event
type membershipAllower struct {
// The user ID of the user whose membership is changing.
targetID string
// The user ID of the user who sent the membership event.
senderID string
// The membership of the user who sent the membership event.
senderMember memberContent
// The previous membership of the user whose membership is changing.
oldMember memberContent
// The new membership of the user if this event is accepted.
newMember memberContent
// The m.room.create content for the room.
create createContent
// The m.room.power_levels content for the room.
powerLevels powerLevelContent
// The m.room.join_rules content for the room.
joinRule joinRuleContent
}
// newMembershipAllower loads the information needed to authenticate the m.room.member event
// from the auth events.
func newMembershipAllower(authEvents AuthEvents, event Event) (m membershipAllower, err error) {
stateKey := event.StateKey()
if stateKey == nil {
err = errorf("m.room.member must be a state event")
return
}
// TODO: Check that the IDs are valid user IDs.
m.targetID = *stateKey
m.senderID = event.Sender()
if m.create, err = newCreateContentFromAuthEvents(authEvents); err != nil {
return
}
if m.newMember, err = newMemberContentFromEvent(event); err != nil {
return
}
if m.oldMember, err = newMemberContentFromAuthEvents(authEvents, m.targetID); err != nil {
return
}
if m.senderMember, err = newMemberContentFromAuthEvents(authEvents, m.senderID); err != nil {
return
}
if m.powerLevels, err = newPowerLevelContentFromAuthEvents(authEvents, m.create.Creator); err != nil {
return
}
// We only need to check the join rules if the proposed membership is "join".
if m.newMember.Membership == "join" {
if m.joinRule, err = newJoinRuleContentFromAuthEvents(authEvents); err != nil {
return
}
}
return
}
// membershipAllowed checks whether the membership event is allowed
func (m *membershipAllower) membershipAllowed(event Event) error {
if m.create.roomID != event.RoomID() {
return errorf("create event has different roomID: %q != %q", event.RoomID(), m.create.roomID)
}
if err := m.create.userIDAllowed(m.senderID); err != nil {
return err
}
if err := m.create.userIDAllowed(m.targetID); err != nil {
return err
}
// Special case the first join event in the room to allow the creator to join.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L328
if m.targetID == m.create.Creator &&
m.newMember.Membership == join &&
m.senderID == m.targetID &&
len(event.PrevEvents()) == 1 {
// Grab the event ID of the previous event.
prevEventID := event.PrevEvents()[0].EventID
if prevEventID == m.create.eventID {
// If this is the room creator joining the room directly after the
// the create event, then allow.
return nil
}
// Otherwise fall back to the normal checks.
}
if m.newMember.Membership == invite && len(m.newMember.ThirdPartyInvite) != 0 {
// Special case third party invites
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L393
panic(fmt.Errorf("ThirdPartyInvite not implemented"))
}
if m.targetID == m.senderID {
// If the state_key and the sender are the same then this is an attempt
// by a user to update their own membership.
return m.membershipAllowedSelf()
}
// Otherwise this is an attempt to modify the membership of somebody else.
return m.membershipAllowedOther()
}
// membershipAllowedSelf determines if the change made by the user to their own membership is allowed.
func (m *membershipAllower) membershipAllowedSelf() error {
if m.newMember.Membership == join {
// A user that is not in the room is allowed to join if the room
// join rules are "public".
if m.oldMember.Membership == leave && m.joinRule.JoinRule == public {
return nil
}
// An invited user is allowed to join if the join rules are "public"
if m.oldMember.Membership == invite && m.joinRule.JoinRule == public {
return nil
}
// An invited user is allowed to join if the join rules are "invite"
if m.oldMember.Membership == invite && m.joinRule.JoinRule == invite {
return nil
}
// A joined user is allowed to update their join.
if m.oldMember.Membership == join {
return nil
}
}
if m.newMember.Membership == leave {
// A joined user is allowed to leave the room.
if m.oldMember.Membership == join {
return nil
}
// An invited user is allowed to reject an invite.
if m.oldMember.Membership == invite {
return nil
}
}
return m.membershipFailed()
}
// membershipAllowedOther determines if the user is allowed to change the membership of another user.
func (m *membershipAllower) membershipAllowedOther() error {
senderLevel := m.powerLevels.userLevel(m.senderID)
targetLevel := m.powerLevels.userLevel(m.targetID)
// You may only modify the membership of another user if you are in the room.
if m.senderMember.Membership != join {
return errorf("sender %q is not in the room", m.senderID)
}
if m.newMember.Membership == ban {
// A user may ban another user if their level is high enough
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L463
if senderLevel >= m.powerLevels.banLevel &&
senderLevel > targetLevel {
return nil
}
}
if m.newMember.Membership == leave {
// A user may unban another user if their level is high enough.
// This is doesn't require the same power_level checks as banning.
// You can unban someone with higher power_level than you.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L451
if m.oldMember.Membership == ban && senderLevel >= m.powerLevels.banLevel {
return nil
}
// A user may kick another user if their level is high enough.
// TODO: You can kick a user that was already kicked, or has left the room, or was
// never in the room in the first place. Do we want to allow these redundant kicks?
if m.oldMember.Membership != ban &&
senderLevel >= m.powerLevels.kickLevel &&
senderLevel > targetLevel {
return nil
}
}
if m.newMember.Membership == invite {
// A user may invite another user if the user has left the room.
// and their level is high enough.
if m.oldMember.Membership == leave && senderLevel >= m.powerLevels.inviteLevel {
return nil
}
// A user may re-invite a user.
if m.oldMember.Membership == invite && senderLevel >= m.powerLevels.inviteLevel {
return nil
}
}
return m.membershipFailed()
}
// membershipFailed returns a error explaining why the membership change was disallowed.
func (m *membershipAllower) membershipFailed() error {
if m.senderID == m.targetID {
return errorf(
"%q is not allowed to change their membership from %q to %q",
m.targetID, m.oldMember.Membership, m.newMember.Membership,
)
}
return errorf(
"%q is not allowed to change the membership of %q from %q to %q",
m.senderID, m.targetID, m.oldMember.Membership, m.newMember.Membership,
)
}

View File

@ -0,0 +1,807 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"testing"
)
func stateNeededEquals(a, b StateNeeded) bool {
if a.Create != b.Create {
return false
}
if a.JoinRules != b.JoinRules {
return false
}
if a.PowerLevels != b.PowerLevels {
return false
}
if len(a.Member) != len(b.Member) {
return false
}
if len(a.ThirdPartyInvite) != len(b.ThirdPartyInvite) {
return false
}
for i := range a.Member {
if a.Member[i] != b.Member[i] {
return false
}
}
for i := range a.ThirdPartyInvite {
if a.ThirdPartyInvite[i] != b.ThirdPartyInvite[i] {
return false
}
}
return true
}
type testEventList []Event
func (tel *testEventList) UnmarshalJSON(data []byte) error {
var eventJSONs []rawJSON
var events []Event
if err := json.Unmarshal([]byte(data), &eventJSONs); err != nil {
return err
}
for _, eventJSON := range eventJSONs {
event, err := NewEventFromTrustedJSON([]byte(eventJSON), false)
if err != nil {
return err
}
events = append(events, event)
}
*tel = testEventList(events)
return nil
}
func testStateNeededForAuth(t *testing.T, eventdata string, want StateNeeded) {
var events testEventList
if err := json.Unmarshal([]byte(eventdata), &events); err != nil {
panic(err)
}
got := StateNeededForAuth(events)
if !stateNeededEquals(got, want) {
t.Errorf("Testing StateNeededForAuth(%#v), wanted %#v got %#v", events, want, got)
}
}
func TestStateNeededForCreate(t *testing.T) {
// Create events don't need anything.
testStateNeededForAuth(t, `[{"type": "m.room.create"}]`, StateNeeded{})
}
func TestStateNeededForMessage(t *testing.T) {
// Message events need the create event, the sender and the power_levels.
testStateNeededForAuth(t, `[{
"type": "m.room.message",
"sender": "@u1:a"
}]`, StateNeeded{
Create: true,
PowerLevels: true,
Member: []string{"@u1:a"},
})
}
func TestStateNeededForAlias(t *testing.T) {
// Alias events need only the create event.
testStateNeededForAuth(t, `[{"type": "m.room.aliases"}]`, StateNeeded{
Create: true,
})
}
func TestStateNeededForJoin(t *testing.T) {
testStateNeededForAuth(t, `[{
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"content": {"membership": "join"}
}]`, StateNeeded{
Create: true,
JoinRules: true,
PowerLevels: true,
Member: []string{"@u1:a"},
})
}
func TestStateNeededForInvite(t *testing.T) {
testStateNeededForAuth(t, `[{
"type": "m.room.member",
"state_key": "@u2:b",
"sender": "@u1:a",
"content": {"membership": "invite"}
}]`, StateNeeded{
Create: true,
PowerLevels: true,
Member: []string{"@u1:a", "@u2:b"},
})
}
func TestStateNeededForInvite3PID(t *testing.T) {
testStateNeededForAuth(t, `[{
"type": "m.room.member",
"state_key": "@u2:b",
"sender": "@u1:a",
"content": {
"membership": "invite",
"third_party_invite": {
"signed": {
"token": "my_token"
}
}
}
}]`, StateNeeded{
Create: true,
PowerLevels: true,
Member: []string{"@u1:a", "@u2:b"},
ThirdPartyInvite: []string{"my_token"},
})
}
type testAuthEvents struct {
CreateJSON json.RawMessage `json:"create"`
JoinRulesJSON json.RawMessage `json:"join_rules"`
PowerLevelsJSON json.RawMessage `json:"power_levels"`
MemberJSON map[string]json.RawMessage `json:"member"`
ThirdPartyInviteJSON map[string]json.RawMessage `json:"third_party_invite"`
}
func (tae *testAuthEvents) Create() (*Event, error) {
if len(tae.CreateJSON) == 0 {
return nil, nil
}
var event Event
event, err := NewEventFromTrustedJSON(tae.CreateJSON, false)
if err != nil {
return nil, err
}
return &event, nil
}
func (tae *testAuthEvents) JoinRules() (*Event, error) {
if len(tae.JoinRulesJSON) == 0 {
return nil, nil
}
event, err := NewEventFromTrustedJSON(tae.JoinRulesJSON, false)
if err != nil {
return nil, err
}
return &event, nil
}
func (tae *testAuthEvents) PowerLevels() (*Event, error) {
if len(tae.PowerLevelsJSON) == 0 {
return nil, nil
}
event, err := NewEventFromTrustedJSON(tae.PowerLevelsJSON, false)
if err != nil {
return nil, err
}
return &event, nil
}
func (tae *testAuthEvents) Member(stateKey string) (*Event, error) {
if len(tae.MemberJSON[stateKey]) == 0 {
return nil, nil
}
event, err := NewEventFromTrustedJSON(tae.MemberJSON[stateKey], false)
if err != nil {
return nil, err
}
return &event, nil
}
func (tae *testAuthEvents) ThirdPartyInvite(stateKey string) (*Event, error) {
if len(tae.ThirdPartyInviteJSON[stateKey]) == 0 {
return nil, nil
}
event, err := NewEventFromTrustedJSON(tae.ThirdPartyInviteJSON[stateKey], false)
if err != nil {
return nil, err
}
return &event, nil
}
type testCase struct {
AuthEvents testAuthEvents `json:"auth_events"`
Allowed []json.RawMessage `json:"allowed"`
NotAllowed []json.RawMessage `json:"not_allowed"`
}
func testEventAllowed(t *testing.T, testCaseJSON string) {
var tc testCase
if err := json.Unmarshal([]byte(testCaseJSON), &tc); err != nil {
panic(err)
}
for _, data := range tc.Allowed {
event, err := NewEventFromTrustedJSON(data, false)
if err != nil {
panic(err)
}
if err = Allowed(event, &tc.AuthEvents); err != nil {
t.Fatalf("Expected %q to be allowed but it was not: %q", string(data), err)
}
}
for _, data := range tc.NotAllowed {
event, err := NewEventFromTrustedJSON(data, false)
if err != nil {
panic(err)
}
if err := Allowed(event, &tc.AuthEvents); err == nil {
t.Fatalf("Expected %q to not be allowed but it was: %q", string(data), err)
}
}
}
func TestAllowedEmptyRoom(t *testing.T) {
// Test that only m.room.create events can be sent without auth events.
// TODO: Test the events that aren't m.room.create
testEventAllowed(t, `{
"auth_events": {},
"allowed": [{
"type": "m.room.create",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e1:a",
"content": {"creator": "@u1:a"}
}],
"not_allowed": [{
"type": "m.room.create",
"state_key": "",
"sender": "@u1:b",
"room_id": "!r1:a",
"event_id": "$e2:a",
"content": {"creator": "@u1:b"},
"unsigned": {
"not_allowed": "Sent by a different server than the one which made the room_id"
}
}, {
"type": "m.room.create",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e3:a",
"prev_events": [["$e1", {}]],
"content": {"creator": "@u1:a"},
"unsigned": {
"not_allowed": "Was not the first event in the room"
}
}, {
"type": "m.room.message",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"body": "Test"},
"unsigned": {
"not_allowed": "No create event"
}
}, {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"membership": "join"},
"unsigned": {
"not_allowed": "No create event"
}
}, {
"type": "m.room.create",
"state_key": "",
"sender": "not_a_user_id",
"room_id": "!r1:a",
"event_id": "$e5:a",
"content": {"creator": "@u1:a"},
"unsigned": {
"not_allowed": "Sender is not a valid user ID"
}
}, {
"type": "m.room.create",
"state_key": "",
"sender": "@u1:a",
"room_id": "not_a_room_id",
"event_id": "$e6:a",
"content": {"creator": "@u1:a"},
"unsigned": {
"not_allowed": "Room is not a valid room ID"
}
}, {
"type": "m.room.create",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e7:a",
"content": {"creator": "@u1:a"},
"unsigned": {
"not_allowed": "Missing state_key"
}
}, {
"type": "m.room.create",
"state_key": "not_empty",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e7:a",
"content": {"creator": "@u1:a"},
"unsigned": {
"not_allowed": "The state_key is not empty"
}
}]
}`)
}
func TestAllowedFirstJoin(t *testing.T) {
testEventAllowed(t, `{
"auth_events": {
"create": {
"type": "m.room.create",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e1:a",
"content": {"creator": "@u1:a"}
}
},
"allowed": [{
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e2:a",
"prev_events": [["$e1:a", {}]],
"content": {"membership": "join"}
}],
"not_allowed": [{
"type": "m.room.message",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e3:a",
"content": {"body": "test"},
"unsigned": {
"not_allowed": "Sender is not in the room"
}
}, {
"type": "m.room.member",
"state_key": "@u2:a",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"prev_events": [["$e1:a", {}]],
"content": {"membership": "join"},
"unsigned": {
"not_allowed": "Only the creator can join the room"
}
}, {
"type": "m.room.member",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"prev_events": [["$e1:a", {}]],
"content": {"membership": "join"},
"unsigned": {
"not_allowed": "Missing state_key"
}
}, {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"prev_events": [["$e2:a", {}]],
"content": {"membership": "join"},
"unsigned": {
"not_allowed": "The prev_event is not the create event"
}
}, {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"membership": "join"},
"unsigned": {
"not_allowed": "There are no prev_events"
}
}, {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"membership": "join"},
"prev_events": [["$e1:a", {}], ["$e2:a", {}]],
"unsigned": {
"not_allowed": "There are too many prev_events"
}
}, {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"membership": "join"},
"prev_events": [["$e1:a", {}]],
"unsigned": {
"not_allowed": "The sender doesn't match the joining user"
}
}, {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"membership": "invite"},
"prev_events": [["$e1:a", {}]],
"unsigned": {
"not_allowed": "The membership is not 'join'"
}
}]
}`)
}
func TestAllowedWithNoPowerLevels(t *testing.T) {
testEventAllowed(t, `{
"auth_events": {
"create": {
"type": "m.room.create",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e1:a",
"content": {"creator": "@u1:a"}
},
"member": {
"@u1:a": {
"type": "m.room.member",
"sender": "@u1:a",
"room_id": "!r1:a",
"state_key": "@u1:a",
"event_id": "$e2:a",
"content": {"membership": "join"}
}
}
},
"allowed": [{
"type": "m.room.message",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e3:a",
"content": {"body": "Test"}
}],
"not_allowed": [{
"type": "m.room.message",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"body": "Test"},
"unsigned": {
"not_allowed": "Sender is not in room"
}
}]
}`)
}
func TestAllowedNoFederation(t *testing.T) {
testEventAllowed(t, `{
"auth_events": {
"create": {
"type": "m.room.create",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e1:a",
"content": {
"creator": "@u1:a",
"m.federate": false
}
},
"member": {
"@u1:a": {
"type": "m.room.member",
"sender": "@u1:a",
"room_id": "!r1:a",
"state_key": "@u1:a",
"event_id": "$e2:a",
"content": {"membership": "join"}
}
}
},
"allowed": [{
"type": "m.room.message",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e3:a",
"content": {"body": "Test"}
}],
"not_allowed": [{
"type": "m.room.message",
"sender": "@u1:b",
"room_id": "!r1:a",
"event_id": "$e4:a",
"content": {"body": "Test"},
"unsigned": {
"not_allowed": "Sender is from a different server."
}
}]
}`)
}
func TestAllowedWithPowerLevels(t *testing.T) {
testEventAllowed(t, `{
"auth_events": {
"create": {
"type": "m.room.create",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e1:a",
"content": {"creator": "@u1:a"}
},
"member": {
"@u1:a": {
"type": "m.room.member",
"sender": "@u1:a",
"room_id": "!r1:a",
"state_key": "@u1:a",
"event_id": "$e2:a",
"content": {"membership": "join"}
},
"@u2:a": {
"type": "m.room.member",
"sender": "@u2:a",
"room_id": "!r1:a",
"state_key": "@u2:a",
"event_id": "$e3:a",
"content": {"membership": "join"}
},
"@u3:b": {
"type": "m.room.member",
"sender": "@u3:b",
"room_id": "!r1:a",
"state_key": "@u3:b",
"event_id": "$e4:a",
"content": {"membership": "join"}
}
},
"power_levels": {
"type": "m.room.power_levels",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e5:a",
"content": {
"users": {
"@u1:a": 100,
"@u2:a": 50
},
"users_default": 0,
"events": {
"m.room.join_rules": 100
},
"state_default": 50,
"events_default": 0
}
}
},
"allowed": [{
"type": "m.room.message",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e6:a",
"content": {"body": "Test from @u1:a"}
}, {
"type": "m.room.message",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "$e7:a",
"content": {"body": "Test from @u2:a"}
}, {
"type": "m.room.message",
"sender": "@u3:b",
"room_id": "!r1:a",
"event_id": "$e8:a",
"content": {"body": "Test from @u3:b"}
},{
"type": "m.room.name",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e9:a",
"content": {"name": "Name set by @u1:a"}
}, {
"type": "m.room.name",
"state_key": "",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "$e10:a",
"content": {"name": "Name set by @u2:a"}
}, {
"type": "m.room.join_rules",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e11:a",
"content": {"join_rule": "public"}
}, {
"type": "my.custom.state",
"state_key": "@u2:a",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "@e12:a",
"content": {}
}],
"not_allowed": [{
"type": "m.room.name",
"state_key": "",
"sender": "@u3:b",
"room_id": "!r1:a",
"event_id": "$e13:a",
"content": {"name": "Name set by @u3:b"},
"unsigned": {
"not_allowed": "User @u3:b's level is too low to send a state event"
}
}, {
"type": "m.room.join_rules",
"sender": "@u2:a",
"room_id": "!r1:a",
"event_id": "$e14:a",
"content": {"name": "Name set by @u3:b"},
"unsigned": {
"not_allowed": "User @u2:a's level is too low to send m.room.join_rules"
}
}, {
"type": "m.room.message",
"sender": "@u4:a",
"room_id": "!r1:a",
"event_id": "$e15:a",
"content": {"Body": "Test from @u4:a"},
"unsigned": {
"not_allowed": "User @u4:a is not in the room"
}
}, {
"type": "m.room.message",
"sender": "@u1:a",
"room_id": "!r2:a",
"event_id": "$e16:a",
"content": {"body": "Test from @u4:a"},
"unsigned": {
"not_allowed": "Sent from a different room to the create event"
}
}, {
"type": "my.custom.state",
"state_key": "@u2:a",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "@e17:a",
"content": {},
"unsigned": {
"not_allowed": "State key starts with '@' and is for a different user"
}
}]
}`)
}
func TestRedactAllowed(t *testing.T) {
// Test if redacts are allowed correctly in a room with a power level event.
testEventAllowed(t, `{
"auth_events": {
"create": {
"type": "m.room.create",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e1:a",
"content": {"creator": "@u1:a"}
},
"member": {
"@u1:a": {
"type": "m.room.member",
"sender": "@u1:a",
"room_id": "!r1:a",
"state_key": "@u1:a",
"event_id": "$e2:a",
"content": {"membership": "join"}
},
"@u2:a": {
"type": "m.room.member",
"sender": "@u2:a",
"room_id": "!r1:a",
"state_key": "@u2:a",
"event_id": "$e3:a",
"content": {"membership": "join"}
},
"@u1:b": {
"type": "m.room.member",
"sender": "@u1:b",
"room_id": "!r1:a",
"state_key": "@u1:b",
"event_id": "$e4:a",
"content": {"membership": "join"}
}
},
"power_levels": {
"type": "m.room.power_levels",
"state_key": "",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e5:a",
"content": {
"users": {
"@u1:a": 100
},
"redact": 100
}
}
},
"allowed": [{
"type": "m.room.redaction",
"sender": "@u1:b",
"room_id": "!r1:a",
"redacts": "$event_sent_by_b:b",
"event_id": "$e6:b",
"content": {"reason": ""}
}, {
"type": "m.room.redaction",
"sender": "@u2:a",
"room_id": "!r1:a",
"redacts": "$event_sent_by_a:a",
"event_id": "$e7:a",
"content": {"reason": ""}
}, {
"type": "m.room.redaction",
"sender": "@u1:a",
"room_id": "!r1:a",
"redacts": "$event_sent_by_b:b",
"event_id": "$e8:a",
"content": {"reason": ""}
}],
"not_allowed": [{
"type": "m.room.redaction",
"sender": "@u2:a",
"room_id": "!r1:a",
"redacts": "$event_sent_by_b:b",
"event_id": "$e9:a",
"content": {"reason": ""},
"unsigned": {
"not_allowed": "User power level is too low and event is from different server"
}
}, {
"type": "m.room.redaction",
"sender": "@u1:c",
"room_id": "!r1:a",
"redacts": "$event_sent_by_c:c",
"event_id": "$e10:a",
"content": {"reason": ""},
"unsigned": {
"not_allowed": "User is not in the room"
}
}, {
"type": "m.room.redaction",
"sender": "@u1:a",
"room_id": "!r1:a",
"redacts": "not_a_valid_event_id",
"event_id": "$e11:a",
"content": {"reason": ""},
"unsigned": {
"not_allowed": "Invalid redacts event ID"
}
}, {
"type": "m.room.redaction",
"sender": "@u1:a",
"room_id": "!r1:a",
"event_id": "$e11:a",
"content": {"reason": ""},
"unsigned": {
"not_allowed": "Missing redacts event ID"
}
}]
}`)
}

View File

@ -0,0 +1,350 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"strconv"
"strings"
)
// createContent is the JSON content of a m.room.create event along with
// the top level keys needed for auth.
// See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-create for descriptions of the fields.
type createContent struct {
// We need the domain of the create event when checking federatability.
senderDomain string
// We need the roomID to check that events are in the same room as the create event.
roomID string
// We need the eventID to check the first join event in the room.
eventID string
// The "m.federate" flag tells us whether the room can be federated to other servers.
Federate *bool `json:"m.federate"`
// The creator of the room tells us what the default power levels are.
Creator string `json:"creator"`
}
// newCreateContentFromAuthEvents loads the create event content from the create event in the
// auth events.
func newCreateContentFromAuthEvents(authEvents AuthEvents) (c createContent, err error) {
var createEvent *Event
if createEvent, err = authEvents.Create(); err != nil {
return
}
if createEvent == nil {
err = errorf("missing create event")
return
}
if err = json.Unmarshal(createEvent.Content(), &c); err != nil {
err = errorf("unparsable create event content: %s", err.Error())
return
}
c.roomID = createEvent.RoomID()
c.eventID = createEvent.EventID()
if c.senderDomain, err = domainFromID(createEvent.Sender()); err != nil {
return
}
return
}
// domainAllowed checks whether the domain is allowed in the room by the
// "m.federate" flag.
func (c *createContent) domainAllowed(domain string) error {
if domain == c.senderDomain {
// If the domain matches the domain of the create event then the event
// is always allowed regardless of the value of the "m.federate" flag.
return nil
}
if c.Federate == nil || *c.Federate {
// The m.federate field defaults to true.
// If the domains are different then event is only allowed if the
// "m.federate" flag is absent or true.
return nil
}
return errorf("room is unfederatable")
}
// userIDAllowed checks whether the domain part of the user ID is allowed in
// the room by the "m.federate" flag.
func (c *createContent) userIDAllowed(id string) error {
domain, err := domainFromID(id)
if err != nil {
return err
}
return c.domainAllowed(domain)
}
// domainFromID returns everything after the first ":" character to extract
// the domain part of a matrix ID.
func domainFromID(id string) (string, error) {
// IDs have the format: SIGIL LOCALPART ":" DOMAIN
// Split on the first ":" character since the domain can contain ":"
// characters.
parts := strings.SplitN(id, ":", 2)
if len(parts) != 2 {
// The ID must have a ":" character.
return "", errorf("invalid ID: %q", id)
}
// Return everything after the first ":" character.
return parts[1], nil
}
// memberContent is the JSON content of a m.room.member event needed for auth checks.
// See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member for descriptions of the fields.
type memberContent struct {
// We use the membership key in order to check if the user is in the room.
Membership string `json:"membership"`
// We use the third_party_invite key to special case thirdparty invites.
ThirdPartyInvite json.RawMessage `json:"third_party_invite"`
}
// newMemberContentFromAuthEvents loads the member content from the member event for the user ID in the auth events.
// Returns an error if there was an error loading the member event or parsing the event content.
func newMemberContentFromAuthEvents(authEvents AuthEvents, userID string) (c memberContent, err error) {
var memberEvent *Event
if memberEvent, err = authEvents.Member(userID); err != nil {
return
}
if memberEvent == nil {
// If there isn't a member event then the membership for the user
// defaults to leave.
c.Membership = leave
return
}
return newMemberContentFromEvent(*memberEvent)
}
// newMemberContentFromEvent parse the member content from an event.
// Returns an error if the content couldn't be parsed.
func newMemberContentFromEvent(event Event) (c memberContent, err error) {
if err = json.Unmarshal(event.Content(), &c); err != nil {
err = errorf("unparsable member event content: %s", err.Error())
return
}
return
}
// joinRuleContent is the JSON content of a m.room.join_rules event needed for auth checks.
// See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-join-rules for descriptions of the fields.
type joinRuleContent struct {
// We use the join_rule key to check whether join m.room.member events are allowed.
JoinRule string `json:"join_rule"`
}
// newJoinRuleContentFromAuthEvents loads the join rule content from the join rules event in the auth event.
// Returns an error if there was an error loading the join rule event or parsing the content.
func newJoinRuleContentFromAuthEvents(authEvents AuthEvents) (c joinRuleContent, err error) {
var joinRulesEvent *Event
if joinRulesEvent, err = authEvents.JoinRules(); err != nil {
return
}
if joinRulesEvent == nil {
// Default to "invite"
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L368
c.JoinRule = invite
return
}
if err = json.Unmarshal(joinRulesEvent.Content(), &c); err != nil {
err = errorf("unparsable join_rules event content: %s", err.Error())
return
}
return
}
// powerLevelContent is the JSON content of a m.room.power_levels event needed for auth checks.
// We can't unmarshal the content directly from JSON because we need to set
// defaults and convert string values to int values.
// See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels for descriptions of the fields.
type powerLevelContent struct {
banLevel int64
inviteLevel int64
kickLevel int64
redactLevel int64
userLevels map[string]int64
userDefaultLevel int64
eventLevels map[string]int64
eventDefaultLevel int64
stateDefaultLevel int64
}
// userLevel returns the power level a user has in the room.
func (c *powerLevelContent) userLevel(userID string) int64 {
level, ok := c.userLevels[userID]
if ok {
return level
}
return c.userDefaultLevel
}
// eventLevel returns the power level needed to send an event in the room.
func (c *powerLevelContent) eventLevel(eventType string, isState bool) int64 {
if eventType == "m.room.third_party_invite" {
// Special case third_party_invite events to have the same level as
// m.room.member invite events.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L182
return c.inviteLevel
}
level, ok := c.eventLevels[eventType]
if ok {
return level
}
if isState {
return c.stateDefaultLevel
}
return c.eventDefaultLevel
}
// newPowerLevelContentFromAuthEvents loads the power level content from the
// power level event in the auth events or returns the default values if there
// is no power level event.
func newPowerLevelContentFromAuthEvents(authEvents AuthEvents, creatorUserID string) (c powerLevelContent, err error) {
powerLevelsEvent, err := authEvents.PowerLevels()
if err != nil {
return
}
if powerLevelsEvent != nil {
return newPowerLevelContentFromEvent(*powerLevelsEvent)
}
// If there are no power levels then fall back to defaults.
c.defaults()
// If there is no power level event then the creator gets level 100
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L569
c.userLevels = map[string]int64{creatorUserID: 100}
return
}
// defaults sets the power levels to their default values.
func (c *powerLevelContent) defaults() {
// Default invite level is 0.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L426
c.inviteLevel = 0
// Default ban, kick and redacts levels are 50
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L376
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L456
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L1041
c.banLevel = 50
c.kickLevel = 50
c.redactLevel = 50
// Default user level is 0
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L558
c.userDefaultLevel = 0
// Default event level is 0, Default state level is 50
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L987
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/api/auth.py#L991
c.eventDefaultLevel = 0
c.stateDefaultLevel = 50
}
// newPowerLevelContentFromEvent loads the power level content from an event.
func newPowerLevelContentFromEvent(event Event) (c powerLevelContent, err error) {
// Set the levels to their default values.
c.defaults()
// We can't extract the JSON directly to the powerLevelContent because we
// need to convert string values to int values.
var content struct {
InviteLevel levelJSONValue `json:"invite"`
BanLevel levelJSONValue `json:"ban"`
KickLevel levelJSONValue `json:"kick"`
RedactLevel levelJSONValue `json:"redact"`
UserLevels map[string]levelJSONValue `json:"users"`
UsersDefaultLevel levelJSONValue `json:"users_default"`
EventLevels map[string]levelJSONValue `json:"events"`
StateDefaultLevel levelJSONValue `json:"state_default"`
EventDefaultLevel levelJSONValue `json:"event_default"`
}
if err = json.Unmarshal(event.Content(), &content); err != nil {
err = errorf("unparsable power_levels event content: %s", err.Error())
return
}
// Update the levels with the values that are present in the event content.
content.InviteLevel.assignIfExists(&c.inviteLevel)
content.BanLevel.assignIfExists(&c.banLevel)
content.KickLevel.assignIfExists(&c.kickLevel)
content.RedactLevel.assignIfExists(&c.redactLevel)
content.UsersDefaultLevel.assignIfExists(&c.userDefaultLevel)
content.StateDefaultLevel.assignIfExists(&c.stateDefaultLevel)
content.EventDefaultLevel.assignIfExists(&c.eventDefaultLevel)
for k, v := range content.UserLevels {
if c.userLevels == nil {
c.userLevels = make(map[string]int64)
}
c.userLevels[k] = v.value
}
for k, v := range content.EventLevels {
if c.eventLevels == nil {
c.eventLevels = make(map[string]int64)
}
c.eventLevels[k] = v.value
}
return
}
// A levelJSONValue is used for unmarshalling power levels from JSON.
// It is intended to replicate the effects of x = int(content["key"]) in python.
type levelJSONValue struct {
// Was a value loaded from the JSON?
exists bool
// The integer value of the power level.
value int64
}
func (v *levelJSONValue) UnmarshalJSON(data []byte) error {
var stringValue string
var int64Value int64
var floatValue float64
var err error
// First try to unmarshal as an int64.
if err = json.Unmarshal(data, &int64Value); err != nil {
// If unmarshalling as an int64 fails try as a string.
if err = json.Unmarshal(data, &stringValue); err != nil {
// If unmarshalling as a string fails try as a float.
if err = json.Unmarshal(data, &floatValue); err != nil {
return err
}
int64Value = int64(floatValue)
} else {
// If we managed to get a string, try parsing the string as an int.
int64Value, err = strconv.ParseInt(stringValue, 10, 64)
if err != nil {
return err
}
}
}
v.exists = true
v.value = int64Value
return nil
}
// assign the power level if a value was present in the JSON.
func (v *levelJSONValue) assignIfExists(to *int64) {
if v.exists {
*to = v.value
}
}
// Check if the user ID is a valid user ID.
func isValidUserID(userID string) bool {
// TODO: Do we want to add anymore checks beyond checking the sigil and that it has a domain part?
return userID[0] == '@' && strings.IndexByte(userID, ':') != -1
}

View File

@ -0,0 +1,51 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"testing"
)
func TestLevelJSONValueValid(t *testing.T) {
var values []levelJSONValue
input := `[0,"1",2.0]`
if err := json.Unmarshal([]byte(input), &values); err != nil {
t.Fatal("Unexpected error unmarshalling ", input, ": ", err)
}
for i, got := range values {
want := i
if !got.exists {
t.Fatalf("Wanted entry %d to exist", want)
}
if int64(want) != got.value {
t.Fatalf("Wanted %d got %q", want, got.value)
}
}
}
func TestLevelJSONValueInvalid(t *testing.T) {
var values []levelJSONValue
inputs := []string{
`[{}]`, `[[]]`, `["not a number"]`, `["0.0"]`,
}
for _, input := range inputs {
if err := json.Unmarshal([]byte(input), &values); err == nil {
t.Fatalf("Unexpected success when unmarshalling %q", input)
}
}
}

View File

@ -0,0 +1,186 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"golang.org/x/crypto/ed25519"
)
// addContentHashesToEvent sets the "hashes" key of the event with a SHA-256 hash of the unredacted event content.
// This hash is used to detect whether the unredacted content of the event is valid.
// Returns the event JSON with a "hashes" key added to it.
func addContentHashesToEvent(eventJSON []byte) ([]byte, error) {
var event map[string]rawJSON
if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err
}
unsignedJSON := event["unsigned"]
delete(event, "unsigned")
delete(event, "hashes")
hashableEventJSON, err := json.Marshal(event)
if err != nil {
return nil, err
}
hashableEventJSON, err = CanonicalJSON(hashableEventJSON)
if err != nil {
return nil, err
}
sha256Hash := sha256.Sum256(hashableEventJSON)
hashes := struct {
Sha256 Base64String `json:"sha256"`
}{Base64String(sha256Hash[:])}
hashesJSON, err := json.Marshal(&hashes)
if err != nil {
return nil, err
}
if len(unsignedJSON) > 0 {
event["unsigned"] = unsignedJSON
}
event["hashes"] = rawJSON(hashesJSON)
return json.Marshal(event)
}
// checkEventContentHash checks if the unredacted content of the event matches the SHA-256 hash under the "hashes" key.
func checkEventContentHash(eventJSON []byte) error {
var event map[string]rawJSON
if err := json.Unmarshal(eventJSON, &event); err != nil {
return err
}
hashesJSON := event["hashes"]
delete(event, "signatures")
delete(event, "unsigned")
delete(event, "hashes")
var hashes struct {
Sha256 Base64String `json:"sha256"`
}
if err := json.Unmarshal(hashesJSON, &hashes); err != nil {
return err
}
hashableEventJSON, err := json.Marshal(event)
if err != nil {
return err
}
hashableEventJSON, err = CanonicalJSON(hashableEventJSON)
if err != nil {
return err
}
sha256Hash := sha256.Sum256(hashableEventJSON)
if bytes.Compare(sha256Hash[:], []byte(hashes.Sha256)) != 0 {
return fmt.Errorf("Invalid Sha256 content hash: %v != %v", sha256Hash[:], []byte(hashes.Sha256))
}
return nil
}
// ReferenceSha256HashOfEvent returns the SHA-256 hash of the redacted event content.
// This is used when referring to this event from other events.
func referenceOfEvent(eventJSON []byte) (EventReference, error) {
redactedJSON, err := redactEvent(eventJSON)
if err != nil {
return EventReference{}, err
}
var event map[string]rawJSON
if err = json.Unmarshal(redactedJSON, &event); err != nil {
return EventReference{}, err
}
delete(event, "signatures")
delete(event, "unsigned")
hashableEventJSON, err := json.Marshal(event)
if err != nil {
return EventReference{}, err
}
hashableEventJSON, err = CanonicalJSON(hashableEventJSON)
if err != nil {
return EventReference{}, err
}
sha256Hash := sha256.Sum256(hashableEventJSON)
var eventID string
if err = json.Unmarshal(event["event_id"], &eventID); err != nil {
return EventReference{}, err
}
return EventReference{eventID, sha256Hash[:]}, nil
}
// SignEvent adds a ED25519 signature to the event for the given key.
func signEvent(signingName, keyID string, privateKey ed25519.PrivateKey, eventJSON []byte) ([]byte, error) {
// Redact the event before signing so signature that will remain valid even if the event is redacted.
redactedJSON, err := redactEvent(eventJSON)
if err != nil {
return nil, err
}
// Sign the JSON, this adds a "signatures" key to the redacted event.
// TODO: Make an internal version of SignJSON that returns just the signatures so that we don't have to parse it out of the JSON.
signedJSON, err := SignJSON(signingName, keyID, privateKey, redactedJSON)
if err != nil {
return nil, err
}
var signedEvent struct {
Signatures rawJSON `json:"signatures"`
}
if err := json.Unmarshal(signedJSON, &signedEvent); err != nil {
return nil, err
}
// Unmarshal the event JSON so that we can replace the signatures key.
var event map[string]rawJSON
if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err
}
event["signatures"] = signedEvent.Signatures
return json.Marshal(event)
}
// VerifyEventSignature checks if the event has been signed by the given ED25519 key.
func verifyEventSignature(signingName, keyID string, publicKey ed25519.PublicKey, eventJSON []byte) error {
redactedJSON, err := redactEvent(eventJSON)
if err != nil {
return err
}
return VerifyJSON(signingName, keyID, publicKey, redactedJSON)
}

View File

@ -0,0 +1,259 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"bytes"
"encoding/base64"
"golang.org/x/crypto/ed25519"
"testing"
)
func TestVerifyEventSignatureTestVectors(t *testing.T) {
// Check JSON verification using the test vectors from https://matrix.org/docs/spec/appendices.html
seed, err := base64.RawStdEncoding.DecodeString("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
if err != nil {
t.Fatal(err)
}
random := bytes.NewBuffer(seed)
entityName := "domain"
keyID := "ed25519:1"
publicKey, _, err := ed25519.GenerateKey(random)
if err != nil {
t.Fatal(err)
}
testVerifyOK := func(input string) {
err := verifyEventSignature(entityName, keyID, publicKey, []byte(input))
if err != nil {
t.Fatal(err)
}
}
testVerifyNotOK := func(reason, input string) {
err := verifyEventSignature(entityName, keyID, publicKey, []byte(input))
if err == nil {
t.Fatalf("Expected VerifyJSON to fail for input %v because %v", input, reason)
}
}
testVerifyOK(`{
"event_id": "$0:domain",
"hashes": {
"sha256": "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
},
"origin": "domain",
"origin_server_ts": 1000000,
"signatures": {
"domain": {
"ed25519:1": "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA"
}
},
"type": "X",
"unsigned": {
"age_ts": 1000000
}
}`)
// It should still pass signature checks, even if we remove the unsigned data.
testVerifyOK(`{
"event_id": "$0:domain",
"hashes": {
"sha256": "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
},
"origin": "domain",
"origin_server_ts": 1000000,
"signatures": {
"domain": {
"ed25519:1": "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA"
}
},
"type": "X",
"unsigned": {}
}`)
testVerifyOK(`{
"content": {
"body": "Here is the message content"
},
"event_id": "$0:domain",
"hashes": {
"sha256": "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
},
"origin": "domain",
"origin_server_ts": 1000000,
"type": "m.room.message",
"room_id": "!r:domain",
"sender": "@u:domain",
"signatures": {
"domain": {
"ed25519:1": "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUwu6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
}
},
"unsigned": {
"age_ts": 1000000
}
}`)
// It should still pass signature checks, even if we redact the content.
testVerifyOK(`{
"content": {},
"event_id": "$0:domain",
"hashes": {
"sha256": "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
},
"origin": "domain",
"origin_server_ts": 1000000,
"type": "m.room.message",
"room_id": "!r:domain",
"sender": "@u:domain",
"signatures": {
"domain": {
"ed25519:1": "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUwu6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
}
},
"unsigned": {}
}`)
testVerifyNotOK("The event is modified", `{
"event_id": "$0:domain",
"hashes": {
"sha256": "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
},
"origin": "domain",
"origin_server_ts": 1000000,
"signatures": {
"domain": {
"ed25519:1": "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA"
}
},
"type": "modified",
"unsigned": {}
}`)
testVerifyNotOK("The content hash is modified", `{
"content": {},
"event_id": "$0:domain",
"hashes": {
"sha256": "adifferenthashvalueaP9SorVmRQNdN5aM2JYU2n/g"
},
"origin": "domain",
"origin_server_ts": 1000000,
"type": "m.room.message",
"room_id": "!r:domain",
"sender": "@u:domain",
"signatures": {
"domain": {
"ed25519:1": "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUwu6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
}
},
"unsigned": {}
}`)
}
func TestSignEventTestVectors(t *testing.T) {
// Check matrix event signing using the test vectors from https://matrix.org/docs/spec/appendices.html
seed, err := base64.RawStdEncoding.DecodeString("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
if err != nil {
t.Fatal(err)
}
random := bytes.NewBuffer(seed)
entityName := "domain"
keyID := "ed25519:1"
_, privateKey, err := ed25519.GenerateKey(random)
if err != nil {
t.Fatal(err)
}
testSign := func(input string, want string) {
hashed, err := addContentHashesToEvent([]byte(input))
if err != nil {
t.Fatal(err)
}
signed, err := signEvent(entityName, keyID, privateKey, hashed)
if err != nil {
t.Fatal(err)
}
if !IsJSONEqual([]byte(want), signed) {
t.Fatalf("SignEvent(%q): want %v got %v", input, want, string(signed))
}
}
testSign(`{
"event_id": "$0:domain",
"origin": "domain",
"origin_server_ts": 1000000,
"type": "X",
"unsigned": {
"age_ts": 1000000
}
}`, `{
"event_id": "$0:domain",
"hashes": {
"sha256": "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
},
"origin": "domain",
"origin_server_ts": 1000000,
"signatures": {
"domain": {
"ed25519:1": "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA"
}
},
"type": "X",
"unsigned": {
"age_ts": 1000000
}
}`)
testSign(`{
"content": {
"body": "Here is the message content"
},
"event_id": "$0:domain",
"origin": "domain",
"origin_server_ts": 1000000,
"type": "m.room.message",
"room_id": "!r:domain",
"sender": "@u:domain",
"unsigned": {
"age_ts": 1000000
}
}`, `{
"content": {
"body": "Here is the message content"
},
"event_id": "$0:domain",
"hashes": {
"sha256": "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
},
"origin": "domain",
"origin_server_ts": 1000000,
"type": "m.room.message",
"room_id": "!r:domain",
"sender": "@u:domain",
"signatures": {
"domain": {
"ed25519:1": "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUwu6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
}
},
"unsigned": {
"age_ts": 1000000
}
}`)
}

View File

@ -0,0 +1,5 @@
#! /bin/bash
DOT_GIT="$(dirname $0)/../.git"
ln -s "../../hooks/pre-commit" "$DOT_GIT/hooks/pre-commit"

View File

@ -0,0 +1,9 @@
#! /bin/bash
set -eu
golint ./...
go fmt
go tool vet --all --shadow .
gocyclo -over 16 .
go test -timeout 5s . ./...

View File

@ -0,0 +1,239 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"bytes"
"encoding/binary"
"encoding/json"
"sort"
"unicode/utf8"
)
// CanonicalJSON re-encodes the JSON in a cannonical encoding. The encoding is
// the shortest possible encoding using integer values with sorted object keys.
// https://matrix.org/docs/spec/server_server/unstable.html#canonical-json
func CanonicalJSON(input []byte) ([]byte, error) {
sorted, err := SortJSON(input, make([]byte, 0, len(input)))
if err != nil {
return nil, err
}
return CompactJSON(sorted, make([]byte, 0, len(sorted))), nil
}
// SortJSON reencodes the JSON with the object keys sorted by lexicographically
// by codepoint. The input must be valid JSON.
func SortJSON(input, output []byte) ([]byte, error) {
// Skip to the first character that isn't whitespace.
var decoded interface{}
decoder := json.NewDecoder(bytes.NewReader(input))
decoder.UseNumber()
if err := decoder.Decode(&decoded); err != nil {
return nil, err
}
return sortJSONValue(decoded, output)
}
func sortJSONValue(input interface{}, output []byte) ([]byte, error) {
switch value := input.(type) {
case []interface{}:
// If the JSON is an array then we need to sort the keys of its children.
return sortJSONArray(value, output)
case map[string]interface{}:
// If the JSON is an object then we need to sort its keys and the keys of its children.
return sortJSONObject(value, output)
default:
// Otherwise the JSON is a value and can be encoded without any further sorting.
bytes, err := json.Marshal(value)
if err != nil {
return nil, err
}
return append(output, bytes...), nil
}
}
func sortJSONArray(input []interface{}, output []byte) ([]byte, error) {
var err error
sep := byte('[')
for _, value := range input {
output = append(output, sep)
sep = ','
if output, err = sortJSONValue(value, output); err != nil {
return nil, err
}
}
if sep == '[' {
// If sep is still '[' then the array was empty and we never wrote the
// initial '[', so we write it now along with the closing ']'.
output = append(output, '[', ']')
} else {
// Otherwise we end the array by writing a single ']'
output = append(output, ']')
}
return output, nil
}
func sortJSONObject(input map[string]interface{}, output []byte) ([]byte, error) {
var err error
keys := make([]string, len(input))
var j int
for key := range input {
keys[j] = key
j++
}
sort.Strings(keys)
sep := byte('{')
for _, key := range keys {
output = append(output, sep)
sep = ','
var encoded []byte
if encoded, err = json.Marshal(key); err != nil {
return nil, err
}
output = append(output, encoded...)
output = append(output, ':')
if output, err = sortJSONValue(input[key], output); err != nil {
return nil, err
}
}
if sep == '{' {
// If sep is still '{' then the object was empty and we never wrote the
// initial '{', so we write it now along with the closing '}'.
output = append(output, '{', '}')
} else {
// Otherwise we end the object by writing a single '}'
output = append(output, '}')
}
return output, nil
}
// CompactJSON makes the encoded JSON as small as possible by removing
// whitespace and unneeded unicode escapes
func CompactJSON(input, output []byte) []byte {
var i int
for i < len(input) {
c := input[i]
i++
// The valid whitespace characters are all less than or equal to SPACE 0x20.
// The valid non-white characters are all greater than SPACE 0x20.
// So we can check for whitespace by comparing against SPACE 0x20.
if c <= ' ' {
// Skip over whitespace.
continue
}
// Add the non-whitespace character to the output.
output = append(output, c)
if c == '"' {
// We are inside a string.
for i < len(input) {
c = input[i]
i++
// Check if this is an escape sequence.
if c == '\\' {
escape := input[i]
i++
if escape == 'u' {
// If this is a unicode escape then we need to handle it specially
output, i = compactUnicodeEscape(input, output, i)
} else if escape == '/' {
// JSON does not require escaping '/', but allows encoders to escape it as a special case.
// Since the escape isn't required we remove it.
output = append(output, escape)
} else {
// All other permitted escapes are single charater escapes that are already in their shortest form.
output = append(output, '\\', escape)
}
} else {
output = append(output, c)
}
if c == '"' {
break
}
}
}
}
return output
}
// compactUnicodeEscape unpacks a 4 byte unicode escape starting at index.
// If the escape is a surrogate pair then decode the 6 byte \uXXXX escape
// that follows. Returns the output slice and a new input index.
func compactUnicodeEscape(input, output []byte, index int) ([]byte, int) {
const (
ESCAPES = "uuuuuuuubtnufruuuuuuuuuuuuuuuuuu"
HEX = "0123456789ABCDEF"
)
// If there aren't enough bytes to decode the hex escape then return.
if len(input)-index < 4 {
return output, len(input)
}
// Decode the 4 hex digits.
c := readHexDigits(input[index:])
index += 4
if c < ' ' {
// If the character is less than SPACE 0x20 then it will need escaping.
escape := ESCAPES[c]
output = append(output, '\\', escape)
if escape == 'u' {
output = append(output, '0', '0', byte('0'+(c>>4)), HEX[c&0xF])
}
} else if c == '\\' || c == '"' {
// Otherwise the character only needs escaping if it is a QUOTE '"' or BACKSLASH '\\'.
output = append(output, '\\', byte(c))
} else if c < 0xD800 || c >= 0xE000 {
// If the character isn't a surrogate pair then encoded it directly as UTF-8.
var buffer [4]byte
n := utf8.EncodeRune(buffer[:], rune(c))
output = append(output, buffer[:n]...)
} else {
// Otherwise the escaped character was the first part of a UTF-16 style surrogate pair.
// The next 6 bytes MUST be a '\uXXXX'.
// If there aren't enough bytes to decode the hex escape then return.
if len(input)-index < 6 {
return output, len(input)
}
// Decode the 4 hex digits from the '\uXXXX'.
surrogate := readHexDigits(input[index+2:])
index += 6
// Reconstruct the UCS4 codepoint from the surrogates.
codepoint := 0x10000 + (((c & 0x3FF) << 10) | (surrogate & 0x3FF))
// Encode the charater as UTF-8.
var buffer [4]byte
n := utf8.EncodeRune(buffer[:], rune(codepoint))
output = append(output, buffer[:n]...)
}
return output, index
}
// Read 4 hex digits from the input slice.
// Taken from https://github.com/NegativeMjark/indolentjson-rust/blob/8b959791fe2656a88f189c5d60d153be05fe3deb/src/readhex.rs#L21
func readHexDigits(input []byte) uint32 {
hex := binary.BigEndian.Uint32(input)
// substract '0'
hex -= 0x30303030
// strip the higher bits, maps 'a' => 'A'
hex &= 0x1F1F1F1F
mask := hex & 0x10101010
// subtract 'A' - 10 - '9' - 9 = 7 from the letters.
hex -= mask >> 1
hex += mask >> 4
// collect the nibbles
hex |= hex >> 4
hex &= 0xFF00FF
hex |= hex >> 8
return hex & 0xFFFF
}

View File

@ -0,0 +1,92 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"testing"
)
func testSortJSON(t *testing.T, input, want string) {
got, err := SortJSON([]byte(input), nil)
if err != nil {
t.Error(err)
}
// Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
if string(CompactJSON(got, nil)) != want {
t.Errorf("SortJSON(%q): want %q got %q", input, want, got)
}
}
func TestSortJSON(t *testing.T) {
testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`)
testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`,
`{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`)
testSortJSON(t, `[true,false,null]`, `[true,false,null]`)
testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`)
}
func testCompactJSON(t *testing.T, input, want string) {
got := string(CompactJSON([]byte(input), nil))
if got != want {
t.Errorf("CompactJSON(%q): want %q got %q", input, want, got)
}
}
func TestCompactJSON(t *testing.T) {
testCompactJSON(t, "{ }", "{}")
input := `["\u0000\u0001\u0002\u0003\u0004\u0005\u0006\u0007"]`
want := input
testCompactJSON(t, input, want)
input = `["\u0008\u0009\u000A\u000B\u000C\u000D\u000E\u000F"]`
want = `["\b\t\n\u000B\f\r\u000E\u000F"]`
testCompactJSON(t, input, want)
input = `["\u0010\u0011\u0012\u0013\u0014\u0015\u0016\u0017"]`
want = input
testCompactJSON(t, input, want)
input = `["\u0018\u0019\u001A\u001B\u001C\u001D\u001E\u001F"]`
want = input
testCompactJSON(t, input, want)
testCompactJSON(t, `["\u0061\u005C\u0042\u0022"]`, `["a\\B\""]`)
testCompactJSON(t, `["\u0120"]`, "[\"\u0120\"]")
testCompactJSON(t, `["\u0FFF"]`, "[\"\u0FFF\"]")
testCompactJSON(t, `["\u1820"]`, "[\"\u1820\"]")
testCompactJSON(t, `["\uFFFF"]`, "[\"\uFFFF\"]")
testCompactJSON(t, `["\uD842\uDC20"]`, "[\"\U00020820\"]")
testCompactJSON(t, `["\uDBFF\uDFFF"]`, "[\"\U0010FFFF\"]")
testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`)
}
func testReadHex(t *testing.T, input string, want uint32) {
got := readHexDigits([]byte(input))
if want != got {
t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got)
}
}
func TestReadHex(t *testing.T) {
testReadHex(t, "0123", 0x0123)
testReadHex(t, "4567", 0x4567)
testReadHex(t, "89AB", 0x89AB)
testReadHex(t, "CDEF", 0xCDEF)
testReadHex(t, "89ab", 0x89AB)
testReadHex(t, "cdef", 0xCDEF)
}

View File

@ -0,0 +1,214 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"bufio"
"bytes"
"crypto/sha256"
"crypto/tls"
"encoding/json"
"io/ioutil"
"net"
"net/http"
"strings"
"time"
)
// ServerKeys are the ed25519 signing keys published by a matrix server.
// Contains SHA256 fingerprints of the TLS X509 certificates used by the server.
type ServerKeys struct {
Raw []byte `json:"-"` // Copy of the raw JSON for signature checking.
ServerName string `json:"server_name"` // The name of the server.
TLSFingerprints []struct { // List of SHA256 fingerprints of X509 certificates.
SHA256 Base64String `json:"sha256"`
} `json:"tls_fingerprints"`
VerifyKeys map[string]struct { // The current signing keys in use on this server.
Key Base64String `json:"key"` // The public key.
} `json:"verify_keys"`
ValidUntilTS int64 `json:"valid_until_ts"` // When this result is valid until in milliseconds.
OldVerifyKeys map[string]struct { // Old keys that are now only valid for checking historic events.
Key Base64String `json:"key"` // The public key.
ExpiredTS uint64 `json:"expired_ts"` // When this key stopped being valid for event signing.
} `json:"old_verify_keys"`
}
// FetchKeysDirect fetches the matrix keys directly from the given address.
// Optionally sets a SNI header if ``sni`` is not empty.
// Returns the server keys and the state of the TLS connection used to retrieve them.
func FetchKeysDirect(serverName, addr, sni string) (*ServerKeys, *tls.ConnectionState, error) {
// Create a TLS connection.
tcpconn, err := net.Dial("tcp", addr)
if err != nil {
return nil, nil, err
}
defer tcpconn.Close()
tlsconn := tls.Client(tcpconn, &tls.Config{
ServerName: sni,
InsecureSkipVerify: true, // This must be specified even though the TLS library will ignore it.
})
if err = tlsconn.Handshake(); err != nil {
return nil, nil, err
}
connectionState := tlsconn.ConnectionState()
// Write a GET /_matrix/key/v2/server down the connection.
requestURL := "matrix://" + serverName + "/_matrix/key/v2/server"
request, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, nil, err
}
request.Header.Set("Connection", "close")
if err = request.Write(tlsconn); err != nil {
return nil, nil, err
}
// Read the 200 OK from the server.
response, err := http.ReadResponse(bufio.NewReader(tlsconn), request)
if response != nil {
defer response.Body.Close()
}
if err != nil {
return nil, nil, err
}
var keys ServerKeys
if keys.Raw, err = ioutil.ReadAll(response.Body); err != nil {
return nil, nil, err
}
if err = json.Unmarshal(keys.Raw, &keys); err != nil {
return nil, nil, err
}
return &keys, &connectionState, nil
}
// Ed25519Checks are the checks that are applied to Ed25519 keys in ServerKey responses.
type Ed25519Checks struct {
ValidEd25519 bool // The verify key is valid Ed25519 keys.
MatchingSignature bool // The verify key has a valid signature.
}
// TLSFingerprintChecks are the checks that are applied to TLS fingerprints in ServerKey responses.
type TLSFingerprintChecks struct {
ValidSHA256 bool // The TLS fingerprint includes a valid SHA-256 hash.
}
// KeyChecks are the checks that should be applied to ServerKey responses.
type KeyChecks struct {
AllChecksOK bool // Did all the checks pass?
MatchingServerName bool // Does the server name match what was requested.
FutureValidUntilTS bool // The valid until TS is in the future.
HasEd25519Key bool // The server has at least one ed25519 key.
AllEd25519ChecksOK *bool // All the Ed25519 checks are ok. or null if there weren't any to check.
Ed25519Checks map[string]Ed25519Checks // Checks for Ed25519 keys.
HasTLSFingerprint bool // The server has at least one fingerprint.
AllTLSFingerprintChecksOK *bool // All the fingerpint checks are ok.
TLSFingerprintChecks []TLSFingerprintChecks // Checks for TLS fingerprints.
MatchingTLSFingerprint *bool // The TLS fingerprint for the connection matches one of the listed fingerprints.
}
// CheckKeys checks the keys returned from a server to make sure they are valid.
// If the checks pass then also return a map of key_id to Ed25519 public key and a list of SHA256 TLS fingerprints.
func CheckKeys(serverName string, now time.Time, keys ServerKeys, connState *tls.ConnectionState) (
checks KeyChecks, ed25519Keys map[string]Base64String, sha256Fingerprints []Base64String,
) {
checks.MatchingServerName = serverName == keys.ServerName
checks.FutureValidUntilTS = now.UnixNano() < keys.ValidUntilTS*1000000
checks.AllChecksOK = checks.MatchingServerName && checks.FutureValidUntilTS
ed25519Keys = checkVerifyKeys(keys, &checks)
sha256Fingerprints = checkTLSFingerprints(keys, &checks)
// Only check the fingerprint if we have the TLS connection state.
if connState != nil {
// Check the peer certificates.
matches := checkFingerprint(connState, sha256Fingerprints)
checks.MatchingTLSFingerprint = &matches
checks.AllChecksOK = checks.AllChecksOK && matches
}
if !checks.AllChecksOK {
sha256Fingerprints = nil
ed25519Keys = nil
}
return
}
func checkFingerprint(connState *tls.ConnectionState, sha256Fingerprints []Base64String) bool {
if len(connState.PeerCertificates) == 0 {
return false
}
cert := connState.PeerCertificates[0]
digest := sha256.Sum256(cert.Raw)
for _, fingerprint := range sha256Fingerprints {
if bytes.Compare(digest[:], fingerprint) == 0 {
return true
}
}
return false
}
func checkVerifyKeys(keys ServerKeys, checks *KeyChecks) map[string]Base64String {
allEd25519ChecksOK := true
checks.Ed25519Checks = map[string]Ed25519Checks{}
verifyKeys := map[string]Base64String{}
for keyID, keyData := range keys.VerifyKeys {
algorithm := strings.SplitN(keyID, ":", 2)[0]
publicKey := keyData.Key
if algorithm == "ed25519" {
checks.HasEd25519Key = true
checks.AllEd25519ChecksOK = &allEd25519ChecksOK
entry := Ed25519Checks{
ValidEd25519: len(publicKey) == 32,
}
if entry.ValidEd25519 {
err := VerifyJSON(keys.ServerName, keyID, []byte(publicKey), keys.Raw)
entry.MatchingSignature = err == nil
}
checks.Ed25519Checks[keyID] = entry
if entry.MatchingSignature {
verifyKeys[keyID] = publicKey
} else {
allEd25519ChecksOK = false
}
}
}
if checks.AllChecksOK {
checks.AllChecksOK = checks.HasEd25519Key && allEd25519ChecksOK
}
return verifyKeys
}
func checkTLSFingerprints(keys ServerKeys, checks *KeyChecks) []Base64String {
var fingerprints []Base64String
allTLSFingerprintChecksOK := true
for _, fingerprint := range keys.TLSFingerprints {
checks.HasTLSFingerprint = true
checks.AllTLSFingerprintChecksOK = &allTLSFingerprintChecksOK
entry := TLSFingerprintChecks{
ValidSHA256: len(fingerprint.SHA256) == sha256.Size,
}
checks.TLSFingerprintChecks = append(checks.TLSFingerprintChecks, entry)
if entry.ValidSHA256 {
fingerprints = append(fingerprints, fingerprint.SHA256)
} else {
allTLSFingerprintChecksOK = false
}
}
if checks.AllChecksOK {
checks.AllChecksOK = checks.HasTLSFingerprint && allTLSFingerprintChecksOK
}
return fingerprints
}

View File

@ -0,0 +1,161 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
)
// rawJSON is a reimplementation of json.RawMessage that supports being used as a value type
//
// For example:
//
// jsonBytes, _ := json.Marshal(struct{
// RawMessage json.RawMessage
// RawJSON rawJSON
// }{
// json.RawMessage(`"Hello"`),
// rawJSON(`"World"`),
// })
//
// Results in:
//
// {"RawMessage":"IkhlbGxvIg==","RawJSON":"World"}
//
// See https://play.golang.org/p/FzhKIJP8-I for a full example.
type rawJSON []byte
// MarshalJSON implements the json.Marshaller interface using a value receiver.
// This means that rawJSON used as an embedded value will still encode correctly.
func (r rawJSON) MarshalJSON() ([]byte, error) {
return []byte(r), nil
}
// UnmarshalJSON implements the json.Unmarshaller interface using a pointer receiver.
func (r *rawJSON) UnmarshalJSON(data []byte) error {
*r = rawJSON(data)
return nil
}
// redactEvent strips the user controlled fields from an event, but leaves the
// fields necessary for authenticating the event.
func redactEvent(eventJSON []byte) ([]byte, error) {
// createContent keeps the fields needed in a m.room.create event.
// Create events need to keep the creator.
// (In an ideal world they would keep the m.federate flag see matrix-org/synapse#1831)
type createContent struct {
Creator rawJSON `json:"creator,omitempty"`
}
// joinRulesContent keeps the fields needed in a m.room.join_rules event.
// Join rules events need to keep the join_rule key.
type joinRulesContent struct {
JoinRule rawJSON `json:"join_rule,omitempty"`
}
// powerLevelContent keeps the fields needed in a m.room.power_levels event.
// Power level events need to keep all the levels.
type powerLevelContent struct {
Users rawJSON `json:"users,omitempty"`
UsersDefault rawJSON `json:"users_default,omitempty"`
Events rawJSON `json:"events,omitempty"`
EventsDefault rawJSON `json:"events_default,omitempty"`
StateDefault rawJSON `json:"state_default,omitempty"`
Ban rawJSON `json:"ban,omitempty"`
Kick rawJSON `json:"kick,omitempty"`
Redact rawJSON `json:"redact,omitempty"`
}
// memberContent keeps the fields needed in a m.room.member event.
// Member events keep the membership.
// (In an ideal world they would keep the third_party_invite see matrix-org/synapse#1831)
type memberContent struct {
Membership rawJSON `json:"membership,omitempty"`
}
// aliasesContent keeps the fields needed in a m.room.aliases event.
// TODO: Alias events probably don't need to keep the aliases key, but we need to match synapse here.
type aliasesContent struct {
Aliases rawJSON `json:"aliases,omitempty"`
}
// historyVisibilityContent keeps the fields needed in a m.room.history_visibility event
// History visibility events need to keep the history_visibility key.
type historyVisibilityContent struct {
HistoryVisibility rawJSON `json:"history_visibility,omitempty"`
}
// allContent keeps the union of all the content fields needed across all the event types.
// All the content JSON keys we are keeping are distinct across the different event types.
type allContent struct {
createContent
joinRulesContent
powerLevelContent
memberContent
aliasesContent
historyVisibilityContent
}
// eventFields keeps the top level keys needed by all event types.
// (In an ideal world they would include the "redacts" key for m.room.redaction events, see matrix-org/synapse#1831)
// See https://github.com/matrix-org/synapse/blob/v0.18.7/synapse/events/utils.py#L42-L56 for the list of fields
type eventFields struct {
EventID rawJSON `json:"event_id,omitempty"`
Sender rawJSON `json:"sender,omitempty"`
RoomID rawJSON `json:"room_id,omitempty"`
Hashes rawJSON `json:"hashes,omitempty"`
Signatures rawJSON `json:"signatures,omitempty"`
Content allContent `json:"content"`
Type string `json:"type"`
StateKey rawJSON `json:"state_key,omitempty"`
Depth rawJSON `json:"depth,omitempty"`
PrevEvents rawJSON `json:"prev_events,omitempty"`
PrevState rawJSON `json:"prev_state,omitempty"`
AuthEvents rawJSON `json:"auth_events,omitempty"`
Origin rawJSON `json:"origin,omitempty"`
OriginServerTS rawJSON `json:"origin_server_ts,omitempty"`
Membership rawJSON `json:"membership,omitempty"`
}
var event eventFields
// Unmarshalling into a struct will discard any extra fields from the event.
if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err
}
var newContent allContent
// Copy the content fields that we should keep for the event type.
// By default we copy nothing leaving the content object empty.
switch event.Type {
case "m.room.create":
newContent.createContent = event.Content.createContent
case "m.room.member":
newContent.memberContent = event.Content.memberContent
case "m.room.join_rules":
newContent.joinRulesContent = event.Content.joinRulesContent
case "m.room.power_levels":
newContent.powerLevelContent = event.Content.powerLevelContent
case "m.room.history_visibility":
newContent.historyVisibilityContent = event.Content.historyVisibilityContent
case "m.room.aliases":
newContent.aliasesContent = event.Content.aliasesContent
}
// Replace the content with our new filtered content.
// This will zero out any keys that weren't copied in the switch statement above.
event.Content = newContent
// Return the redacted event encoded as JSON.
return json.Marshal(&event)
}

View File

@ -0,0 +1,108 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"net"
"strconv"
"strings"
)
// A HostResult is the result of looking up the IP addresses for a host.
type HostResult struct {
CName string // The canonical name for the host.
Addrs []string // The IP addresses for the host.
Error error // If there was an error getting the IP addresses.
}
// A DNSResult is the result of looking up a matrix server in DNS.
type DNSResult struct {
SRVCName string // The canonical name for the SRV record in DNS
SRVRecords []*net.SRV // List of SRV record for the matrix server.
SRVError error // If there was an error getting the SRV records.
Hosts map[string]HostResult // The results of looking up the SRV record targets.
Addrs []string // List of "<ip>:<port>" strings that the server is listening on. These strings can be passed to `net.Dial()`.
}
// LookupServer looks up a matrix server in DNS.
func LookupServer(serverName string) (*DNSResult, error) {
var result DNSResult
result.Hosts = map[string]HostResult{}
hosts := map[string][]net.SRV{}
if strings.Index(serverName, ":") == -1 {
// If there isn't an explicit port set then try to look up the SRV record.
var err error
result.SRVCName, result.SRVRecords, err = net.LookupSRV("matrix", "tcp", serverName)
result.SRVError = err
if err != nil {
if dnserr, ok := err.(*net.DNSError); ok {
// If the error is a network timeout talking to the DNS server
// then give up now rather than trying to fallback.
if dnserr.Timeout() {
return nil, err
}
// If there isn't a SRV record in DNS then fallback to "serverName:8448".
hosts[serverName] = []net.SRV{net.SRV{
Target: serverName,
Port: 8448,
}}
}
} else {
// Group the SRV records by target host.
for _, record := range result.SRVRecords {
hosts[record.Target] = append(hosts[record.Target], *record)
}
}
} else {
// There is a explicit port set in the server name.
// We don't need to look up any SRV records.
host, portStr, err := net.SplitHostPort(serverName)
if err != nil {
return nil, err
}
var port uint64
port, err = strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
hosts[host] = []net.SRV{net.SRV{
Target: host,
Port: uint16(port),
}}
}
// Look up the IP addresses for each host.
for host, records := range hosts {
// Ignore any DNS errors when looking up the CNAME. We only are interested in it for debugging.
cname, _ := net.LookupCNAME(host)
addrs, err := net.LookupHost(host)
result.Hosts[host] = HostResult{
CName: cname,
Addrs: addrs,
Error: err,
}
// For each SRV record, for each IP address add a "<ip>:<port>" entry to the list of addresses.
for _, record := range records {
for _, addr := range addrs {
ipPort := net.JoinHostPort(addr, strconv.Itoa(int(record.Port)))
result.Addrs = append(result.Addrs, ipPort)
}
}
}
return &result, nil
}

View File

@ -0,0 +1,120 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"fmt"
"golang.org/x/crypto/ed25519"
)
// SignJSON signs a JSON object returning a copy signed with the given key.
// https://matrix.org/docs/spec/server_server/unstable.html#signing-json
func SignJSON(signingName, keyID string, privateKey ed25519.PrivateKey, message []byte) ([]byte, error) {
var object map[string]*json.RawMessage
var signatures map[string]map[string]Base64String
if err := json.Unmarshal(message, &object); err != nil {
return nil, err
}
rawUnsigned, hasUnsigned := object["unsigned"]
delete(object, "unsigned")
if rawSignatures := object["signatures"]; rawSignatures != nil {
if err := json.Unmarshal(*rawSignatures, &signatures); err != nil {
return nil, err
}
delete(object, "signatures")
} else {
signatures = map[string]map[string]Base64String{}
}
unsorted, err := json.Marshal(object)
if err != nil {
return nil, err
}
canonical, err := CanonicalJSON(unsorted)
if err != nil {
return nil, err
}
signature := Base64String(ed25519.Sign(privateKey, canonical))
signaturesForEntity := signatures[signingName]
if signaturesForEntity != nil {
signaturesForEntity[keyID] = signature
} else {
signatures[signingName] = map[string]Base64String{keyID: signature}
}
var rawSignatures json.RawMessage
rawSignatures, err = json.Marshal(signatures)
if err != nil {
return nil, err
}
object["signatures"] = &rawSignatures
if hasUnsigned {
object["unsigned"] = rawUnsigned
}
return json.Marshal(object)
}
// VerifyJSON checks that the entity has signed the message using a particular key.
func VerifyJSON(signingName, keyID string, publicKey ed25519.PublicKey, message []byte) error {
var object map[string]*json.RawMessage
var signatures map[string]map[string]Base64String
if err := json.Unmarshal(message, &object); err != nil {
return err
}
delete(object, "unsigned")
if object["signatures"] == nil {
return fmt.Errorf("No signatures")
}
if err := json.Unmarshal(*object["signatures"], &signatures); err != nil {
return err
}
delete(object, "signatures")
signature, ok := signatures[signingName][keyID]
if !ok {
return fmt.Errorf("No signature from %q with ID %q", signingName, keyID)
}
if len(signature) != ed25519.SignatureSize {
return fmt.Errorf("Bad signature length from %q with ID %q", signingName, keyID)
}
unsorted, err := json.Marshal(object)
if err != nil {
return err
}
canonical, err := CanonicalJSON(unsorted)
if err != nil {
return err
}
if !ed25519.Verify(publicKey, canonical, signature) {
return fmt.Errorf("Bad signature from %q with ID %q", signingName, keyID)
}
return nil
}

View File

@ -0,0 +1,223 @@
/* Copyright 2016-2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"bytes"
"encoding/base64"
"encoding/json"
"testing"
"golang.org/x/crypto/ed25519"
)
func TestVerifyJSON(t *testing.T) {
// Check JSON verification using the test vectors from https://matrix.org/docs/spec/appendices.html
seed, err := base64.RawStdEncoding.DecodeString("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
if err != nil {
t.Fatal(err)
}
random := bytes.NewBuffer(seed)
entityName := "domain"
keyID := "ed25519:1"
publicKey, _, err := ed25519.GenerateKey(random)
if err != nil {
t.Fatal(err)
}
testVerifyOK := func(input string) {
err := VerifyJSON(entityName, keyID, publicKey, []byte(input))
if err != nil {
t.Fatal(err)
}
}
testVerifyNotOK := func(reason, input string) {
err := VerifyJSON(entityName, keyID, publicKey, []byte(input))
if err == nil {
t.Fatalf("Expected VerifyJSON to fail for input %v because %v", input, reason)
}
}
testVerifyOK(`{
"signatures": {
"domain": {
"ed25519:1": "K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ"
}
}
}`)
testVerifyNotOK("the json is modified", `{
"a new key": "a new value",
"signatures": {
"domain": {
"ed25519:1": "K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ"
}
}
}`)
testVerifyNotOK("the signature is modified", `{
"a new key": "a new value",
"signatures": {
"domain": {
"ed25519:1": "modifiedSSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ"
}
}
}`)
testVerifyNotOK("there are no signatures", `{}`)
testVerifyNotOK("there are no signatures", `{"signatures": {}}`)
testVerifyNotOK("there are not signatures for domain", `{
"signatures": {"domain": {}}
}`)
testVerifyNotOK("the signature has the wrong key_id", `{
"signatures": { "domain": {
"ed25519:2":"KqmLSbO39/Bzb0QIYE82zqLwsA+PDzYIpIRA2sRQ4sL53+sN6/fpNSoqE7BP7vBZhG6kYdD13EIMJpvhJI+6Bw"
}}
}`)
testVerifyNotOK("the signature is too short for ed25519", `{"signatures": {"domain": {"ed25519:1":"not/a/valid/signature"}}}`)
testVerifyNotOK("the signature has base64 padding that it shouldn't have", `{
"signatures": { "domain": {
"ed25519:1": "K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ=="
}}
}`)
}
func TestSignJSON(t *testing.T) {
random := bytes.NewBuffer([]byte("Some 32 randomly generated bytes"))
entityName := "example.com"
keyID := "ed25519:my_key_id"
input := []byte(`{"this":"is","my":"message"}`)
publicKey, privateKey, err := ed25519.GenerateKey(random)
if err != nil {
t.Fatal(err)
}
signed, err := SignJSON(entityName, keyID, privateKey, input)
if err != nil {
t.Fatal(err)
}
err = VerifyJSON(entityName, keyID, publicKey, signed)
if err != nil {
t.Errorf("VerifyJSON(%q)", signed)
t.Fatal(err)
}
}
func IsJSONEqual(a, b []byte) bool {
canonicalA, err := CanonicalJSON(a)
if err != nil {
panic(err)
}
canonicalB, err := CanonicalJSON(b)
if err != nil {
panic(err)
}
return string(canonicalA) == string(canonicalB)
}
func TestSignJSONTestVectors(t *testing.T) {
// Check JSON signing using the test vectors from https://matrix.org/docs/spec/appendices.html
seed, err := base64.RawStdEncoding.DecodeString("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
if err != nil {
t.Fatal(err)
}
random := bytes.NewBuffer(seed)
entityName := "domain"
keyID := "ed25519:1"
_, privateKey, err := ed25519.GenerateKey(random)
if err != nil {
t.Fatal(err)
}
testSign := func(input string, want string) {
signed, err := SignJSON(entityName, keyID, privateKey, []byte(input))
if err != nil {
t.Fatal(err)
}
if !IsJSONEqual([]byte(want), signed) {
t.Fatalf("VerifyJSON(%q): want %v got %v", input, want, string(signed))
}
}
testSign(`{}`, `{
"signatures":{
"domain":{
"ed25519:1":"K8280/U9SSy9IVtjBuVeLr+HpOB4BQFWbg+UZaADMtTdGYI7Geitb76LTrr5QV/7Xg4ahLwYGYZzuHGZKM5ZAQ"
}
}
}`)
testSign(`{"one":1,"two":"Two"}`, `{
"one": 1,
"signatures": {
"domain": {
"ed25519:1": "KqmLSbO39/Bzb0QIYE82zqLwsA+PDzYIpIRA2sRQ4sL53+sN6/fpNSoqE7BP7vBZhG6kYdD13EIMJpvhJI+6Bw"
}
},
"two": "Two"
}`)
}
type MyMessage struct {
Unsigned *json.RawMessage `json:"unsigned"`
Content *json.RawMessage `json:"content"`
Signatures *json.RawMessage `json:"signature,omitempty"`
}
func TestSignJSONWithUnsigned(t *testing.T) {
random := bytes.NewBuffer([]byte("Some 32 randomly generated bytes"))
entityName := "example.com"
keyID := "ed25519:my_key_id"
content := json.RawMessage(`{"signed":"data"}`)
unsigned := json.RawMessage(`{"unsigned":"data"}`)
message := MyMessage{&unsigned, &content, nil}
input, err := json.Marshal(&message)
if err != nil {
t.Fatal(err)
}
publicKey, privateKey, err := ed25519.GenerateKey(random)
if err != nil {
t.Fatal(err)
}
signed, err := SignJSON(entityName, keyID, privateKey, input)
if err != nil {
t.Fatal(err)
}
if err2 := json.Unmarshal(signed, &message); err2 != nil {
t.Fatal(err2)
}
newUnsigned := json.RawMessage(`{"different":"data"}`)
message.Unsigned = &newUnsigned
input, err = json.Marshal(&message)
if err != nil {
t.Fatal(err)
}
err = VerifyJSON(entityName, keyID, publicKey, signed)
if err != nil {
t.Errorf("VerifyJSON(%q)", signed)
t.Fatal(err)
}
}

View File

@ -0,0 +1,292 @@
package gomatrixserverlib
import (
"bytes"
"crypto/sha1"
"fmt"
"sort"
)
// ResolveStateConflicts takes a list of state events with conflicting state keys
// and works out which event should be used for each state event.
func ResolveStateConflicts(conflicted []Event, authEvents []Event) []Event {
var r stateResolver
// Group the conflicted events by type and state key.
r.addConflicted(conflicted)
// Add the unconflicted auth events needed for auth checks.
for i := range authEvents {
r.addAuthEvent(&authEvents[i])
}
// Resolve the conflicted auth events.
r.resolveAndAddAuthBlocks([][]Event{r.creates})
r.resolveAndAddAuthBlocks([][]Event{r.powerLevels})
r.resolveAndAddAuthBlocks([][]Event{r.joinRules})
r.resolveAndAddAuthBlocks(r.thirdPartyInvites)
r.resolveAndAddAuthBlocks(r.members)
// Resolve any other conflicted state events.
for _, block := range r.others {
if event := r.resolveNormalBlock(block); event != nil {
r.result = append(r.result, *event)
}
}
return r.result
}
// A stateResolver tracks the internal state of the state resolution algorithm
// It has 3 sections:
//
// * Lists of lists of events to resolve grouped by event type and state key.
// * The resolved auth events grouped by type and state key.
// * A List of resolved events.
//
// It implements the AuthEvents interface and can be used for running auth checks.
type stateResolver struct {
// Lists of lists of events to resolve grouped by event type and state key:
// * creates, powerLevels, joinRules have empty state keys.
// * members and thirdPartyInvites are grouped by state key.
// * the others are grouped by the pair of type and state key.
creates []Event
powerLevels []Event
joinRules []Event
thirdPartyInvites [][]Event
members [][]Event
others [][]Event
// The resolved auth events grouped by type and state key.
resolvedCreate *Event
resolvedPowerLevels *Event
resolvedJoinRules *Event
resolvedThirdPartyInvites map[string]*Event
resolvedMembers map[string]*Event
// The list of resolved events.
// This will contain one entry for each conflicted event type and state key.
result []Event
}
func (r *stateResolver) Create() (*Event, error) {
return r.resolvedCreate, nil
}
func (r *stateResolver) PowerLevels() (*Event, error) {
return r.resolvedPowerLevels, nil
}
func (r *stateResolver) JoinRules() (*Event, error) {
return r.resolvedJoinRules, nil
}
func (r *stateResolver) ThirdPartyInvite(key string) (*Event, error) {
return r.resolvedThirdPartyInvites[key], nil
}
func (r *stateResolver) Member(key string) (*Event, error) {
return r.resolvedMembers[key], nil
}
func (r *stateResolver) addConflicted(events []Event) {
type conflictKey struct {
eventType string
stateKey string
}
offsets := map[conflictKey]int{}
// Split up the conflicted events into blocks with the same type and state key.
// Separate the auth events into specifically named lists because they have
// special rules for state resolution.
for _, event := range events {
key := conflictKey{event.Type(), *event.StateKey()}
// Work out which block to add the event to.
// By default we add the event to a block in the others list.
blockList := &r.others
switch key.eventType {
case "m.room.create":
if key.stateKey == "" {
r.creates = append(r.creates, event)
continue
}
case "m.room.power_levels":
if key.stateKey == "" {
r.powerLevels = append(r.powerLevels, event)
continue
}
case "m.room.join_rules":
if key.stateKey == "" {
r.joinRules = append(r.joinRules, event)
continue
}
case "m.room.member":
blockList = &r.members
case "m.room.third_party_invite":
blockList = &r.thirdPartyInvites
}
// We need to find an entry for the state key in a block list.
offset, ok := offsets[key]
if !ok {
// This is the first time we've seen that state key so we add a
// new block to the block list.
offset = len(*blockList)
*blockList = append(*blockList, nil)
}
// Get the address of the block in the block list.
block := &(*blockList)[offset]
// Add the event to the block.
*block = append(*block, event)
}
}
// Add an event to the resolved auth events.
func (r *stateResolver) addAuthEvent(event *Event) {
switch event.Type() {
case "m.room.create":
if event.StateKeyEquals("") {
r.resolvedCreate = event
}
case "m.room.power_levels":
if event.StateKeyEquals("") {
r.resolvedPowerLevels = event
}
case "m.room.join_rules":
if event.StateKeyEquals("") {
r.resolvedJoinRules = event
}
case "m.room.member":
r.resolvedMembers[*event.StateKey()] = event
case "m.room.third_party_invite":
r.resolvedThirdPartyInvites[*event.StateKey()] = event
default:
panic(fmt.Errorf("Unexpected auth event with type %q", event.Type()))
}
}
// Remove the auth event with the given type and state key.
func (r *stateResolver) removeAuthEvent(eventType, stateKey string) {
switch eventType {
case "m.room.create":
if stateKey == "" {
r.resolvedCreate = nil
}
case "m.room.power_levels":
if stateKey == "" {
r.resolvedPowerLevels = nil
}
case "m.room.join_rules":
if stateKey == "" {
r.resolvedJoinRules = nil
}
case "m.room.member":
r.resolvedMembers[stateKey] = nil
case "m.room.third_party_invite":
r.resolvedThirdPartyInvites[stateKey] = nil
default:
panic(fmt.Errorf("Unexpected auth event with type %q", eventType))
}
}
// resolveAndAddAuthBlocks resolves each block of conflicting auth state events in a list of blocks
// where all the blocks have the same event type.
// Once every block has been resolved the resulting events are added to the events used for auth checks.
// This is called once per auth event type and state key pair.
func (r *stateResolver) resolveAndAddAuthBlocks(blocks [][]Event) {
start := len(r.result)
for _, block := range blocks {
if event := r.resolveAuthBlock(block); event != nil {
r.result = append(r.result, *event)
}
}
// Only add the events to the auth events once all of the events with that type have been resolved.
// (SPEC: This is done to avoid the result of state resolution depending on the iteration order)
for i := start; i < len(r.result); i++ {
r.addAuthEvent(&r.result[i])
}
}
// resolveAuthBlock resolves a block of auth events with the same state key to a single event.
func (r *stateResolver) resolveAuthBlock(events []Event) *Event {
// Sort the events by depth and sha1 of event ID
block := sortConflictedEventsByDepthAndSHA1(events)
// Pick the "oldest" event, that is the one with the lowest depth, as the first candidate.
// If none of the newer events pass auth checks against this event then we pick the "oldest" event.
// (SPEC: This ensures that we always pick a state event for this type and state key.
// Note that if all the events fail auth checks we will still pick the "oldest" event.)
result := block[0].event
// Temporarily add the candidate event to the auth events.
r.addAuthEvent(result)
for i := 1; i < len(block); i++ {
event := block[i].event
// Check if the next event passes authentication checks against the current candidate.
// (SPEC: This ensures that "ban" events cannot be replaced by "join" events through a conflict)
if Allowed(*event, r) == nil {
// If the event passes authentication checks pick it as the current candidate.
// (SPEC: This prefers newer events so that we don't flip a valid state back to a previous version)
result = event
r.addAuthEvent(result)
} else {
// If the authentication check fails then we stop iterating the list and return the current candidate.
break
}
}
// Discard the event from the auth events.
// We'll add it back later when all events of the same type have been resolved.
// (SPEC: This is done to avoid the result of state resolution depending on the iteration order)
r.removeAuthEvent(result.Type(), *result.StateKey())
return result
}
// resolveNormalBlock resolves a block of normal state events with the same state key to a single event.
func (r *stateResolver) resolveNormalBlock(events []Event) *Event {
// Sort the events by depth and sha1 of event ID
block := sortConflictedEventsByDepthAndSHA1(events)
// Start at the "newest" event, that is the one with the highest depth, and go
// backward through the list until we find one that passes authentication checks.
// (SPEC: This prefers newer events so that we don't flip a valid state back to a previous version)
for i := len(block) - 1; i > 0; i-- {
event := block[i].event
if Allowed(*event, r) == nil {
return event
}
}
// If all the auth checks for newer events fail then we pick the oldest event.
// (SPEC: This ensures that we always pick a state event for this type and state key.
// Note that if all the events fail auth checks we will still pick the "oldest" event.)
return block[0].event
}
// sortConflictedEventsByDepthAndSHA1 sorts by ascending depth and descending sha1 of event ID.
func sortConflictedEventsByDepthAndSHA1(events []Event) []conflictedEvent {
block := make([]conflictedEvent, len(events))
for i := range events {
event := &events[i]
block[i] = conflictedEvent{
depth: event.Depth(),
eventIDSHA1: sha1.Sum([]byte(event.EventID())),
event: event,
}
}
sort.Sort(conflictedEventSorter(block))
return block
}
// A conflictedEvent is used to sort the events in a block by ascending depth and descending sha1 of event ID.
// (SPEC: We use the SHA1 of the event ID as an arbitrary tie breaker between events with the same depth)
type conflictedEvent struct {
depth int64
eventIDSHA1 [sha1.Size]byte
event *Event
}
// A conflictedEventSorter is used to sort the events using sort.Sort.
type conflictedEventSorter []conflictedEvent
func (s conflictedEventSorter) Len() int {
return len(s)
}
func (s conflictedEventSorter) Less(i, j int) bool {
if s[i].depth == s[j].depth {
return bytes.Compare(s[i].eventIDSHA1[:], s[j].eventIDSHA1[:]) > 0
}
return s[i].depth < s[j].depth
}
func (s conflictedEventSorter) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}

View File

@ -0,0 +1,36 @@
package gomatrixserverlib
import (
"testing"
)
const (
sha1OfEventID1A = "\xe5\x89,\xa2\x1cF<&\xf3\rf}\xde\xa5\xef;\xddK\xaaS"
sha1OfEventID2A = "\xa4\xe4\x10\x1b}\x1a\xf9`\x94\x10\xa3\x84+\xae\x06\x8d\x16A\xfc>"
sha1OfEventID3B = "\xca\xe8\xde\xb6\xa3\xb6\xee\x01\xc4\xbc\xd0/\x1b\x1c2\x0c\xd3\xa4\xe9\xcb"
)
func TestConflictEventSorter(t *testing.T) {
input := []Event{
Event{fields: eventFields{Depth: 1, EventID: "@1:a"}},
Event{fields: eventFields{Depth: 2, EventID: "@2:a"}},
Event{fields: eventFields{Depth: 2, EventID: "@3:b"}},
}
got := sortConflictedEventsByDepthAndSHA1(input)
want := []conflictedEvent{
conflictedEvent{depth: 1, event: &input[0]},
conflictedEvent{depth: 2, event: &input[2]},
conflictedEvent{depth: 2, event: &input[1]},
}
copy(want[0].eventIDSHA1[:], sha1OfEventID1A)
copy(want[1].eventIDSHA1[:], sha1OfEventID3B)
copy(want[2].eventIDSHA1[:], sha1OfEventID2A)
if len(want) != len(got) {
t.Fatalf("Different length: wanted %d, got %d", len(want), len(got))
}
for i := range want {
if want[i] != got[i] {
t.Fatalf("Different element at index %d: wanted %#v got %#v", i, want[i], got[i])
}
}
}

View File

@ -0,0 +1,181 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ed25519 implements the Ed25519 signature algorithm. See
// http://ed25519.cr.yp.to/.
//
// These functions are also compatible with the “Ed25519” function defined in
// https://tools.ietf.org/html/draft-irtf-cfrg-eddsa-05.
package ed25519
// This code is a port of the public domain, “ref10” implementation of ed25519
// from SUPERCOP.
import (
"crypto"
cryptorand "crypto/rand"
"crypto/sha512"
"crypto/subtle"
"errors"
"io"
"strconv"
"golang.org/x/crypto/ed25519/internal/edwards25519"
)
const (
// PublicKeySize is the size, in bytes, of public keys as used in this package.
PublicKeySize = 32
// PrivateKeySize is the size, in bytes, of private keys as used in this package.
PrivateKeySize = 64
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = 64
)
// PublicKey is the type of Ed25519 public keys.
type PublicKey []byte
// PrivateKey is the type of Ed25519 private keys. It implements crypto.Signer.
type PrivateKey []byte
// Public returns the PublicKey corresponding to priv.
func (priv PrivateKey) Public() crypto.PublicKey {
publicKey := make([]byte, PublicKeySize)
copy(publicKey, priv[32:])
return PublicKey(publicKey)
}
// Sign signs the given message with priv.
// Ed25519 performs two passes over messages to be signed and therefore cannot
// handle pre-hashed messages. Thus opts.HashFunc() must return zero to
// indicate the message hasn't been hashed. This can be achieved by passing
// crypto.Hash(0) as the value for opts.
func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) {
if opts.HashFunc() != crypto.Hash(0) {
return nil, errors.New("ed25519: cannot sign hashed message")
}
return Sign(priv, message), nil
}
// GenerateKey generates a public/private key pair using entropy from rand.
// If rand is nil, crypto/rand.Reader will be used.
func GenerateKey(rand io.Reader) (publicKey PublicKey, privateKey PrivateKey, err error) {
if rand == nil {
rand = cryptorand.Reader
}
privateKey = make([]byte, PrivateKeySize)
publicKey = make([]byte, PublicKeySize)
_, err = io.ReadFull(rand, privateKey[:32])
if err != nil {
return nil, nil, err
}
digest := sha512.Sum512(privateKey[:32])
digest[0] &= 248
digest[31] &= 127
digest[31] |= 64
var A edwards25519.ExtendedGroupElement
var hBytes [32]byte
copy(hBytes[:], digest[:])
edwards25519.GeScalarMultBase(&A, &hBytes)
var publicKeyBytes [32]byte
A.ToBytes(&publicKeyBytes)
copy(privateKey[32:], publicKeyBytes[:])
copy(publicKey, publicKeyBytes[:])
return publicKey, privateKey, nil
}
// Sign signs the message with privateKey and returns a signature. It will
// panic if len(privateKey) is not PrivateKeySize.
func Sign(privateKey PrivateKey, message []byte) []byte {
if l := len(privateKey); l != PrivateKeySize {
panic("ed25519: bad private key length: " + strconv.Itoa(l))
}
h := sha512.New()
h.Write(privateKey[:32])
var digest1, messageDigest, hramDigest [64]byte
var expandedSecretKey [32]byte
h.Sum(digest1[:0])
copy(expandedSecretKey[:], digest1[:])
expandedSecretKey[0] &= 248
expandedSecretKey[31] &= 63
expandedSecretKey[31] |= 64
h.Reset()
h.Write(digest1[32:])
h.Write(message)
h.Sum(messageDigest[:0])
var messageDigestReduced [32]byte
edwards25519.ScReduce(&messageDigestReduced, &messageDigest)
var R edwards25519.ExtendedGroupElement
edwards25519.GeScalarMultBase(&R, &messageDigestReduced)
var encodedR [32]byte
R.ToBytes(&encodedR)
h.Reset()
h.Write(encodedR[:])
h.Write(privateKey[32:])
h.Write(message)
h.Sum(hramDigest[:0])
var hramDigestReduced [32]byte
edwards25519.ScReduce(&hramDigestReduced, &hramDigest)
var s [32]byte
edwards25519.ScMulAdd(&s, &hramDigestReduced, &expandedSecretKey, &messageDigestReduced)
signature := make([]byte, SignatureSize)
copy(signature[:], encodedR[:])
copy(signature[32:], s[:])
return signature
}
// Verify reports whether sig is a valid signature of message by publicKey. It
// will panic if len(publicKey) is not PublicKeySize.
func Verify(publicKey PublicKey, message, sig []byte) bool {
if l := len(publicKey); l != PublicKeySize {
panic("ed25519: bad public key length: " + strconv.Itoa(l))
}
if len(sig) != SignatureSize || sig[63]&224 != 0 {
return false
}
var A edwards25519.ExtendedGroupElement
var publicKeyBytes [32]byte
copy(publicKeyBytes[:], publicKey)
if !A.FromBytes(&publicKeyBytes) {
return false
}
edwards25519.FeNeg(&A.X, &A.X)
edwards25519.FeNeg(&A.T, &A.T)
h := sha512.New()
h.Write(sig[:32])
h.Write(publicKey[:])
h.Write(message)
var digest [64]byte
h.Sum(digest[:0])
var hReduced [32]byte
edwards25519.ScReduce(&hReduced, &digest)
var R edwards25519.ProjectiveGroupElement
var b [32]byte
copy(b[:], sig[32:])
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &b)
var checkR [32]byte
R.ToBytes(&checkR)
return subtle.ConstantTimeCompare(sig[:32], checkR[:]) == 1
}

View File

@ -0,0 +1,183 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ed25519
import (
"bufio"
"bytes"
"compress/gzip"
"crypto"
"crypto/rand"
"encoding/hex"
"os"
"strings"
"testing"
"golang.org/x/crypto/ed25519/internal/edwards25519"
)
type zeroReader struct{}
func (zeroReader) Read(buf []byte) (int, error) {
for i := range buf {
buf[i] = 0
}
return len(buf), nil
}
func TestUnmarshalMarshal(t *testing.T) {
pub, _, _ := GenerateKey(rand.Reader)
var A edwards25519.ExtendedGroupElement
var pubBytes [32]byte
copy(pubBytes[:], pub)
if !A.FromBytes(&pubBytes) {
t.Fatalf("ExtendedGroupElement.FromBytes failed")
}
var pub2 [32]byte
A.ToBytes(&pub2)
if pubBytes != pub2 {
t.Errorf("FromBytes(%v)->ToBytes does not round-trip, got %x\n", pubBytes, pub2)
}
}
func TestSignVerify(t *testing.T) {
var zero zeroReader
public, private, _ := GenerateKey(zero)
message := []byte("test message")
sig := Sign(private, message)
if !Verify(public, message, sig) {
t.Errorf("valid signature rejected")
}
wrongMessage := []byte("wrong message")
if Verify(public, wrongMessage, sig) {
t.Errorf("signature of different message accepted")
}
}
func TestCryptoSigner(t *testing.T) {
var zero zeroReader
public, private, _ := GenerateKey(zero)
signer := crypto.Signer(private)
publicInterface := signer.Public()
public2, ok := publicInterface.(PublicKey)
if !ok {
t.Fatalf("expected PublicKey from Public() but got %T", publicInterface)
}
if !bytes.Equal(public, public2) {
t.Errorf("public keys do not match: original:%x vs Public():%x", public, public2)
}
message := []byte("message")
var noHash crypto.Hash
signature, err := signer.Sign(zero, message, noHash)
if err != nil {
t.Fatalf("error from Sign(): %s", err)
}
if !Verify(public, message, signature) {
t.Errorf("Verify failed on signature from Sign()")
}
}
func TestGolden(t *testing.T) {
// sign.input.gz is a selection of test cases from
// http://ed25519.cr.yp.to/python/sign.input
testDataZ, err := os.Open("testdata/sign.input.gz")
if err != nil {
t.Fatal(err)
}
defer testDataZ.Close()
testData, err := gzip.NewReader(testDataZ)
if err != nil {
t.Fatal(err)
}
defer testData.Close()
scanner := bufio.NewScanner(testData)
lineNo := 0
for scanner.Scan() {
lineNo++
line := scanner.Text()
parts := strings.Split(line, ":")
if len(parts) != 5 {
t.Fatalf("bad number of parts on line %d", lineNo)
}
privBytes, _ := hex.DecodeString(parts[0])
pubKey, _ := hex.DecodeString(parts[1])
msg, _ := hex.DecodeString(parts[2])
sig, _ := hex.DecodeString(parts[3])
// The signatures in the test vectors also include the message
// at the end, but we just want R and S.
sig = sig[:SignatureSize]
if l := len(pubKey); l != PublicKeySize {
t.Fatalf("bad public key length on line %d: got %d bytes", lineNo, l)
}
var priv [PrivateKeySize]byte
copy(priv[:], privBytes)
copy(priv[32:], pubKey)
sig2 := Sign(priv[:], msg)
if !bytes.Equal(sig, sig2[:]) {
t.Errorf("different signature result on line %d: %x vs %x", lineNo, sig, sig2)
}
if !Verify(pubKey, msg, sig2) {
t.Errorf("signature failed to verify on line %d", lineNo)
}
}
if err := scanner.Err(); err != nil {
t.Fatalf("error reading test data: %s", err)
}
}
func BenchmarkKeyGeneration(b *testing.B) {
var zero zeroReader
for i := 0; i < b.N; i++ {
if _, _, err := GenerateKey(zero); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSigning(b *testing.B) {
var zero zeroReader
_, priv, err := GenerateKey(zero)
if err != nil {
b.Fatal(err)
}
message := []byte("Hello, world!")
b.ResetTimer()
for i := 0; i < b.N; i++ {
Sign(priv, message)
}
}
func BenchmarkVerification(b *testing.B) {
var zero zeroReader
pub, priv, err := GenerateKey(zero)
if err != nil {
b.Fatal(err)
}
message := []byte("Hello, world!")
signature := Sign(priv, message)
b.ResetTimer()
for i := 0; i < b.N; i++ {
Verify(pub, message, signature)
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.