// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package xorm

import (
	"errors"
	"fmt"
	"reflect"
	"strconv"
	"strings"

	"github.com/go-xorm/builder"
	"github.com/go-xorm/core"
)

func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
	if session.Statement.RefTable == nil ||
		session.Tx != nil {
		return ErrCacheFailed
	}

	oldhead, newsql := session.Statement.convertUpdateSQL(sqlStr)
	if newsql == "" {
		return ErrCacheFailed
	}
	for _, filter := range session.Engine.dialect.Filters() {
		newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable)
	}
	session.Engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)

	var nStart int
	if len(args) > 0 {
		if strings.Index(sqlStr, "?") > -1 {
			nStart = strings.Count(oldhead, "?")
		} else {
			// only for pq, TODO: if any other databse?
			nStart = strings.Count(oldhead, "$")
		}
	}
	table := session.Statement.RefTable
	cacher := session.Engine.getCacher2(table)
	tableName := session.Statement.TableName()
	session.Engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
	ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
	if err != nil {
		rows, err := session.DB().Query(newsql, args[nStart:]...)
		if err != nil {
			return err
		}
		defer rows.Close()

		ids = make([]core.PK, 0)
		for rows.Next() {
			var res = make([]string, len(table.PrimaryKeys))
			err = rows.ScanSlice(&res)
			if err != nil {
				return err
			}
			var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
			for i, col := range table.PKColumns() {
				if col.SQLType.IsNumeric() {
					n, err := strconv.ParseInt(res[i], 10, 64)
					if err != nil {
						return err
					}
					pk[i] = n
				} else if col.SQLType.IsText() {
					pk[i] = res[i]
				} else {
					return errors.New("not supported")
				}
			}

			ids = append(ids, pk)
		}
		session.Engine.logger.Debug("[cacheUpdate] find updated id", ids)
	} /*else {
	    session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
	    cacher.DelIds(tableName, genSqlKey(newsql, args))
	}*/

	for _, id := range ids {
		sid, err := id.ToString()
		if err != nil {
			return err
		}
		if bean := cacher.GetBean(tableName, sid); bean != nil {
			sqls := splitNNoCase(sqlStr, "where", 2)
			if len(sqls) == 0 || len(sqls) > 2 {
				return ErrCacheFailed
			}

			sqls = splitNNoCase(sqls[0], "set", 2)
			if len(sqls) != 2 {
				return ErrCacheFailed
			}
			kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
			for idx, kv := range kvs {
				sps := strings.SplitN(kv, "=", 2)
				sps2 := strings.Split(sps[0], ".")
				colName := sps2[len(sps2)-1]
				if strings.Contains(colName, "`") {
					colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
				} else if strings.Contains(colName, session.Engine.QuoteStr()) {
					colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1))
				} else {
					session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
					return ErrCacheFailed
				}

				if col := table.GetColumn(colName); col != nil {
					fieldValue, err := col.ValueOf(bean)
					if err != nil {
						session.Engine.logger.Error(err)
					} else {
						session.Engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
						if col.IsVersion && session.Statement.checkVersion {
							fieldValue.SetInt(fieldValue.Int() + 1)
						} else {
							fieldValue.Set(reflect.ValueOf(args[idx]))
						}
					}
				} else {
					session.Engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
						colName, table.Name)
				}
			}

			session.Engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
			cacher.PutBean(tableName, sid, bean)
		}
	}
	session.Engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
	cacher.ClearIds(tableName)
	return nil
}

