Implement /logout/all (#307)

Signed-off-by: Remi Reuvekamp <git@remireuvekamp.nl>
main
Remi Reuvekamp 2017-10-15 10:29:47 +00:00 committed by Erik Johnston
parent 32a2b3a5c0
commit 1a026f16d5
4 changed files with 59 additions and 4 deletions

View File

@ -57,12 +57,17 @@ const selectDeviceByTokenSQL = "" +
const deleteDeviceSQL = "" + const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
// TODO: List devices? // TODO: List devices?
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -80,6 +85,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return return
} }
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -110,6 +118,14 @@ func (s *devicesStatements) deleteDevice(
return err return err
} }
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*authtypes.Device, error) { ) (*authtypes.Device, error) {

View File

@ -109,3 +109,17 @@ func (d *Database) RemoveDevice(
return nil return nil
}) })
} }
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View File

@ -50,3 +50,22 @@ func Logout(
JSON: struct{}{}, JSON: struct{}{},
} }
} }
// LogoutAll handles POST /logout/all
func LogoutAll(
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return httputil.LogThenError(req, err)
}
if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: 200,
JSON: struct{}{},
}
}

View File

@ -160,6 +160,12 @@ func Setup(
}), }),
).Methods("POST", "OPTIONS") ).Methods("POST", "OPTIONS")
r0mux.Handle("/logout/all",
common.MakeAuthAPI("logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return LogoutAll(req, deviceDB, device)
}),
).Methods("POST", "OPTIONS")
// Stub endpoints required by Riot // Stub endpoints required by Riot
r0mux.Handle("/login", r0mux.Handle("/login",