import pytest

from conductor.config import RedshiftConfig
from tests.utils import create_conductor_instance


def test_redshift():
    redshift_config = RedshiftConfig(
        host="host",
        port=5439,
        cluster_identifier="test_cluster",
        db_name="test_db",
        db_user="test_user",
        unload_role="arn:aws:iam::400577501318:role/RedshiftS3UnloadRole",
    )
    c = create_conductor_instance(redshift=redshift_config)
    test_redshift_op = c.operators.RedshiftOperator(
        task_id="redshift", query="SELECT * FROM test"
    )

    with pytest.raises(AttributeError):
        test_redshift_op.output_s3_prefix

    assert test_redshift_op.task.sql == "SELECT * FROM test"

    # Test unloads.
    test_redshift_op = c.operators.RedshiftOperator(
        task_id="redshift-unload",
        query="SELECT * FROM test",
        unload=True,
        unload_prefix="",
    )
    assert test_redshift_op.task.sql == (
        "UNLOAD($$ SELECT * FROM test $$)\n"
        "TO 's3://test-project.test-env.123456789012/test-branch/test-dag/{{run_id}}/redshift-unload/'\n"
        "FORMAT PARQUET\n"
        "MAXFILESIZE 50 MB\n"
        "IAM_ROLE 'arn:aws:iam::400577501318:role/RedshiftS3UnloadRole,"
        "arn:aws:iam::123456789012:role/0EKirPILJEXjpsva-unload-role'\n"
        "PARALLEL ON\n"
        "ALLOWOVERWRITE"
    )
    assert (
        test_redshift_op.task.connection.extra
        == '{"iam": true, "aws_conn_id": null, "redshift": true, "cluster-identifier": "test_cluster"}'
    )
    assert (
        test_redshift_op.outputs.s3_url
        == "s3://test-project.test-env.123456789012/test-branch/test-dag/{{run_id}}/redshift-unload"
    )

    test_redshift_op = c.operators.RedshiftOperator(
        task_id="redshift-unload-prefix",
        query="SELECT * FROM test",
        unload=True,
        unload_prefix="prefix",
    )
    assert (
        test_redshift_op.outputs.s3_url
        == "s3://test-project.test-env.123456789012/test-branch/test-dag/{{run_id}}/redshift-unload-prefix/prefix"
    )


def test_redshift_errors(c):
    with pytest.raises(ValueError):
        c.operators.RedshiftOperator(task_id="redshift", query="SELECT * FROM test")