// Update records, bean's non-empty fields are updated contents,
// condiBean' non-empty filds are conditions
// CAUTION:
//        1.bool will defaultly be updated content nor conditions
//         You should call UseBool if you have bool to use.
//        2.float32 & float64 may be not inexact as conditions
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
	defer session.resetStatement()
	if session.IsAutoClose {
		defer session.Close()
	}

	v := rValue(bean)
	t := v.Type()

	var colNames []string
	var args []interface{}

	// handle before update processors
	for _, closure := range session.beforeClosures {
		closure(bean)
	}
	cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
	if processor, ok := interface{}(bean).(BeforeUpdateProcessor); ok {
		processor.BeforeUpdate()
	}
	// --

	var err error
	var isMap = t.Kind() == reflect.Map
	var isStruct = t.Kind() == reflect.Struct
	if isStruct {
		session.Statement.setRefValue(v)

		if len(session.Statement.TableName()) <= 0 {
			return 0, ErrTableNotFound
		}

		if session.Statement.ColumnStr == "" {
			colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false,
				false, false, session.Statement.allUseBool, session.Statement.useAllCols,
				session.Statement.mustColumnMap, session.Statement.nullableMap,
				session.Statement.columnMap, true, session.Statement.unscoped)
		} else {
			colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true)
			if err != nil {
				return 0, err
			}
		}
	} else if isMap {
		colNames = make([]string, 0)
		args = make([]interface{}, 0)
		bValue := reflect.Indirect(reflect.ValueOf(bean))

		for _, v := range bValue.MapKeys() {
			colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
			args = append(args, bValue.MapIndex(v).Interface())
		}
	} else {
		return 0, ErrParamsType
	}

	table := session.Statement.RefTable

	if session.Statement.UseAutoTime && table != nil && table.Updated != "" {
		colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
		col := table.UpdatedColumn()
		val, t := session.Engine.NowTime2(col.SQLType.Name)
		args = append(args, val)

		var colName = col.Name
		if isStruct {
			session.afterClosures = append(session.afterClosures, func(bean interface{}) {
				col := table.GetColumn(colName)
				setColumnTime(bean, col, t)
			})
		}
	}

	//for update action to like "column = column + ?"
	incColumns := session.Statement.getInc()
	for _, v := range incColumns {
		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?")
		args = append(args, v.arg)
	}
	//for update action to like "column = column - ?"
	decColumns := session.Statement.getDec()
	for _, v := range decColumns {
		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?")
		args = append(args, v.arg)
	}
	//for update action to like "column = expression"
	exprColumns := session.Statement.getExpr()
	for _, v := range exprColumns {
		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
	}

	session.Statement.processIDParam()

	var autoCond builder.Cond
	if !session.Statement.noAutoCondition && len(condiBean) > 0 {
		var err error
		autoCond, err = session.Statement.buildConds(session.Statement.RefTable, condiBean[0], true, true, false, true, false)
		if err != nil {
			return 0, err
		}
	}

	st := session.Statement
	defer session.resetStatement()

	var sqlStr string
	var condArgs []interface{}
	var condSQL string
	cond := session.Statement.cond.And(autoCond)

	doIncVer := false
	var verValue *reflect.Value
	if table != nil && table.Version != "" && session.Statement.checkVersion {
		verValue, err = table.VersionColumn().ValueOf(bean)
		if err != nil {
			return 0, err
		}

		cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()})
		condSQL, condArgs, _ = builder.ToSQL(cond)

		if len(condSQL) > 0 {
			condSQL = "WHERE " + condSQL
		}

		if st.LimitN > 0 {
			condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
		}

		sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v",
			session.Engine.Quote(session.Statement.TableName()),
			strings.Join(colNames, ", "),
			session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
			condSQL)

		doIncVer = true
	} else {
		condSQL, condArgs, _ = builder.ToSQL(cond)
		if len(condSQL) > 0 {
			condSQL = "WHERE " + condSQL
		}

		if st.LimitN > 0 {
			condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
		}

		sqlStr = fmt.Sprintf("UPDATE %v SET %v %v",
			session.Engine.Quote(session.Statement.TableName()),
			strings.Join(colNames, ", "),
			condSQL)
	}

	res, err := session.exec(sqlStr, append(args, condArgs...)...)
	if err != nil {
		return 0, err
	} else if doIncVer {
		if verValue != nil && verValue.IsValid() && verValue.CanSet() {
			verValue.SetInt(verValue.Int() + 1)
		}
	}

	if table != nil {
		if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
			cacher.ClearIds(session.Statement.TableName())
			cacher.ClearBeans(session.Statement.TableName())
		}
	}

	// handle after update processors
	if session.IsAutoCommit {
		for _, closure := range session.afterClosures {
			closure(bean)
		}
		if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
			session.Engine.logger.Debug("[event]", session.Statement.TableName(), " has after update processor")
			processor.AfterUpdate()
		}
	} else {
		lenAfterClosures := len(session.afterClosures)
		if lenAfterClosures > 0 {
			if value, has := session.afterUpdateBeans[bean]; has && value != nil {
				*value = append(*value, session.afterClosures...)
			} else {
				afterClosures := make([]func(interface{}), lenAfterClosures)
				copy(afterClosures, session.afterClosures)
				// FIXME: if bean is a map type, it will panic because map cannot be as map key
				session.afterUpdateBeans[bean] = &afterClosures
			}

		} else {
			if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
				session.afterUpdateBeans[bean] = nil
			}
		}
	}
	cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
	// --

	return res.RowsAffected()
}