package iprep

import (
	"database/sql"
	"flag"
	"io/ioutil"
	"os"
	"testing"
	"time"

	"github.com/abh/geoip"
	_ "github.com/mattn/go-sqlite3"

	"code.justin.tv/abuse/shodan/config"
	"code.justin.tv/abuse/shodan/whitelist"
)

const (
	schema = `CREATE TABLE vba_score (
	version INTEGER,
	count BIGINT,
	category TEXT,
	dscore BIGINT,
	flock_score BIGINT,
	dscore_str TEXT,
	client_asn_id BIGINT,
	time FLOAT
);
CREATE UNIQUE INDEX vba_score_index ON vba_score (client_asn_id, version);
CREATE TABLE versions (
	id INTEGER NOT NULL,
	time DATETIME,
	PRIMARY KEY (id)
);`
)

type dummyOptMap map[string]interface{}

func (m dummyOptMap) GetInt(name string, def int) int {
	if val, ok := m[name]; ok {
		return val.(int)
	}
	return def
}
func (m dummyOptMap) GetInt32(name string, def int32) int32 {
	if val, ok := m[name]; ok {
		return val.(int32)
	}
	return def
}
func (m dummyOptMap) GetInt64(name string, def int64) int64 {
	if val, ok := m[name]; ok {
		return val.(int64)
	}
	return def
}
func (m dummyOptMap) GetFloat32(name string, def float32) float32 {
	if val, ok := m[name]; ok {
		return val.(float32)
	}
	return def
}
func (m dummyOptMap) GetFloat64(name string, def float64) float64 {
	if val, ok := m[name]; ok {
		return val.(float64)
	}
	return def
}

func init() {
	flag.Parse()
}

func withTestDb(t *testing.T, f func(dbPath, asnPath string)) {
	dir, err := ioutil.TempDir("", "shodan-test-db")
	if err != nil {
		t.Fatal("couldn't create temp dir:", err)
	}
	defer os.RemoveAll(dir)

	dbPath := dir + "/vba_score.db"

	db, err := sql.Open("sqlite3", dbPath)
	if err != nil {
		t.Fatal("couldn't create temp db:", err)
	}
	defer db.Close()

	exec(t, db, schema)

	f(dbPath, *config.AsnPath)
}

func exec(t *testing.T, db *sql.DB, qry string, args ...interface{}) sql.Result {
	result, err := db.Exec(qry, args...)
	if err != nil {
		t.Fatalf("unable to run query.\nquery: %v\nargs: %v\nerror: %v", qry, args, err)
	}
	return result
}

func ipToASNID(t *testing.T, geo *geoip.GeoIP, ip string) string {
	asnName, _ := geo.GetName(ip)
	asnID, err := asnNameToID(asnName)
	if err != nil {
		t.Fatalf("couldn't get ASN ID for %v, ASN name %v: %v\n", ip, asnName, err)
	}

	return asnID
}

func setScore(t *testing.T, repSource *stdSource, version int, testASNID string, score float64) {
	exec(t, repSource.db, "insert or replace into vba_score (version, client_asn_id, dscore, count)"+
		"values (?, ?, ?, ?)", version, testASNID, score, defMinRequestsForScore+1)
}

func expectFracViews(t *testing.T, repSource Source, ip string, expected float64) {
	fracViews, err := repSource.FracViewsForIP(ip, nil)
	if err != nil {
		t.Fatalf("couldn't get rep for %v: %v\n", ip, err)
	}
	if fracViews != expected {
		t.Fatalf("expected fracViews = %v, got %v", expected, fracViews)
	}
}

func testFracViewsForScore(t *testing.T, repSource *stdSource, version int, testIP, testASNID string, score, fracViews float64) {
	repSource.repCache = make(map[string]float64)
	t.Logf("verifying fracViews for %v/%v with score %v is %v\n", testIP, testASNID, score, fracViews)
	setScore(t, repSource, version, testASNID, score)
	expectFracViews(t, repSource, testIP, fracViews)
}

func TestBasicAsnRep(t *testing.T) {
	withTestDb(t, func(dbPath, asnPath string) {
		tx := time.Now().Unix()
		const testVersion = 1
		dbOpts := dummyOptMap(map[string]interface{}{
			"gvc.iprep.v1":           testVersion,
			"gvc.iprep.v2":           testVersion,
			"gvc.iprep.t1":           tx,
			"gvc.iprep.t2":           tx,
			"gvc.iprep.vbad_thresh":  defVBadThreshold,
			"gvc.iprep.score_minreq": defMinRequestsForScore,
		})

		opaqueRepSource, err := NewSource(asnPath, dbPath, &whitelist.Whitelist{}, dbOpts)
		if err != nil {
			t.Fatal("couldn't create rep source:", err)
		}
		repSource := opaqueRepSource.(*stdSource)

		var testIP = "192.16.71.177"
		var testASNID = ipToASNID(t, repSource.asn, testIP)

		// default should be fracviews 1
		t.Log("checking fracViews without score set")
		expectFracViews(t, repSource, testIP, 1)

		testFracViewsForScore(t, repSource, testVersion, testIP, testASNID, 0, 1)
		testFracViewsForScore(t, repSource, testVersion, testIP, testASNID, defVBadThreshold-1, 1)
		testFracViewsForScore(t, repSource, testVersion, testIP, testASNID, defVBadThreshold, 0)
		testFracViewsForScore(t, repSource, testVersion, testIP, testASNID, defVBadThreshold+1, 0)
	})
}
