# This is the file that implements a flask server to do inferences. It's the file that you will modify to
# implement the scoring for your own algorithm.
import importlib
import io
import os

import flask

PREFIX = "/opt/ml/"
MODEL_PATH: str = os.path.join(PREFIX, "model", "model.pkl")

# A singleton for holding the model. This simply loads the model and holds it.
# It has a predict function that does a prediction based on the model and the input data.


class ModelProvider(object):
    model = None  # Where we keep the model when it's loaded

    @classmethod
    def get_model(cls, path: str):
        """Get the model object for this instance, loading it if it's not already loaded."""
        if cls.model is None:
            model_cls_module = os.environ["MODEL_CLS_MODULE"]
            model_cls_name = os.environ["MODEL_CLS_NAME"]
            module = importlib.import_module(model_cls_module)
            model_cls = getattr(module, model_cls_name)
            cls.model = model_cls.load(path)
        return cls.model


# The flask app for serving predictions
app = flask.Flask(__name__)


@app.route("/ping", methods=["GET"])
def ping() -> flask.Response:
    """Determine if the container is working and healthy. In this sample container, we declare
    it healthy if we can load the model successfully."""
    health = (
        ModelProvider.get_model(MODEL_PATH) is not None
    )  # You can insert a health check here

    status = 200 if health else 404
    return flask.Response(response="\n", status=status, mimetype="application/json")


@app.route("/invocations", methods=["POST"])
def transformation() -> flask.Response:
    """Do an inference on a single batch of data, using the model's set mimetype
    as the response type.
    """
    decoded_request = flask.request.data.decode("utf-8")
    data = io.StringIO(decoded_request)
    # Do the prediction
    model = ModelProvider.get_model(MODEL_PATH)
    result = model.predict(data)
    return flask.Response(
        response=result, status=200, mimetype=model.predict_mimetype()
    )
