package mssql

import (
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"
)

//go:generate stringer -type token

type token byte

// token ids
const (
	tokenReturnStatus  token = 121 // 0x79
	tokenColMetadata   token = 129 // 0x81
	tokenOrder         token = 169 // 0xA9
	tokenError         token = 170 // 0xAA
	tokenInfo          token = 171 // 0xAB
	tokenReturnValue   token = 0xAC
	tokenLoginAck      token = 173 // 0xad
	tokenFeatureExtAck token = 174 // 0xae
	tokenRow           token = 209 // 0xd1
	tokenNbcRow        token = 210 // 0xd2
	tokenEnvChange     token = 227 // 0xE3
	tokenSSPI          token = 237 // 0xED
	tokenDone          token = 253 // 0xFD
	tokenDoneProc      token = 254
	tokenDoneInProc    token = 255
)

// done flags
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
const (
	doneFinal    = 0
	doneMore     = 1
	doneError    = 2
	doneInxact   = 4
	doneCount    = 0x10
	doneAttn     = 0x20
	doneSrvError = 0x100
)

// ENVCHANGE types
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
const (
	envTypDatabase           = 1
	envTypLanguage           = 2
	envTypCharset            = 3
	envTypPacketSize         = 4
	envSortId                = 5
	envSortFlags             = 6
	envSqlCollation          = 7
	envTypBeginTran          = 8
	envTypCommitTran         = 9
	envTypRollbackTran       = 10
	envEnlistDTC             = 11
	envDefectTran            = 12
	envDatabaseMirrorPartner = 13
	envPromoteTran           = 15
	envTranMgrAddr           = 16
	envTranEnded             = 17
	envResetConnAck          = 18
	envStartedInstanceName   = 19
	envRouting               = 20
)

// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
	colFlagNullable = 1
	// TODO implement more flags
)

// interface for all tokens
type tokenStruct interface{}

type orderStruct struct {
	ColIds []uint16
}

type doneStruct struct {
	Status   uint16
	CurCmd   uint16
	RowCount uint64
	errors   []Error
}

func (d doneStruct) isError() bool {
	return d.Status&doneError != 0 || len(d.errors) > 0
}

func (d doneStruct) getError() Error {
	if len(d.errors) > 0 {
		return d.errors[len(d.errors)-1]
	} else {
		return Error{Message: "Request failed but didn't provide reason"}
	}
}

type doneInProcStruct doneStruct

var doneFlags2str = map[uint16]string{
	doneFinal:    "final",
	doneMore:     "more",
	doneError:    "error",
	doneInxact:   "inxact",
	doneCount:    "count",
	doneAttn:     "attn",
	doneSrvError: "srverror",
}

func doneFlags2Str(flags uint16) string {
	strs := make([]string, 0, len(doneFlags2str))
	for flag, tag := range doneFlags2str {
		if flags&flag != 0 {
			strs = append(strs, tag)
		}
	}
	return strings.Join(strs, "|")
}

