// Package logx is logrus plus:
// - context logging
// - file & line number logging
// - multiple default loggers (useful if you want to log different formats to different places with one call)

package logx

import (
	"io"
	"io/ioutil"
	"runtime"

	"code.justin.tv/chat/errx"

	"golang.org/x/net/context"

	"github.com/Sirupsen/logrus"
)

type contextKey int

const (
	contextFieldsKey contextKey = iota
)

func callerFields(skip int) Fields {
	_, file, line, ok := runtime.Caller(skip)
	if !ok {
		file = "???"
		line = 0
	}
	return Fields{
		"file": file,
		"line": line,
	}
}

func New() Logger {
	logger := logrus.New()
	logger.Out = ioutil.Discard
	return &loggerImpl{
		Logger:    logger,
		LogCaller: false,
	}
}

type Logger interface {
	AddHook(hook logrus.Hook)
	SetCallerLogging(logCaller bool)
	SetFormatter(formatter logrus.Formatter)
	SetOutput(out io.Writer)
	SetLevel(level logrus.Level)

	Info(ctx context.Context, msg interface{}, fields ...Fields)
	Warn(ctx context.Context, msg interface{}, fields ...Fields)
	Error(ctx context.Context, msg interface{}, fields ...Fields)
	Fatal(ctx context.Context, msg interface{}, fields ...Fields)
}

type loggerImpl struct {
	*logrus.Logger
	LogCaller bool
}

func (l *loggerImpl) SetCallerLogging(logCaller bool) {
	l.LogCaller = logCaller
}

func (l *loggerImpl) AddHook(hook logrus.Hook) {
	if hook != nil {
		l.Logger.Hooks.Add(hook)
	}
}

func (l *loggerImpl) SetFormatter(formatter logrus.Formatter) {
	l.Logger.Formatter = formatter
}

func (l *loggerImpl) SetOutput(out io.Writer) {
	l.Logger.Out = out
}

func (l *loggerImpl) SetLevel(level logrus.Level) {
	l.Logger.Level = level
}

func (l *loggerImpl) entryFromContext(ctx context.Context, msg interface{}, fields ...Fields) *logrus.Entry {
	// Merge fields in order where later fields overwrite earlier fields
	merged := Fields{}

	// Caller
	if l.LogCaller {
		mergeFields(merged, callerFields(4))
	}

	// Fields stored in context
	if ctx != nil {
		if cf, ok := ctx.Value(contextFieldsKey).(*contextFields); ok {
			cf.lock.Lock()
			fs := cf.fields[:len(cf.fields)]
			cf.lock.Unlock()
			for _, f := range fs {
				mergeFields(merged, f)
			}
		}
	}

	// Fields extracted from msg
	mergeFields(merged, msgFields(msg))

	// Fields passed in with call to logx
	for _, f := range fields {
		mergeFields(merged, f)
	}

	return l.Logger.WithFields(logrus.Fields(merged))
}

func (l *loggerImpl) Info(ctx context.Context, msg interface{}, fields ...Fields) {
	l.entryFromContext(ctx, msg, fields...).Info(msg)
}

func (l *loggerImpl) Warn(ctx context.Context, msg interface{}, fields ...Fields) {
	l.entryFromContext(ctx, msg, fields...).Warn(msg)
}

func (l *loggerImpl) Error(ctx context.Context, msg interface{}, fields ...Fields) {
	l.entryFromContext(ctx, msg, fields...).Error(msg)
}

func (l *loggerImpl) Fatal(ctx context.Context, msg interface{}, fields ...Fields) {
	l.entryFromContext(ctx, msg, fields...).Fatal(msg)
}

func msgFields(msg interface{}) Fields {
	if msg == nil {
		return nil
	}

	fields := make(Fields, 0)
	if f := errx.NewWithSkip(msg, 4).Fields(); f != nil {
		for k, v := range f {
			fields[k] = v
		}
	}
	return fields
}

func mergeFields(merged Fields, fields Fields) {
	for k, v := range fields {
		merged[k] = v
	}
}
