import s
import datetime as dt
import airflow
from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
import sys
import json
import requests
import traceback
import psycopg2
import os
from threading import Thread
import time
from collections import defaultdict
from pytz import timezone

def execute_sql_script(connection,filename,dependency):
    """ Loops through all spade, tahoe, cubes and dbsnapshots to make sure all of the dependencies are updated or not. """
    # Initializes all info about dependencies and psql files.
    file_info = open(filename).read()
    checks = [False for i in dependency]
    start_time = time.time()
    updated = False

    print('Entering While LOOP')

    # If no dependencies exist, doesn't enter the while loop.
    if len(dependency) == 0:
        updated = True

    while not updated:
        # If the time that was spent in this loop more than 3 hours, then sends a slack message and ends this airflow task.
        if time.time()-start_time > 18000:
            send_slack_message('Hi @channel the query `{}` is stuck in the dependency validation loop. These are the dependencies that the script is waiting on: `{}`'.format(filename, ','.join(dependency[i] for i in range(len(dependency)) if not checks[i])))
            break

        # Initializes thread list and pool list.
        threads = []
        pool = [i for i in range(len(dependency)) if not checks[i] and dependency[i].strip() != '']

        # Creates threads to check each dependency for whether it was updated or not.
        for index in pool:
            process = Thread(target=is_updated,args=[connection,dependency[index].strip(),checks,index])
            process.start()
            threads.append(process)

        # Syncing up all of the threads that were deployed.
        for process in threads:
            process.join()
        # Checks for whether all of the dependencies are updated. If so, breaks from the while loop.
        if len(set(checks)) == 1 and list(set(checks))[0]:
            updated = True

    print('Exited WHILE LOOP, took {} seconds'.format(time.time()-start_time))

    # Runs the psql file.
    if updated:
        execute_scripts(filename,connection,file_info.split(';'))

def execute_scripts(filename,connection, commands):
    """ Executes the PSQL Script and sends a slack message.  """

    # As per https://jira.twitch.com/browse/DSDS-638, this flag ensures that we do not send success messages if it's False.
    send_success_messages = False

    # Initializes the error string and cursor
    errors = ''
    cursor = connection.cursor()
    # Iterates through all of the psql commands.
    for command in commands:
        # If command is empty, no need to run.
        if len(command.strip()) != 0:
            try:
                cursor.execute(command)
                connection.commit()
            except:
                # Checks to make sure that the same error isn't already in the error string.
                if 'empty query' not in traceback.format_exc() and traceback.format_exc() not in errors and 'SIGTERM' not in traceback.format_exc():
                    errors += 'Command executed: {}\n'.format(command)+traceback.format_exc() + '\n'

    # Builds out the slack message to send based on whether errors is empty or not and sends the slack message.
    if len(errors.strip()) != 0:
        message = ':x: Hi, the query `{}` finished running, but there were some errors :x: \n{}\nNote: This was part of the airflow job from the *coconut_box* server.'.format(filename,errors)
        send_slack_message(message)
    else:
        if send_success_messages:
            send_slack_message(':white_check_mark: Hi, the query `{}` ran successfully!'.format(filename))

