package rule

import (
	"fmt"
	"go/ast"
	"go/token"

	"github.com/mgechev/revive/lint"
	"golang.org/x/tools/go/ast/astutil"
)

// CognitiveComplexityRule lints given else constructs.
type CognitiveComplexityRule struct{}

// Apply applies the rule to given file.
func (r *CognitiveComplexityRule) Apply(file *lint.File, arguments lint.Arguments) []lint.Failure {
	var failures []lint.Failure

	const expectedArgumentsCount = 1
	if len(arguments) < expectedArgumentsCount {
		panic(fmt.Sprintf("not enough arguments for cognitive-complexity, expected %d, got %d", expectedArgumentsCount, len(arguments)))
	}
	complexity, ok := arguments[0].(int64)
	if !ok {
		panic(fmt.Sprintf("invalid argument type for cognitive-complexity, expected int64, got %T", arguments[0]))
	}

	linter := cognitiveComplexityLinter{
		file:          file,
		maxComplexity: int(complexity),
		onFailure: func(failure lint.Failure) {
			failures = append(failures, failure)
		},
	}

	linter.lint()

	return failures
}

// Name returns the rule name.
func (r *CognitiveComplexityRule) Name() string {
	return "cognitive-complexity"
}

type cognitiveComplexityLinter struct {
	file          *lint.File
	maxComplexity int
	onFailure     func(lint.Failure)
}

func (w cognitiveComplexityLinter) lint() {
	f := w.file
	for _, decl := range f.AST.Decls {
		if fn, ok := decl.(*ast.FuncDecl); ok && fn.Body != nil {
			v := cognitiveComplexityVisitor{}
			c := v.subTreeComplexity(fn.Body)
			if c > w.maxComplexity {
				w.onFailure(lint.Failure{
					Confidence: 1,
					Category:   "maintenance",
					Failure:    fmt.Sprintf("function %s has cognitive complexity %d (> max enabled %d)", funcName(fn), c, w.maxComplexity),
					Node:       fn,
				})
			}
		}
	}
}

type cognitiveComplexityVisitor struct {
	complexity   int
	nestingLevel int
}

// subTreeComplexity calculates the cognitive complexity of an AST-subtree.
func (v cognitiveComplexityVisitor) subTreeComplexity(n ast.Node) int {
	ast.Walk(&v, n)
	return v.complexity
}

// Visit implements the ast.Visitor interface.
func (v *cognitiveComplexityVisitor) Visit(n ast.Node) ast.Visitor {
	switch n := n.(type) {
	case *ast.IfStmt:
		targets := []ast.Node{n.Cond, n.Body, n.Else}
		v.walk(1, targets...)
		return nil
	case *ast.ForStmt:
		targets := []ast.Node{n.Cond, n.Body}
		v.walk(1, targets...)
		return nil
	case *ast.RangeStmt:
		v.walk(1, n.Body)
		return nil
	case *ast.SelectStmt:
		v.walk(1, n.Body)
		return nil
	case *ast.SwitchStmt:
		v.walk(1, n.Body)
		return nil
	case *ast.TypeSwitchStmt:
		v.walk(1, n.Body)
		return nil
	case *ast.FuncLit:
		v.walk(0, n.Body) // do not increment the complexity, just do the nesting
		return nil
	case *ast.BinaryExpr:
		v.complexity += v.binExpComplexity(n)
		return nil // skip visiting binexp sub-tree (already visited by binExpComplexity)
	case *ast.BranchStmt:
		if n.Label != nil {
			v.complexity++
		}
	}
	// TODO handle (at least) direct recursion

	return v
}

func (v *cognitiveComplexityVisitor) walk(complexityIncrement int, targets ...ast.Node) {
	v.complexity += complexityIncrement + v.nestingLevel
	nesting := v.nestingLevel
	v.nestingLevel++

	for _, t := range targets {
		if t == nil {
			continue
		}

		ast.Walk(v, t)
	}

	v.nestingLevel = nesting
}

func (cognitiveComplexityVisitor) binExpComplexity(n *ast.BinaryExpr) int {
	calculator := binExprComplexityCalculator{opsStack: []token.Token{}}

	astutil.Apply(n, calculator.pre, calculator.post)

	return calculator.complexity
}

type binExprComplexityCalculator struct {
	complexity    int
	opsStack      []token.Token // stack of bool operators
	subexpStarted bool
}

func (becc *binExprComplexityCalculator) pre(c *astutil.Cursor) bool {
	switch n := c.Node().(type) {
	case *ast.BinaryExpr:
		isBoolOp := n.Op == token.LAND || n.Op == token.LOR
		if !isBoolOp {
			break
		}

		ops := len(becc.opsStack)
		// if
		// 		is the first boolop in the expression OR
		// 		is the first boolop inside a subexpression (...) OR
		//		is not the same to the previous one
		// then
		//      increment complexity
		if ops == 0 || becc.subexpStarted || n.Op != becc.opsStack[ops-1] {
			becc.complexity++
			becc.subexpStarted = false
		}

		becc.opsStack = append(becc.opsStack, n.Op)
	case *ast.ParenExpr:
		becc.subexpStarted = true
	}

	return true
}

func (becc *binExprComplexityCalculator) post(c *astutil.Cursor) bool {
	switch n := c.Node().(type) {
	case *ast.BinaryExpr:
		isBoolOp := n.Op == token.LAND || n.Op == token.LOR
		if !isBoolOp {
			break
		}

		ops := len(becc.opsStack)
		if ops > 0 {
			becc.opsStack = becc.opsStack[:ops-1]
		}
	case *ast.ParenExpr:
		becc.subexpStarted = false
	}

	return true
}