// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package agent

import (
	"crypto/dsa"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rsa"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"log"
	"math/big"

	"golang.org/x/crypto/ed25519"
	"golang.org/x/crypto/ssh"
)

// Server wraps an Agent and uses it to implement the agent side of
// the SSH-agent, wire protocol.
type server struct {
	agent Agent
}

func (s *server) processRequestBytes(reqData []byte) []byte {
	rep, err := s.processRequest(reqData)
	if err != nil {
		if err != errLocked {
			// TODO(hanwen): provide better logging interface?
			log.Printf("agent %d: %v", reqData[0], err)
		}
		return []byte{agentFailure}
	}

	if err == nil && rep == nil {
		return []byte{agentSuccess}
	}

	return ssh.Marshal(rep)
}

func marshalKey(k *Key) []byte {
	var record struct {
		Blob    []byte
		Comment string
	}
	record.Blob = k.Marshal()
	record.Comment = k.Comment

	return ssh.Marshal(&record)
}

// See [PROTOCOL.agent], section 2.5.1.
const agentV1IdentitiesAnswer = 2

type agentV1IdentityMsg struct {
	Numkeys uint32 `sshtype:"2"`
}

type agentRemoveIdentityMsg struct {
	KeyBlob []byte `sshtype:"18"`
}

type agentLockMsg struct {
	Passphrase []byte `sshtype:"22"`
}

type agentUnlockMsg struct {
	Passphrase []byte `sshtype:"23"`
}

func (s *server) processRequest(data []byte) (interface{}, error) {
	switch data[0] {
	case agentRequestV1Identities:
		return &agentV1IdentityMsg{0}, nil

	case agentRemoveAllV1Identities:
		return nil, nil

	case agentRemoveIdentity:
		var req agentRemoveIdentityMsg
		if err := ssh.Unmarshal(data, &req); err != nil {
			return nil, err
		}

		var wk wireKey
		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
			return nil, err
		}

		return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})

	case agentRemoveAllIdentities:
		return nil, s.agent.RemoveAll()

	case agentLock:
		var req agentLockMsg
		if err := ssh.Unmarshal(data, &req); err != nil {
			return nil, err
		}

		return nil, s.agent.Lock(req.Passphrase)

	case agentUnlock:
		var req agentUnlockMsg
		if err := ssh.Unmarshal(data, &req); err != nil {
			return nil, err
		}
		return nil, s.agent.Unlock(req.Passphrase)

	case agentSignRequest:
		var req signRequestAgentMsg
		if err := ssh.Unmarshal(data, &req); err != nil {
			return nil, err
		}

		var wk wireKey
		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
			return nil, err
		}

		k := &Key{
			Format: wk.Format,
			Blob:   req.KeyBlob,
		}

		var sig *ssh.Signature
		var err error
		if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
			sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
		} else {
			sig, err = s.agent.Sign(k, req.Data)
		}

		if err != nil {
			return nil, err
		}
		return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil

	case agentRequestIdentities:
		keys, err := s.agent.List()
		if err != nil {
			return nil, err
		}

		rep := identitiesAnswerAgentMsg{
			NumKeys: uint32(len(keys)),
		}
		for _, k := range keys {
			rep.Keys = append(rep.Keys, marshalKey(k)...)
		}
		return rep, nil

	case agentAddIDConstrained, agentAddIdentity:
		return nil, s.insertIdentity(data)

	case agentExtension:
		// Return a stub object where the whole contents of the response gets marshaled.
		var responseStub struct {
			Rest []byte `ssh:"rest"`
		}

		if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
			// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
			// requires that we return a standard SSH_AGENT_FAILURE message.
			responseStub.Rest = []byte{agentFailure}
		} else {
			var req extensionAgentMsg
			if err := ssh.Unmarshal(data, &req); err != nil {
				return nil, err
			}
			res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
			if err != nil {
				// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
				// message as required by [PROTOCOL.agent] section 4.7.
				if err == ErrExtensionUnsupported {
					responseStub.Rest = []byte{agentFailure}
				} else {
					// As the result of any other error processing an extension request,
					// [PROTOCOL.agent] section 4.7 requires that we return a
					// SSH_AGENT_EXTENSION_FAILURE code.
					responseStub.Rest = []byte{agentExtensionFailure}
				}
			} else {
				if len(res) == 0 {
					return nil, nil
				}
				responseStub.Rest = res
			}
		}

		return responseStub, nil
	}

	return nil, fmt.Errorf("unknown opcode %d", data[0])
}