def is_updated(connection,table,checks,index):
    """ Logic to check to ensure whether a cubes, tahoe, or dbsnapshots table was updated or not. """
    # Initializes cursor and table that's being checked.
    initial_cursor = connection.cursor()
    table_info = table.split('.')
    schema = table_info[0]
    table_name = table_info[1]

    # Runs the first command to get the most updated table name from pg_views
    command1 = 'select definition from pg_views where schemaname = \'{}\' and viewname = \'{}\';'.format(schema, table_name)
    initial_cursor.execute(command1)
    connection.commit()
    result = initial_cursor.fetchone()
    result = result[0]

    # Searches the resuling string from the result that was fetched for the updated table name and version.
    new_table = result.split()
    new_table = new_table[result.split().index('FROM')+1].split('.')
    schema = new_table[0]
    table_name = new_table[1]
    table_name = table_name.replace('"','')

    # Runs the second psql script to get the date the table was last updated.
    command2 = 'select max(values) from svv_external_partitions where schemaname = \'{}\' and tablename = \'{}\';'.format(schema, table_name)
    initial_cursor.execute(command2)
    connection.commit()
    date = initial_cursor.fetchone()
    date = date[0]

    # There are scenarios where the second psql query can yield no results. If it doesn't give a result, then run a third psql query to find the s3 bucket.
    if date is None:
        # Run the third psql script to get the s3 bucket name.
        command3='select location from svv_external_tables where schemaname = \'{}\' and tablename = \'{}\';'.format(schema, table_name)
        initial_cursor.execute(command3)
        connection.commit()
        date = initial_cursor.fetchone()

        # Processes the string received and slices it to build a datetime object for comparison.
        date = date[0]
        date = date.split('/')[-1][:8]
        date = dt.date(int(date[:4]) , int(date[4:6]), int(date[6:8]))
    else:
        # Processes the date string and builds the datetime object.
        date = date.split('-')
        date = dt.date(int(date[0]),int(date[1]),int(date[2]))

    #Ensures that the two dates are in the PST timezone and not a mix of UTC and PST.
    newdate = dt.datetime.now(timezone('US/Pacific'))
    newdate = dt.date(newdate.year,newdate.month,newdate.day)

    # Get the difference between both dates and checks for whether the difference is less than 1 day.
    difference = newdate-date
    if difference.days <= 1:
        checks[index] = True
        return True
    return False

def send_slack_message(message):
    """ Builds the json object to POST into a slack webhook. """
    slack_url = 'https://hooks.slack.com/services/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'
    payload = {"text":"","link_names":1,"mrkdwn":True}
    payload['text'] = message
    requests.post(slack_url,data=json.dumps(payload),headers={'Content_Type': 'application/json'})

def setup():
    """ Setup print statement. """
    send_slack_message('Hi @channel, running rollup scripts.')

def unittests(connection):
    """ Unit Test start print statement. """
    send_slack_message('Hi @channel, rollups finished. Running unit tests')

def end():
    """ Unit Test end print statement. """
    send_slack_message('Hi @channel, airflow run ended')

def run_unit_tests(connection, filename):
    """ Runs each unit test and sends a slack message for whether there are any errors or not. """
    # Creates cursor and opens up the unit test file.
    cursor = connection.cursor()
    f = open(filename)
    info = f.read()
    errors = ''

    # Executes the unit test and checks for whether there are any errors.
    try:
        cursor.execute(info)
        connection.commit()
        line = cursor.fetchone()

        # Fields required by JIRA Ticket https://jira.twitch.com/browse/DSDS-595
        day = line[0]
        ldap = line[1]
        flags = line[2:4]
        test_info = line[4]
        passed = line[5]
        unit_test_result = line[6]
        tables = line[-1]
        if flags[int(passed)]:
            formatted_tables = ', '.join(['`'+ i.strip() + '`' for i in tables.split()])
            message = 'Hi @{} here is your data alert for *{}*:\n The logic you checked for:\n```{}```\nThe status of the check: {}\nThe related tables are: {}'.format(ldap,day,test_info,unit_test_result,formatted_tables)
            send_slack_unit_message(message)
        else:
            pass
    except:
        if 'empty query' not in traceback.format_exc() and traceback.format_exc() not in errors and 'SIGTERM' not in traceback.format_exc():
            errors += 'Command executed: {}\n'.format(command)+traceback.format_exc() + '\n'
            # If errors is empty,  then a successful slack message will be sent, else, prints out the info for each unit test
            send_slack_message('Hi @channel, unit test file `{}` finished running with errors:\n {}'.format(filename, errors))

