import io
import os
import pickle
import sys
from dataclasses import dataclass
from pathlib import Path

import pytest

import conductor.docker.predictor
import conductor.utils.sagemaker_training_environment
from conductor.docker.predictor import ping, transformation
from conductor.types.model import Model
from tests.docker.test_files import FIXTURE_DIR


class MockModel(Model):
    def train(self):
        pass

    def predict(self, data: io.StringIO):
        data = int(data.getvalue())
        return data * 2


@dataclass
class FlaskRequest:
    data: bytearray
    content_type: str = "text/csv"


@pytest.mark.datafiles(Path(FIXTURE_DIR, "test_project"))
def test_predict(datafiles, mocker, monkeypatch):
    os.chdir(datafiles)
    sys.path.append(str(datafiles))
    mocker.patch.object(
        conductor.utils.sagemaker_training_environment,
        "INPUT_CONFIG_PATH",
        str(datafiles),
    )
    mocker.patch.object(
        conductor.utils.sagemaker_training_environment, "MODEL_PATH", str(datafiles)
    )

    monkeypatch.setenv("MODEL_CLS_MODULE", MockModel.__module__)
    monkeypatch.setenv("MODEL_CLS_NAME", MockModel.__name__)

    model = MockModel()
    model_path = Path(datafiles, "model.pkl")
    with open(model_path, "wb+") as f:
        pickle.dump(model, f)
    mocker.patch.object(conductor.docker.predictor, "MODEL_PATH", model_path)
    resp = ping()
    assert resp.status == "200 OK"

    mocker.patch.object(
        conductor.docker.predictor.flask,
        "request",
        FlaskRequest("10".encode(encoding="utf-8")),
    )
    resp = transformation()
    assert resp.response == 20
