# This file implements the scoring service shell. You don't necessarily need to modify it for various
# algorithms. It starts nginx and gunicorn with the correct configurations and then simply waits until
# gunicorn exits.
#
# The flask server is specified to be the app object in wsgi.py
#
# We set the following parameters:
#
# Parameter                Environment Variable              Default Value
# ---------                --------------------              -------------
# number of workers        MODEL_SERVER_WORKERS              the number of CPU cores
# timeout                  MODEL_SERVER_TIMEOUT              60 seconds

import logging
import multiprocessing
import os
import signal
import sys
from pathlib import Path
from subprocess import Popen, check_call
from typing import Union

MODEL_SERVER_TIMEOUT: Union[int, str] = os.environ.get("MODEL_SERVER_TIMEOUT", 60)
MODEL_SERVER_WORKERS = int(
    os.environ.get("MODEL_SERVER_WORKERS", multiprocessing.cpu_count())
)

LOGDIR = "/var/log/nginx/"
PROGRAM_DIR = "/opt/program/"


def sigterm_handler(nginx_pid, gunicorn_pid):
    try:
        os.kill(nginx_pid, signal.SIGQUIT)
    except OSError:
        pass
    try:
        os.kill(gunicorn_pid, signal.SIGTERM)
    except OSError:
        pass

    sys.exit(0)


def start_server() -> None:
    logging.info(
        "Starting the inference server with {} workers.".format(MODEL_SERVER_WORKERS)
    )

    # link the log streams to stdout/err so they will be logged to the container logs
    check_call(["ln", "-sf", "/dev/stdout", str(Path(LOGDIR, "access.log"))])
    check_call(["ln", "-sf", "/dev/stderr", str(Path(LOGDIR, "error.log"))])

    nginx = Popen(["nginx", "-c", str(Path(PROGRAM_DIR, "nginx.conf"))])
    gunicorn = Popen(
        [
            "gunicorn",
            "--timeout",
            str(MODEL_SERVER_TIMEOUT),
            "-k",
            "gevent",
            "-b",
            "unix:/tmp/gunicorn.sock",
            "-w",
            str(MODEL_SERVER_WORKERS),
            "wsgi:app",
        ]
    )

    signal.signal(signal.SIGTERM, lambda a, b: sigterm_handler(nginx.pid, gunicorn.pid))

    # If either subprocess exits, so do we.
    pids = set([nginx.pid, gunicorn.pid])
    while True:
        pid, _ = os.wait()
        if pid in pids:
            break

    logging.info("Inference server exiting")
    sigterm_handler(nginx.pid, gunicorn.pid)


# The main routine just invokes the start function.

if __name__ == "__main__":
    start_server()