// ENVCHANGE stream
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
func processEnvChg(sess *tdsSession) {
	size := sess.buf.uint16()
	r := &io.LimitedReader{R: sess.buf, N: int64(size)}
	for {
		var err error
		var envtype uint8
		err = binary.Read(r, binary.LittleEndian, &envtype)
		if err == io.EOF {
			return
		}
		if err != nil {
			badStreamPanic(err)
		}
		switch envtype {
		case envTypDatabase:
			sess.database, err = readBVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
			_, err = readBVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
		case envTypLanguage:
			// currently ignored
			// new value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// old value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envTypCharset:
			// currently ignored
			// new value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// old value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envTypPacketSize:
			packetsize, err := readBVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
			_, err = readBVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
			packetsizei, err := strconv.Atoi(packetsize)
			if err != nil {
				badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error())
			}
			sess.buf.ResizeBuffer(packetsizei)
		case envSortId:
			// currently ignored
			// new value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envSortFlags:
			// currently ignored
			// new value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envSqlCollation:
			// currently ignored
			var collationSize uint8
			err = binary.Read(r, binary.LittleEndian, &collationSize)
			if err != nil {
				badStreamPanic(err)
			}

			// SQL Collation data should contain 5 bytes in length
			if collationSize != 5 {
				badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize)
			}

			// 4 bytes, contains: LCID ColFlags Version
			var info uint32
			err = binary.Read(r, binary.LittleEndian, &info)
			if err != nil {
				badStreamPanic(err)
			}

			// 1 byte, contains: sortID
			var sortID uint8
			err = binary.Read(r, binary.LittleEndian, &sortID)
			if err != nil {
				badStreamPanic(err)
			}

			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envTypBeginTran:
			tranid, err := readBVarByte(r)
			if len(tranid) != 8 {
				badStreamPanicf("invalid size of transaction identifier: %d", len(tranid))
			}
			sess.tranid = binary.LittleEndian.Uint64(tranid)
			if err != nil {
				badStreamPanic(err)
			}
			if sess.logFlags&logTransaction != 0 {
				sess.log.Printf("BEGIN TRANSACTION %x\n", sess.tranid)
			}
			_, err = readBVarByte(r)
			if err != nil {
				badStreamPanic(err)
			}
		case envTypCommitTran, envTypRollbackTran:
			_, err = readBVarByte(r)
			if err != nil {
				badStreamPanic(err)
			}
			_, err = readBVarByte(r)
			if err != nil {
				badStreamPanic(err)
			}
			if sess.logFlags&logTransaction != 0 {
				if envtype == envTypCommitTran {
					sess.log.Printf("COMMIT TRANSACTION %x\n", sess.tranid)
				} else {
					sess.log.Printf("ROLLBACK TRANSACTION %x\n", sess.tranid)
				}
			}
			sess.tranid = 0
		case envEnlistDTC:
			// currently ignored
			// new value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// old value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envDefectTran:
			// currently ignored
			// new value
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envDatabaseMirrorPartner:
			sess.partner, err = readBVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
			_, err = readBVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
		case envPromoteTran:
			// currently ignored
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// dtc token
			// spec says it should be L_VARBYTE, so this code might be wrong
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envTranMgrAddr:
			// currently ignored
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// XACT_MANAGER_ADDRESS = B_VARBYTE
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envTranEnded:
			// currently ignored
			// old value, B_VARBYTE
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envResetConnAck:
			// currently ignored
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envStartedInstanceName:
			// currently ignored
			// old value, should be 0
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
			// instance name
			if _, err = readBVarChar(r); err != nil {
				badStreamPanic(err)
			}
		case envRouting:
			// RoutingData message is:
			// ValueLength                 USHORT
			// Protocol (TCP = 0)          BYTE
			// ProtocolProperty (new port) USHORT
			// AlternateServer             US_VARCHAR
			_, err := readUshort(r)
			if err != nil {
				badStreamPanic(err)
			}
			protocol, err := readByte(r)
			if err != nil || protocol != 0 {
				badStreamPanic(err)
			}
			newPort, err := readUshort(r)
			if err != nil {
				badStreamPanic(err)
			}
			newServer, err := readUsVarChar(r)
			if err != nil {
				badStreamPanic(err)
			}
			// consume the OLDVALUE = %x00 %x00
			_, err = readUshort(r)
			if err != nil {
				badStreamPanic(err)
			}
			sess.routedServer = newServer
			sess.routedPort = newPort
		default:
			// ignore rest of records because we don't know how to skip those
			sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype)
			break
		}

	}
}

// http://msdn.microsoft.com/en-us/library/dd358180.aspx
func parseReturnStatus(r *tdsBuffer) ReturnStatus {
	return ReturnStatus(r.int32())
}

func parseOrder(r *tdsBuffer) (res orderStruct) {
	len := int(r.uint16())
	res.ColIds = make([]uint16, len/2)
	for i := 0; i < len/2; i++ {
		res.ColIds[i] = r.uint16()
	}
	return res
}