def send_slack_unit_message(message):
    """Special Slack Messages for Unit Tests"""
    slack_url = 'https://hooks.slack.com/services/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'
    payload = {"text":"","link_names":1,"mrkdwn":True}
    payload['text'] = message
    requests.post(slack_url,data=json.dumps(payload),headers={'Content_Type': 'application/json'})



""" Builds out the DAG and initializes all of the conditions for running. """


#Contains all of the airflow tasks.
tasks = defaultdict()

#All of the psql files
files = os.listdir('/home/airflow/files')

#If no files exist, then exit script.
if len(files) == 0:
    print('ERROR: No files found in here.')
    sys.exit()

#Builds all tahoe and spade dependencies for each file
dependencies = defaultdict(list)

for line in open('/home/airflow/dependencies.csv'):
    l = line.split(',')
    dependencies[l[0]] = [i.strip() for i in l[1:] if len(i.strip()) != 0 and i.strip() != '']

#If no dependencies, then quit since we can't validate the data we're inserting.
if len(dependencies.keys()) == 0:
    print('ERROR: dependencies.csv does not seem correct')
    sys.exit()

# Gets all of the unit test files
unit_test_files = os.listdir('/home/airflow/files_testing')

# Created empty dag
default_args={'owner':'manchikc', 'start_date' : dt.datetime(2019,8,11,11,0,0,0), 'retries':1,'retry_delay':dt.timedelta(minutes=5), 'depends_on_past': False}
dag = DAG('rollups',default_args=default_args, schedule_interval='0 12 * * *',concurrency=10,catchup=False)#dt.timedelta(minutes=2))

conn = s.getConnection()

# Initialized
tasks[0] = PythonOperator(task_id='run_rollup_scripts',python_callable=setup,dag=dag)
tasks[-1] = PythonOperator(task_id='unit_tests',python_callable=unittests,dag=dag, op_kwargs={'connection':conn})
tasks[-2] = PythonOperator(task_id='end',python_callable=end,dag=dag)

# Adding airflow tasks to the tasks list
for i in files:
    filename = '/home/airflow/files/{}'.format(i)
    if i.split('.')[-1] == 'sql':
        tasks[int(i.split('_')[0])] = PythonOperator(task_id=i,python_callable=execute_sql_script,op_kwargs={'connection':conn,'filename':filename,'dependency':dependencies[i]},dag=dag)
    elif i.split('.')[-1] == 'py':
        tasks[int(i.split('_')[0])] = BashOperator(task_id=i, bash_command='/usr/bin/python3.6 {}'.format(i), dag=dag)
    else:
        print(filename)
        print('WRONG NAME OF ROLLUP SCRIPT')

# Building dependency relationships
for i in list(tasks.keys()):
    if i <= 0:
        continue
    a = tasks[i].task_id
    if '-' in a.split('_')[1]:
        all_dependencies= a.split('_')[1].split('-')
        for r in all_dependencies:
            try:
                tasks[int(r)] >> tasks[i]
            except:
                print('INVALID DEPENDENCY: {}'.format(r))
    else:
        tasks[int(a.split('_')[1])] >> tasks[i]
    tasks[i] >> tasks[-1]

#Unique file numbers for unit tests.
unit_test = -100000
new_keys = []

# Created unit test airflow tasks and sets the dependencies ready.
for file in unit_test_files:
    filename = '/home/airflow/files_testing/{}'.format(file)
    if file.split('.')[-1] == 'sql':
        tasks[unit_test] = PythonOperator(task_id=file,python_callable=run_unit_tests,op_kwargs={'connection':conn,'filename':filename},dag=dag)
    elif file.split('.')[-1] == 'py':
        tasks[unit_test] = BashOperator(task_id=file, bash_command='/usr/bin/python3.6 {}'.format(i), dag=dag)
    else:
        print('WRONG NAME OF UNIT TEST FILE')

    tasks[-1] >>tasks[unit_test]
    tasks[unit_test] >> tasks[-2]
    unit_test-=1
