import unittest
from unittest import mock

import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from parameterized import parameterized

from twitch_airflow_components.operators.sagemaker_processing import SageMakerProcessingOperator

job_name = "test-job-name"

create_processing_params = {
    "AppSpecification": {
        "ContainerArguments": ["container_arg"],
        "ContainerEntrypoint": ["container_entrypoint"],
        "ImageUri": "{{ image_uri }}",
    },
    "Environment": {"{{ key }}": "{{ value }}"},
    "ExperimentConfig": {
        "ExperimentName": "ExperimentName",
        "TrialComponentDisplayName": "TrialComponentDisplayName",
        "TrialName": "TrialName",
    },
    "ProcessingInputs": [
        {
            "InputName": "AnalyticsInputName",
            "S3Input": {
                "LocalPath": "{{ Local Path }}",
                "S3CompressionType": "None",
                "S3DataDistributionType": "FullyReplicated",
                "S3DataType": "S3Prefix",
                "S3InputMode": "File",
                "S3Uri": "{{ S3Uri }}",
            },
        }
    ],
    "ProcessingJobName": job_name,
    "ProcessingOutputConfig": {
        "KmsKeyId": "KmsKeyID",
        "Outputs": [
            {
                "OutputName": "AnalyticsOutputName",
                "S3Output": {
                    "LocalPath": "{{ Local Path }}",
                    "S3UploadMode": "EndOfJob",
                    "S3Uri": "{{ S3Uri }}",
                },
            }
        ],
    },
    "ProcessingResources": {
        "ClusterConfig": {
            "InstanceCount": 2,
            "InstanceType": "ml.p2.xlarge",
            "VolumeSizeInGB": 30,
            "VolumeKmsKeyId": "{{ kms_key }}",
        }
    },
    "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
    "Tags": [{"{{ key }}": "{{ value }}"}],
}

create_processing_params_with_stopping_condition = create_processing_params.copy()
create_processing_params_with_stopping_condition.update(
    StoppingCondition={"MaxRuntimeInSeconds": 3600}
)


class TestSageMakerProcessingOperator(unittest.TestCase):
    def setUp(self):
        self.processing_config_kwargs = dict(
            task_id="test_sagemaker_operator",
            wait_for_completion=False,
            check_interval=5,
        )

    @parameterized.expand(
        [
            (
                create_processing_params,
                [
                    ["ProcessingResources", "ClusterConfig", "InstanceCount"],
                    ["ProcessingResources", "ClusterConfig", "VolumeSizeInGB"],
                ],
            ),
            (
                create_processing_params_with_stopping_condition,
                [
                    ["ProcessingResources", "ClusterConfig", "InstanceCount"],
                    ["ProcessingResources", "ClusterConfig", "VolumeSizeInGB"],
                    ["StoppingCondition", "MaxRuntimeInSeconds"],
                ],
            ),
        ]
    )
    def test_integer_fields_are_set(self, config, expected_fields):
        sagemaker = SageMakerProcessingOperator(
            **self.processing_config_kwargs, config=config
        )
        assert sagemaker.integer_fields == expected_fields

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(
        SageMakerHook,
        "create_processing_job",
        return_value={
            "ProcessingJobArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 200},
        },
    )
    def test_execute(self, mock_processing, mock_client):
        sagemaker = SageMakerProcessingOperator(
            **self.processing_config_kwargs, config=create_processing_params
        )
        sagemaker.execute(None)
        mock_processing.assert_called_once_with(
            create_processing_params,
            wait_for_completion=False,
            check_interval=5,
            max_ingestion_time=None,
        )

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(
        SageMakerHook,
        "create_processing_job",
        return_value={
            "ProcessingJobArn": "testarn",
            "ResponseMetadata": {"HTTPStatusCode": 404},
        },
    )
    def test_execute_with_failure(self, mock_processing, mock_client):
        sagemaker = SageMakerProcessingOperator(
            **self.processing_config_kwargs, config=create_processing_params
        )
        with pytest.raises(AirflowException):
            sagemaker.execute(None)