// https://msdn.microsoft.com/en-us/library/dd340421.aspx
func parseDone(r *tdsBuffer) (res doneStruct) {
	res.Status = r.uint16()
	res.CurCmd = r.uint16()
	res.RowCount = r.uint64()
	return res
}

// https://msdn.microsoft.com/en-us/library/dd340553.aspx
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
	res.Status = r.uint16()
	res.CurCmd = r.uint16()
	res.RowCount = r.uint64()
	return res
}

type sspiMsg []byte

func parseSSPIMsg(r *tdsBuffer) sspiMsg {
	size := r.uint16()
	buf := make([]byte, size)
	r.ReadFull(buf)
	return sspiMsg(buf)
}

type loginAckStruct struct {
	Interface  uint8
	TDSVersion uint32
	ProgName   string
	ProgVer    uint32
}

func parseLoginAck(r *tdsBuffer) loginAckStruct {
	size := r.uint16()
	buf := make([]byte, size)
	r.ReadFull(buf)
	var res loginAckStruct
	res.Interface = buf[0]
	res.TDSVersion = binary.BigEndian.Uint32(buf[1:])
	prognamelen := buf[1+4]
	var err error
	if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil {
		badStreamPanic(err)
	}
	res.ProgVer = binary.BigEndian.Uint32(buf[size-4:])
	return res
}

// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a
func parseFeatureExtAck(r *tdsBuffer) {
	// at most 1 featureAck per feature in featureExt
	// go-mssqldb will add at most 1 feature, the spec defines 7 different features
	for i := 0; i < 8; i++ {
		featureID := r.byte() // FeatureID
		if featureID == 0xff {
			return
		}
		size := r.uint32() // FeatureAckDataLen
		d := make([]byte, size)
		r.ReadFull(d)
	}
	panic("parsed more than 7 featureAck's, protocol implementation error?")
}

// http://msdn.microsoft.com/en-us/library/dd357363.aspx
func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) {
	count := r.uint16()
	if count == 0xffff {
		// no metadata is sent
		return nil
	}
	columns = make([]columnStruct, count)
	for i := range columns {
		column := &columns[i]
		column.UserType = r.uint32()
		column.Flags = r.uint16()

		// parsing TYPE_INFO structure
		column.ti = readTypeInfo(r)
		column.ColName = r.BVarChar()
	}
	return columns
}

// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
	for i, column := range columns {
		row[i] = column.ti.Reader(&column.ti, r)
	}
}

// http://msdn.microsoft.com/en-us/library/dd304783.aspx
func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
	bitlen := (len(columns) + 7) / 8
	pres := make([]byte, bitlen)
	r.ReadFull(pres)
	for i, col := range columns {
		if pres[i/8]&(1<<(uint(i)%8)) != 0 {
			row[i] = nil
			continue
		}
		row[i] = col.ti.Reader(&col.ti, r)
	}
}

// http://msdn.microsoft.com/en-us/library/dd304156.aspx
func parseError72(r *tdsBuffer) (res Error) {
	length := r.uint16()
	_ = length // ignore length
	res.Number = r.int32()
	res.State = r.byte()
	res.Class = r.byte()
	res.Message = r.UsVarChar()
	res.ServerName = r.BVarChar()
	res.ProcName = r.BVarChar()
	res.LineNo = r.int32()
	return
}

// http://msdn.microsoft.com/en-us/library/dd304156.aspx
func parseInfo(r *tdsBuffer) (res Error) {
	length := r.uint16()
	_ = length // ignore length
	res.Number = r.int32()
	res.State = r.byte()
	res.Class = r.byte()
	res.Message = r.UsVarChar()
	res.ServerName = r.BVarChar()
	res.ProcName = r.BVarChar()
	res.LineNo = r.int32()
	return
}