func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
	for len(constraints) != 0 {
		switch constraints[0] {
		case agentConstrainLifetime:
			lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
			constraints = constraints[5:]
		case agentConstrainConfirm:
			confirmBeforeUse = true
			constraints = constraints[1:]
		case agentConstrainExtension:
			var msg constrainExtensionAgentMsg
			if err = ssh.Unmarshal(constraints, &msg); err != nil {
				return 0, false, nil, err
			}
			extensions = append(extensions, ConstraintExtension{
				ExtensionName:    msg.ExtensionName,
				ExtensionDetails: msg.ExtensionDetails,
			})
			constraints = msg.Rest
		default:
			return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
		}
	}
	return
}

func setConstraints(key *AddedKey, constraintBytes []byte) error {
	lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
	if err != nil {
		return err
	}

	key.LifetimeSecs = lifetimeSecs
	key.ConfirmBeforeUse = confirmBeforeUse
	key.ConstraintExtensions = constraintExtensions
	return nil
}

func parseRSAKey(req []byte) (*AddedKey, error) {
	var k rsaKeyMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}
	if k.E.BitLen() > 30 {
		return nil, errors.New("agent: RSA public exponent too large")
	}
	priv := &rsa.PrivateKey{
		PublicKey: rsa.PublicKey{
			E: int(k.E.Int64()),
			N: k.N,
		},
		D:      k.D,
		Primes: []*big.Int{k.P, k.Q},
	}
	priv.Precompute()

	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func parseEd25519Key(req []byte) (*AddedKey, error) {
	var k ed25519KeyMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}
	priv := ed25519.PrivateKey(k.Priv)

	addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func parseDSAKey(req []byte) (*AddedKey, error) {
	var k dsaKeyMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}
	priv := &dsa.PrivateKey{
		PublicKey: dsa.PublicKey{
			Parameters: dsa.Parameters{
				P: k.P,
				Q: k.Q,
				G: k.G,
			},
			Y: k.Y,
		},
		X: k.X,
	}

	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
	priv = &ecdsa.PrivateKey{
		D: privScalar,
	}

	switch curveName {
	case "nistp256":
		priv.Curve = elliptic.P256()
	case "nistp384":
		priv.Curve = elliptic.P384()
	case "nistp521":
		priv.Curve = elliptic.P521()
	default:
		return nil, fmt.Errorf("agent: unknown curve %q", curveName)
	}

	priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
	if priv.X == nil || priv.Y == nil {
		return nil, errors.New("agent: point not on curve")
	}

	return priv, nil
}

