'''
get_table_structure.py
'''
import json
import multiprocessing
import sys
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor, wait
import py_zenith

DB_NAME = sys.argv[1].lower()
PASSWORD = sys.argv[2]
IP = sys.argv[3]
PORT = sys.argv[4]
INSTANCE_NAME = sys.argv[5]
BASE_DIR = sys.argv[6]
UPGRADE = sys.argv[7]
INSTANCE_NAME_OLD = sys.argv[8]
SQL_DIR = sys.argv[9]


def get_tables(new_dir):
    '''
    get_tables
    :param new_dir:
    :return:
    '''
    zenith_connect = py_zenith.Zenith(
        host=IP,
        username=DB_NAME,
        password=PASSWORD,
        port=str(PORT))
    sql = "select name from sys.sys_tables where USER# = " \
          "(select ID from sys.sys_users where name='%s')" % DB_NAME.upper()
    result = zenith_connect.select(sql)
    futures = []
    if UPGRADE == "pre_upgrade" or UPGRADE == "base_line" or UPGRADE == "upgrade":
        for row in result:
            table_name = row[0].upper()
            futures.append(POOL.submit(get_table_structure_and_into_file, table_name, new_dir))
        wait(futures)
    else:
        # 存放sql结果的目录
        new_sql_dir = os.path.join(SQL_DIR, INSTANCE_NAME, DB_NAME)
        # 存放升级后表结构的目录
        new_structure_dir = os.path.join("%s_structrue" % SQL_DIR, INSTANCE_NAME, DB_NAME)
        # 存放升级前表结构的目录
        pre_dir = os.path.join(BASE_DIR, INSTANCE_NAME_OLD, DB_NAME)
        for row in result:
            pre_tables = get_tables_from_pre(pre_dir)
            table_name = row[0].upper()
            if table_name in pre_tables:
                if not os.path.exists(new_sql_dir):
                    os.makedirs(new_sql_dir)
                if not os.path.exists(new_structure_dir):
                    os.makedirs(new_structure_dir)
                futures.append(POOL.submit(
                    compare_table_structure, pre_dir, new_sql_dir, table_name, new_structure_dir))
        wait(futures)
        # 对于没有变化的不生成目录了
        result = {}
        for future in futures:
            future_result = future.result()
            # 如果结果是 False，更新结果
            if not result.get(future_result[1], False):
                result[future_result[1]] = future_result[0]
        for key, value in result.items():
            if value:
                print("%s has difference" % key)
                with open(os.path.join(key, "sql.sql"), "a+") as _f:
                    _f.write("commit;\n")
            else:
                print("%s no difference" % key)
                command = "rm -rf %s" % key
                subprocess.Popen(command.split(), stdout=subprocess.PIPE, stdin=subprocess.PIPE)

    zenith_connect.close()


def compare_table_structure(pre_dir, new_sql_dir, table_name, new_structure_dir):
    '''
    compare_table_structure
    :param pre_dir:
    :param new_sql_dir:
    :param table_name:
    :param new_structure_dir:
    :return:
    '''

    get_table_structure_and_into_file(table_name, new_structure_dir)

    # 获取新旧字段和索引字典
    with open(os.path.join(new_structure_dir, table_name), "r") as _f:
        table_structure_dict_new = json.load(_f)
        new_columns = set(table_structure_dict_new.keys())
    with open(os.path.join(new_structure_dir, "%s_index" % table_name), "r") as _f:
        table_index_dict_new = json.load(_f)
        new_index = set(table_index_dict_new.keys())
    with open(os.path.join(new_structure_dir, "%s_constraint" % table_name), "r") as _f:
        table_constraint_dict_new = json.load(_f)
        new_constraint = set(table_constraint_dict_new.keys())

    with open(os.path.join(pre_dir, table_name), "r") as _f:
        table_structure_dict_old = json.load(_f)
        old_columns = set(table_structure_dict_old.keys())
    with open(os.path.join(pre_dir, "%s_index" % table_name), "r") as _f:
        table_index_dict_old = json.load(_f)
        old_index = set(table_index_dict_old.keys())
    with open(os.path.join(pre_dir, "%s_constraint" % table_name), "r") as _f:
        table_constraint_dict_old = json.load(_f)
        old_constraint = set(table_constraint_dict_old.keys())

    # 开始比较
    is_difference = False
    need_del = []
    need_modify = []
    need_add_index_list = []
    need_add_constraint_list = []
    need_add_columns_list = []
    with open(os.path.join(new_sql_dir, "sql.sql"), "a+") as _f:
        # 升级后有的约束，升级前没有的约束，需要增加
        need_add_constraint = new_constraint.difference(old_constraint)
        for constraint_name in need_add_constraint:
            is_difference = True
            need_add_constraint_list.append('ALTER TABLE "%s" ADD CONSTRAINT "%s" UNIQUE (%s);\n' % (
                table_name, constraint_name, table_constraint_dict_new[constraint_name]))
        # 升级后没有的约束，升级前有的约束，需要删除
        need_del_constraint = old_constraint.difference(new_constraint)
        for constraint_name in need_del_constraint:
            is_difference = True
            need_del.append('ALTER TABLE "%s" DROP CONSTRAINT "%s";\n' % (
                table_name, constraint_name
            ))
        # 约束不一样的字段，取并集
        for constraint_name in new_constraint & old_constraint:
            new_constraint_columns = table_constraint_dict_new[constraint_name]
            old_constraint_columns = table_constraint_dict_old[constraint_name]
            new_constraint_columns_sorted = ",".join(sorted(new_constraint_columns.replace(" ", "").split(",")))
            old_constraint_columns_sorted = ",".join(sorted(old_constraint_columns.replace(" ", "").split(",")))
            if new_constraint_columns_sorted != old_constraint_columns_sorted:
                is_difference = True
                need_del.append('ALTER TABLE "%s" DROP CONSTRAINT "%s";\n' % (
                    table_name, constraint_name
                ))
                need_add_constraint_list.append('ALTER TABLE "%s" ADD CONSTRAINT "%s" UNIQUE (%s);\n' % (
                    table_name, constraint_name, table_constraint_dict_new[constraint_name]))

        # 升级后有的索引，升级前没有的索引，需要增加
        need_add_index = new_index.difference(old_index)
        for index_name in need_add_index:
            is_difference = True
            need_add_index_list.append('CREATE INDEX "%s" ON "%s"(%s);\n' % (
                index_name, table_name, table_index_dict_new[index_name]))
        # 升级后没有的索引，升级前有的索引，需要删除
        need_del_index = old_index.difference(new_index)
        for index_name in need_del_index:
            is_difference = True
            need_del.append('drop index if exists "%s";\n' % index_name)
        # 索引不一样的字段，取并集
        for index_name in new_index & old_index:
            new_index_columns = table_index_dict_new[index_name]
            old_index_columns = table_index_dict_old[index_name]
            new_index_columns_sorted = ",".join(sorted(new_index_columns.replace(" ", "").split(",")))
            old_index_columns_sorted = ",".join(sorted(old_index_columns.replace(" ", "").split(",")))
            if new_index_columns_sorted != old_index_columns_sorted:
                is_difference = True
                need_del.append('drop index if exists "%s";\n' % index_name)
                need_add_index_list.append('CREATE INDEX "%s" ON "%s"(%s);\n' % (
                    index_name, table_name, table_index_dict_new[index_name]))

        # 升级后有的字段，升级前没有的字段，需要增加
        need_add_columns = new_columns.difference(old_columns)
        for column in need_add_columns:
            is_difference = True
            need_add_columns_list.append("alter table %s add %s %s;\n" % (
                table_name, column, table_structure_dict_new[column]))
        # 升级后没有的字段，升级前有的字段，需要删除
        need_del_columns = old_columns.difference(new_columns)
        for column in need_del_columns:
            is_difference = True
            need_del.append("alter table %s drop %s;\n" % (table_name, column))
        # 表结构不一样的字段（默认值、是否为空等），取并集
        for column in new_columns & old_columns:
            if table_structure_dict_new[column] != table_structure_dict_old[column]:
                is_difference = True
                need_modify.append("alter table %s modify(%s %s);\n" % (
                    table_name, column, table_structure_dict_new[column]))

        with open(os.path.join(new_sql_dir, "sql.sql"), "a+") as _f:
            if is_difference:
                _f.write("\n--%s\n" % table_name)
            for row in need_del + need_modify + need_add_columns_list + need_add_index_list + need_add_constraint_list:
                _f.write("%s" % row)
    return is_difference, new_sql_dir