// https://msdn.microsoft.com/en-us/library/dd303881.aspx
func parseReturnValue(r *tdsBuffer) (nv namedValue) {
	/*
		ParamOrdinal
		ParamName
		Status
		UserType
		Flags
		TypeInfo
		CryptoMetadata
		Value
	*/
	r.uint16()
	nv.Name = r.BVarChar()
	r.byte()
	r.uint32() // UserType (uint16 prior to 7.2)
	r.uint16()
	ti := readTypeInfo(r)
	nv.Value = ti.Reader(&ti, r)
	return
}

func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
	defer func() {
		if err := recover(); err != nil {
			if sess.logFlags&logErrors != 0 {
				sess.log.Printf("ERROR: Intercepted panic %v", err)
			}
			ch <- err
		}
		close(ch)
	}()

	packet_type, err := sess.buf.BeginRead()
	if err != nil {
		if sess.logFlags&logErrors != 0 {
			sess.log.Printf("ERROR: BeginRead failed %v", err)
		}
		ch <- err
		return
	}
	if packet_type != packReply {
		badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply))
	}
	var columns []columnStruct
	errs := make([]Error, 0, 5)
	for {
		token := token(sess.buf.byte())
		if sess.logFlags&logDebug != 0 {
			sess.log.Printf("got token %v", token)
		}
		switch token {
		case tokenSSPI:
			ch <- parseSSPIMsg(sess.buf)
			return
		case tokenReturnStatus:
			returnStatus := parseReturnStatus(sess.buf)
			ch <- returnStatus
		case tokenLoginAck:
			loginAck := parseLoginAck(sess.buf)
			ch <- loginAck
		case tokenFeatureExtAck:
			parseFeatureExtAck(sess.buf)
		case tokenOrder:
			order := parseOrder(sess.buf)
			ch <- order
		case tokenDoneInProc:
			done := parseDoneInProc(sess.buf)
			if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
				sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
			}
			ch <- done
		case tokenDone, tokenDoneProc:
			done := parseDone(sess.buf)
			done.errors = errs
			if sess.logFlags&logDebug != 0 {
				sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
			}
			if done.Status&doneSrvError != 0 {
				ch <- errors.New("SQL Server had internal error")
				return
			}
			if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
				sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
			}
			ch <- done
			if done.Status&doneMore == 0 {
				return
			}
		case tokenColMetadata:
			columns = parseColMetadata72(sess.buf)
			ch <- columns
		case tokenRow:
			row := make([]interface{}, len(columns))
			parseRow(sess.buf, columns, row)
			ch <- row
		case tokenNbcRow:
			row := make([]interface{}, len(columns))
			parseNbcRow(sess.buf, columns, row)
			ch <- row
		case tokenEnvChange:
			processEnvChg(sess)
		case tokenError:
			err := parseError72(sess.buf)
			if sess.logFlags&logDebug != 0 {
				sess.log.Printf("got ERROR %d %s", err.Number, err.Message)
			}
			errs = append(errs, err)
			if sess.logFlags&logErrors != 0 {
				sess.log.Println(err.Message)
			}
		case tokenInfo:
			info := parseInfo(sess.buf)
			if sess.logFlags&logDebug != 0 {
				sess.log.Printf("got INFO %d %s", info.Number, info.Message)
			}
			if sess.logFlags&logMessages != 0 {
				sess.log.Println(info.Message)
			}
		case tokenReturnValue:
			nv := parseReturnValue(sess.buf)
			if len(nv.Name) > 0 {
				name := nv.Name[1:] // Remove the leading "@".
				if ov, has := outs[name]; has {
					err = scanIntoOut(name, nv.Value, ov)
					if err != nil {
						fmt.Println("scan error", err)
						ch <- err
					}
				}
			}
		default:
			badStreamPanic(fmt.Errorf("unknown token type returned: %v", token))
		}
	}
}

type parseRespIter byte

const (
	parseRespIterContinue parseRespIter = iota // Continue parsing current token.
	parseRespIterNext                          // Fetch the next token.
	parseRespIterDone                          // Done with parsing the response.
)

