package dbplus

import (
	"database/sql"
	"fmt"

	"code.justin.tv/chat/db"
	"code.justin.tv/chat/timing"
	"golang.org/x/net/context"
)

type txOrDb interface {
	Query(ctx context.Context, name, query string, args ...interface{}) (db.Rows, error)
	Exec(ctx context.Context, name, query string, args ...interface{}) (db.Result, error)
	QueryRow(ctx context.Context, name, query string, args ...interface{}) db.Row
}

// caller helps identify where potentially bad DB queries are coming from
func NewQuerier(caller string, xactGroup string, logger *DbLogger) *DbQuerier {
	return &DbQuerier{
		caller:    caller,
		xactGroup: xactGroup,
		logger:    logger,
	}
}

type DbQuerier struct {
	caller    string
	xactGroup string
	logger    *DbLogger
}

func (q *DbQuerier) QueryHasResults(ctx context.Context, database txOrDb, queryId, query string, params ...interface{}) (bool, error) {
	var result interface{}
	fn := func(row db.Row) error {
		return row.Scan(&result)
	}
	err := q.QueryRow(ctx, database, queryId, fn, query, params...)
	if err == db.ErrNoRows {
		return false, nil
	} else if err != nil {
		return false, err
	}
	return true, nil
}

func (q *DbQuerier) QueryInt(ctx context.Context, database txOrDb, queryId, query string, params ...interface{}) (int, error) {
	var result int
	fn := func(row db.Row) error {
		return row.Scan(&result)
	}
	err := q.QueryRow(ctx, database, queryId, fn, query, params...)
	if err != nil {
		return 0, err
	}
	return result, err
}

func (q *DbQuerier) QueryIntSlice(ctx context.Context, database txOrDb, queryId, query string, params ...interface{}) ([]int, error) {
	results := []int{}
	var result int
	fn := func(rows db.Rows) error {
		err := rows.Scan(&result)
		if err != nil {
			return err
		}
		results = append(results, result)
		return nil
	}
	err := q.Query(ctx, database, queryId, fn, query, params...)
	return results, err
}

func (q *DbQuerier) QueryString(ctx context.Context, database txOrDb, queryId, query string, params ...interface{}) (string, error) {
	var result sql.NullString
	fn := func(row db.Row) error {
		return row.Scan(&result)
	}
	if err := q.QueryRow(ctx, database, queryId, fn, query, params...); err != nil {
		return "", err
	}

	if !result.Valid {
		return "", nil
	}
	return result.String, nil
}

func (q *DbQuerier) QueryStringSlice(ctx context.Context, database txOrDb, queryId, query string, params ...interface{}) ([]string, error) {
	results := []string{}
	var result string
	fn := func(rows db.Rows) error {
		err := rows.Scan(&result)
		if err != nil {
			return err
		}

		results = append(results, result)
		return nil
	}
	err := q.Query(ctx, database, queryId, fn, query, params...)
	return results, err
}

func (q *DbQuerier) QueryRow(ctx context.Context, database txOrDb, queryId string, fn func(db.Row) error, query string, params ...interface{}) error {
	return q.do(ctx, func() error {
		sqlQuery, queryParams := q.buildQuery(ctx, queryId, query, params...)
		q.logger.printQuery(ctx, q.xactGroup, sqlQuery, queryParams)

		row := database.QueryRow(ctx, queryId, sqlQuery, queryParams...)
		return fn(row)
	})
}

func (q *DbQuerier) Query(ctx context.Context, database txOrDb, queryId string, fn func(db.Rows) error, query string, params ...interface{}) error {
	return q.do(ctx, func() error {
		sqlQuery, queryParams := q.buildQuery(ctx, queryId, query, params...)
		q.logger.printQuery(ctx, q.xactGroup, sqlQuery, queryParams)

		rows, err := database.Query(ctx, queryId, sqlQuery, queryParams...)
		if err != nil {
			return err
		}
		defer rows.Close()

		for rows.Next() {
			err = fn(rows)
			if err != nil {
				return err
			}
		}
		return rows.Err()
	})
}

func (q *DbQuerier) Exec(ctx context.Context, database txOrDb, queryId, query string, params ...interface{}) (db.Result, error) {
	var result db.Result
	err := q.do(ctx, func() error {
		sqlQuery, queryParams := q.buildQuery(ctx, queryId, query, params...)
		q.logger.printQuery(ctx, q.xactGroup, sqlQuery, queryParams)

		res, err := database.Exec(ctx, queryId, sqlQuery, queryParams...)
		result = res
		return err
	})
	return result, err
}

func (q *DbQuerier) do(ctx context.Context, fn func() error) error {
	xact, ok := timing.XactFromContext(ctx)
	if ok {
		sub := xact.Sub(q.xactGroup)
		sub.Start()
		defer sub.End()
	}
	return fn()
}

func (q *DbQuerier) buildQuery(ctx context.Context, queryId, query string, params ...interface{}) (string, []interface{}) {
	return fmt.Sprintf("%s -- %s: %s", query, q.caller, queryId), params
}