def get_tables_from_pre(pre_dir):
    '''
    从升级前的 实例/db/ 下，读取所有表
    :return:
    '''
    return set(os.listdir(pre_dir))


def get_table_structure_and_into_file(table_name, new_dir):
    '''
    get_table_structure_and_into_file
    :param table_name:
    :param new_dir:
    :return:
    '''
    table_structure_dict, table_index_dict, table_constraint_dict = get_table_structure(table_name)
    if not os.path.exists(os.path.join(new_dir, "%s" % table_name)):
        with open(os.path.join(new_dir, "%s" % table_name), "w") as _f:
            json.dump(table_structure_dict, _f, indent="\t")
        with open(os.path.join(new_dir, "%s_index" % table_name), "w") as _f:
            json.dump(table_index_dict, _f, indent="\t")
        with open(os.path.join(new_dir, "%s_constraint" % table_name), "w") as _f:
            json.dump(table_constraint_dict, _f, indent="\t")
    return table_structure_dict, table_index_dict, table_constraint_dict


def get_table_structure(table_name):
    '''
    get_table_structure
    :param table_name:
    :return:
    '''
    # 执行sh脚本并获取返回值
    command = "sh %s %s %s %s %s %s" % (
        "get_table_structure.sh", DB_NAME, PASSWORD, IP, PORT, table_name)
    result = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stdin=subprocess.PIPE)
    table_structure_str = result.stdout.read().decode()
    table_structure_result = re.search(".*SQL.*?\n\(\n(.*?)\n\)\n", table_structure_str, re.S)
    table_structure_dict = {}
    for row in table_structure_result.group(1).split("\n"):
        if row.split():
            column_list = re.search("\"(.*)\"(.*)", row)
            table_structure_dict[column_list.group(1).strip()] = \
                column_list.group(2).strip().strip(",")

    # 索引
    table_index_dict = {}
    for row in re.findall("CREATE INDEX.*", table_structure_str):
        index_name = re.search('.*CREATE INDEX.*?\"(.*?)\"', row).group(1)
        columns = re.search('.*\((.*?)\)', row).group(1)
        table_index_dict[index_name] = columns

    # 约束
    table_constraint_dict = {}
    for row in re.findall("ADD CONSTRAINT.*", table_structure_str):
        constraint_name = re.search('.*ADD CONSTRAINT.*?\"(.*?)\"', row).group(1)
        columns = re.search('.*\((.*?)\)', row).group(1)
        table_constraint_dict[constraint_name] = columns

    return table_structure_dict, table_index_dict, table_constraint_dict


def main():
    '''
    main
    :return:
    '''
    new_dir = os.path.join(BASE_DIR, INSTANCE_NAME, DB_NAME)
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)
    get_tables(new_dir)


if __name__ == '__main__':
    POOL = ThreadPoolExecutor(max_workers=multiprocessing.cpu_count() // 2 + 1)
    main()