type parseRespState byte

const (
	parseRespStateNormal  parseRespState = iota // Normal response state.
	parseRespStateCancel                        // Query is canceled, wait for server to confirm.
	parseRespStateClosing                       // Waiting for tokens to come through.
)

type parseResp struct {
	sess        *tdsSession
	ctxDone     <-chan struct{}
	state       parseRespState
	cancelError error
}

func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
	if err := sendAttention(ts.sess.buf); err != nil {
		ts.dlogf("failed to send attention signal %v", err)
		ch <- err
		return parseRespIterDone
	}
	ts.state = parseRespStateCancel
	return parseRespIterContinue
}

func (ts *parseResp) dlog(msg string) {
	// logging from goroutine is disabled to prevent
	// data race detection from firing
	// The race is probably happening when
	// test logger changes between tests.
	/*if ts.sess.logFlags&logDebug != 0 {
		ts.sess.log.Println(msg)
	}*/
}
func (ts *parseResp) dlogf(f string, v ...interface{}) {
	/*if ts.sess.logFlags&logDebug != 0 {
		ts.sess.log.Printf(f, v...)
	}*/
}

func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
	switch ts.state {
	default:
		panic("unknown state")
	case parseRespStateNormal:
		select {
		case tok, ok := <-tokChan:
			if !ok {
				ts.dlog("response finished")
				return parseRespIterDone
			}
			if err, ok := tok.(net.Error); ok && err.Timeout() {
				ts.cancelError = err
				ts.dlog("got timeout error, sending attention signal to server")
				return ts.sendAttention(ch)
			}
			// Pass the token along.
			ch <- tok
			return parseRespIterContinue

		case <-ts.ctxDone:
			ts.ctxDone = nil
			ts.dlog("got cancel message, sending attention signal to server")
			return ts.sendAttention(ch)
		}
	case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
		select {
		case tok, ok := <-tokChan:
			if !ok {
				ts.dlog("response finished but waiting for attention ack")
				return parseRespIterNext
			}
			switch tok := tok.(type) {
			default:
				// Ignore all other tokens while waiting.
				// The TDS spec says other tokens may arrive after an attention
				// signal is sent. Ignore these tokens and continue looking for
				// a DONE with attention confirm mark.
			case doneStruct:
				if tok.Status&doneAttn != 0 {
					ts.dlog("got cancellation confirmation from server")
					if ts.cancelError != nil {
						ch <- ts.cancelError
						ts.cancelError = nil
					} else {
						ch <- ctx.Err()
					}
					return parseRespIterDone
				}

			// If an error happens during cancel, pass it along and just stop.
			// We are uncertain to receive more tokens.
			case error:
				ch <- tok
				ts.state = parseRespStateClosing
			}
			return parseRespIterContinue
		case <-ts.ctxDone:
			ts.ctxDone = nil
			ts.state = parseRespStateClosing
			return parseRespIterContinue
		}
	case parseRespStateClosing: // Wait for current token chan to close.
		if _, ok := <-tokChan; !ok {
			ts.dlog("response finished")
			return parseRespIterDone
		}
		return parseRespIterContinue
	}
}

func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
	ts := &parseResp{
		sess:    sess,
		ctxDone: ctx.Done(),
	}
	defer func() {
		// Ensure any remaining error is piped through
		// or the query may look like it executed when it actually failed.
		if ts.cancelError != nil {
			ch <- ts.cancelError
			ts.cancelError = nil
		}
		close(ch)
	}()

	// Loop over multiple responses.
	for {
		ts.dlog("initiating response reading")

		tokChan := make(chan tokenStruct)
		go processSingleResponse(sess, tokChan, outs)

		// Loop over multiple tokens in response.
	tokensLoop:
		for {
			switch ts.iter(ctx, ch, tokChan) {
			case parseRespIterContinue:
				// Nothing, continue to next token.
			case parseRespIterNext:
				break tokensLoop
			case parseRespIterDone:
				return
			}
		}
	}
}