import json
from typing import Optional, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.decorators import apply_defaults


class SageMakerManifestFileOperator(BaseOperator):
    """
    Operator that generates a JSON manifest file containing a list of S3 object keys
    that you want to input into SageMaker.

    :param source_bucket_name: The name of the source bucket. (templated)
    :type source_bucket_name: str
    :param dest_bucket_name: The name of the destination bucket where the JSON manifest file
        is placed. (templated)
    :type dest_bucket_name: str
    :param manifest_file_key: The key of the manfiest file in the destination S3 bucket.
        It must have a .json extension.
    :type manifest_file_key: str
    :param source_prefix: The prefix of the objects in the source S3 bucket.
        All objects matching this prefix in the bucket will be added to the manifest file. (templated)
    :type source_prefix: str
    :param aws_conn_id: Connection id of the S3 connection to use
    :type aws_conn_id: str
    :param source_keys: The key(s) to add to the manifest file. (templated)
        If specific object keys are provided, these will be added to the manifest file instead of
        all objects matching the prefix.

        The provided prefix must be shared among all keys. If keys don't share a prefix,
        provide an empty string as an argument to source_prefix.
    :type source_keys: str or list
    """

    template_fields = (
        "source_bucket_name",
        "dest_bucket_name",
        "source_prefix",
        "source_keys",
    )

    @apply_defaults
    def __init__(
        self,
        *,
        source_bucket_name: str,
        dest_bucket_name: str,
        manifest_file_key: str,
        source_prefix: str = "",
        aws_conn_id: str = "aws_default",
        source_keys: Optional[Union[str, list]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.source_bucket_name = source_bucket_name
        self.dest_bucket_name = dest_bucket_name
        self.source_prefix = source_prefix
        self.manifest_file_key = manifest_file_key
        self.aws_conn_id = aws_conn_id
        self.source_keys = source_keys

    def execute(self, context):  # pylint: disable=unused-argument
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)

        keys = self.source_keys or s3_hook.list_keys(
            bucket_name=self.source_bucket_name, prefix=self.source_prefix
        )

        keys_formatted = []
        for key in keys:
            if self.source_prefix and not key.startswith(self.source_prefix):
                raise AirflowException(f"The key '{key}' does not start with prefix: '{self.source_prefix}'")
            keys_formatted.append(key.replace(self.source_prefix, "", 1))

        s3_prefix = f"s3://{self.source_bucket_name}/{self.source_prefix}"

        manifest_object = [{"prefix": s3_prefix}]
        for key in keys_formatted:
            manifest_object.append(key)

        string_data = json.dumps(manifest_object, ensure_ascii=False)

        s3_hook.load_string(
            string_data=string_data,
            key=self.manifest_file_key,
            bucket_name=self.dest_bucket_name,
            replace=True,
        )