func parseEd25519Cert(req []byte) (*AddedKey, error) {
	var k ed25519CertMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}
	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
	if err != nil {
		return nil, err
	}
	priv := ed25519.PrivateKey(k.Priv)
	cert, ok := pubKey.(*ssh.Certificate)
	if !ok {
		return nil, errors.New("agent: bad ED25519 certificate")
	}

	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func parseECDSAKey(req []byte) (*AddedKey, error) {
	var k ecdsaKeyMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}

	priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
	if err != nil {
		return nil, err
	}

	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func parseRSACert(req []byte) (*AddedKey, error) {
	var k rsaCertMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}

	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
	if err != nil {
		return nil, err
	}

	cert, ok := pubKey.(*ssh.Certificate)
	if !ok {
		return nil, errors.New("agent: bad RSA certificate")
	}

	// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
	var rsaPub struct {
		Name string
		E    *big.Int
		N    *big.Int
	}
	if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
	}

	if rsaPub.E.BitLen() > 30 {
		return nil, errors.New("agent: RSA public exponent too large")
	}

	priv := rsa.PrivateKey{
		PublicKey: rsa.PublicKey{
			E: int(rsaPub.E.Int64()),
			N: rsaPub.N,
		},
		D:      k.D,
		Primes: []*big.Int{k.Q, k.P},
	}
	priv.Precompute()

	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func parseDSACert(req []byte) (*AddedKey, error) {
	var k dsaCertMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}
	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
	if err != nil {
		return nil, err
	}
	cert, ok := pubKey.(*ssh.Certificate)
	if !ok {
		return nil, errors.New("agent: bad DSA certificate")
	}

	// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
	var w struct {
		Name       string
		P, Q, G, Y *big.Int
	}
	if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
	}

	priv := &dsa.PrivateKey{
		PublicKey: dsa.PublicKey{
			Parameters: dsa.Parameters{
				P: w.P,
				Q: w.Q,
				G: w.G,
			},
			Y: w.Y,
		},
		X: k.X,
	}

	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func parseECDSACert(req []byte) (*AddedKey, error) {
	var k ecdsaCertMsg
	if err := ssh.Unmarshal(req, &k); err != nil {
		return nil, err
	}

	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
	if err != nil {
		return nil, err
	}
	cert, ok := pubKey.(*ssh.Certificate)
	if !ok {
		return nil, errors.New("agent: bad ECDSA certificate")
	}

	// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
	var ecdsaPub struct {
		Name string
		ID   string
		Key  []byte
	}
	if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
		return nil, err
	}

	priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
	if err != nil {
		return nil, err
	}

	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
	if err := setConstraints(addedKey, k.Constraints); err != nil {
		return nil, err
	}
	return addedKey, nil
}

func (s *server) insertIdentity(req []byte) error {
	var record struct {
		Type string `sshtype:"17|25"`
		Rest []byte `ssh:"rest"`
	}

	if err := ssh.Unmarshal(req, &record); err != nil {
		return err
	}

	var addedKey *AddedKey
	var err error

	switch record.Type {
	case ssh.KeyAlgoRSA:
		addedKey, err = parseRSAKey(req)
	case ssh.KeyAlgoDSA:
		addedKey, err = parseDSAKey(req)
	case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
		addedKey, err = parseECDSAKey(req)
	case ssh.KeyAlgoED25519:
		addedKey, err = parseEd25519Key(req)
	case ssh.CertAlgoRSAv01:
		addedKey, err = parseRSACert(req)
	case ssh.CertAlgoDSAv01:
		addedKey, err = parseDSACert(req)
	case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
		addedKey, err = parseECDSACert(req)
	case ssh.CertAlgoED25519v01:
		addedKey, err = parseEd25519Cert(req)
	default:
		return fmt.Errorf("agent: not implemented: %q", record.Type)
	}

	if err != nil {
		return err
	}
	return s.agent.Add(*addedKey)
}

// ServeAgent serves the agent protocol on the given connection. It
// returns when an I/O error occurs.
func ServeAgent(agent Agent, c io.ReadWriter) error {
	s := &server{agent}

	var length [4]byte
	for {
		if _, err := io.ReadFull(c, length[:]); err != nil {
			return err
		}
		l := binary.BigEndian.Uint32(length[:])
		if l == 0 {
			return fmt.Errorf("agent: request size is 0")
		}
		if l > maxAgentResponseBytes {
			// We also cap requests.
			return fmt.Errorf("agent: request too large: %d", l)
		}

		req := make([]byte, l)
		if _, err := io.ReadFull(c, req); err != nil {
			return err
		}

		repData := s.processRequestBytes(req)
		if len(repData) > maxAgentResponseBytes {
			return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
		}

		binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
		if _, err := c.Write(length[:]); err != nil {
			return err
		}
		if _, err := c.Write(repData); err != nil {
			return err
		}
	}
}