package main

import (
	"errors"
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"path"
	"reflect"
	"strings"
)

var errBadReturn = errors.New("found return arg with no name: all args must be named")

type errUnexpectedType struct {
	expected string
	actual   interface{}
}

func (e errUnexpectedType) Error() string {
	return fmt.Sprintf("got wrong type expecting %s, got: %v", e.expected, reflect.TypeOf(e.actual))
}

// ParsedPkg holds information about a package that has been parsed,
// its name and the list of functions.
type ParsedPkg struct {
	Name      string
	Functions []function
	Imports   []importSpec
}

type function struct {
	Name    string
	Args    []arg
	Returns []arg
	Doc     string
}

type arg struct {
	Name            string
	ArgType         string
	PackageSelector string
}

func (a *arg) String() string {
	return a.Name + " " + a.ArgType
}

type importSpec struct {
	Name string
	Path string
}

func (s *importSpec) String() string {
	var ss string
	if len(s.Name) != 0 {
		ss += s.Name
	}
	ss += s.Path
	return ss
}

// Parse parses the given file for an interface definition with the given name.
func Parse(filePath string, objName string) (*ParsedPkg, error) {
	fs := token.NewFileSet()
	pkg, err := parser.ParseFile(fs, filePath, nil, parser.AllErrors)
	if err != nil {
		return nil, err
	}
	p := &ParsedPkg{}
	p.Name = pkg.Name.Name
	obj, exists := pkg.Scope.Objects[objName]
	if !exists {
		return nil, fmt.Errorf("could not find object %s in %s", objName, filePath)
	}
	if obj.Kind != ast.Typ {
		return nil, fmt.Errorf("exected type, got %s", obj.Kind)
	}
	spec, ok := obj.Decl.(*ast.TypeSpec)
	if !ok {
		return nil, errUnexpectedType{"*ast.TypeSpec", obj.Decl}
	}
	iface, ok := spec.Type.(*ast.InterfaceType)
	if !ok {
		return nil, errUnexpectedType{"*ast.InterfaceType", spec.Type}
	}

	p.Functions, err = parseInterface(iface)
	if err != nil {
		return nil, err
	}

	// figure out what imports will be needed
	imports := make(map[string]importSpec)
	for _, f := range p.Functions {
		args := append(f.Args, f.Returns...)
		for _, arg := range args {
			if len(arg.PackageSelector) == 0 {
				continue
			}

			for _, i := range pkg.Imports {
				if i.Name != nil {
					if i.Name.Name != arg.PackageSelector {
						continue
					}
					imports[i.Path.Value] = importSpec{Name: arg.PackageSelector, Path: i.Path.Value}
					break
				}

				_, name := path.Split(i.Path.Value)
				splitName := strings.Split(name, "-")
				if len(splitName) > 1 {
					name = splitName[len(splitName)-1]
				}
				// import paths have quotes already added in, so need to remove them for name comparison
				name = strings.TrimPrefix(name, `"`)
				name = strings.TrimSuffix(name, `"`)
				if name == arg.PackageSelector {
					imports[i.Path.Value] = importSpec{Path: i.Path.Value}
					break
				}
			}
		}
	}

	for _, spec := range imports {
		p.Imports = append(p.Imports, spec)
	}

	return p, nil
}

func parseInterface(iface *ast.InterfaceType) ([]function, error) {
	var functions []function
	for _, field := range iface.Methods.List {
		switch f := field.Type.(type) {
		case *ast.FuncType:
			method, err := parseFunc(field)
			if err != nil {
				return nil, err
			}
			if method == nil {
				continue
			}
			functions = append(functions, *method)
		case *ast.Ident:
			spec, ok := f.Obj.Decl.(*ast.TypeSpec)
			if !ok {
				return nil, errUnexpectedType{"*ast.TypeSpec", f.Obj.Decl}
			}
			iface, ok := spec.Type.(*ast.InterfaceType)
			if !ok {
				return nil, errUnexpectedType{"*ast.TypeSpec", spec.Type}
			}
			funcs, err := parseInterface(iface)
			if err != nil {
				fmt.Println(err)
				continue
			}
			functions = append(functions, funcs...)
		default:
			return nil, errUnexpectedType{"*astFuncType or *ast.Ident", f}
		}
	}
	return functions, nil
}

func parseFunc(field *ast.Field) (*function, error) {
	f := field.Type.(*ast.FuncType)
	method := &function{Name: field.Names[0].Name}
	if _, exists := skipFuncs[method.Name]; exists {
		fmt.Println("skipping:", method.Name)
		return nil, nil
	}
	if f.Params != nil {
		args, err := parseArgs(f.Params.List)
		if err != nil {
			return nil, err
		}
		method.Args = args
	}
	if f.Results != nil {
		returns, err := parseArgs(f.Results.List)
		if err != nil {
			return nil, fmt.Errorf("error parsing function returns for %q: %v", method.Name, err)
		}
		method.Returns = returns
	}
	return method, nil
}

func parseArgs(fields []*ast.Field) ([]arg, error) {
	var args []arg
	for _, f := range fields {
		if len(f.Names) == 0 {
			return nil, errBadReturn
		}
		for _, name := range f.Names {
			p, err := parseExpr(f.Type)
			if err != nil {
				return nil, err
			}
			args = append(args, arg{name.Name, p.value, p.pkg})
		}
	}
	return args, nil
}

type parsedExpr struct {
	value string
	pkg   string
}

func parseExpr(e ast.Expr) (parsedExpr, error) {
	var parsed parsedExpr
	switch i := e.(type) {
	case *ast.Ident:
		parsed.value += i.Name
	case *ast.StarExpr:
		p, err := parseExpr(i.X)
		if err != nil {
			return parsed, err
		}
		parsed.value += "*"
		parsed.value += p.value
		parsed.pkg = p.pkg
	case *ast.SelectorExpr:
		p, err := parseExpr(i.X)
		if err != nil {
			return parsed, err
		}
		parsed.pkg = p.value
		parsed.value += p.value + "."
		parsed.value += i.Sel.Name
	case *ast.MapType:
		parsed.value += "map["
		p, err := parseExpr(i.Key)
		if err != nil {
			return parsed, err
		}
		parsed.value += p.value
		parsed.value += "]"
		p, err = parseExpr(i.Value)
		if err != nil {
			return parsed, err
		}
		parsed.value += p.value
		parsed.pkg = p.pkg
	case *ast.ArrayType:
		parsed.value += "[]"
		p, err := parseExpr(i.Elt)
		if err != nil {
			return parsed, err
		}
		parsed.value += p.value
		parsed.pkg = p.pkg
	default:
		return parsed, errUnexpectedType{"*ast.Ident or *ast.StarExpr", i}
	}
	return parsed, nil
}