#!/usr/bin/env python
# -*-coding:utf-8-*-
import traceback
import sys
import os
from dsware_kmc_tool import KmcApi
import logging.handlers
import getopt
import subprocess
import shlex
import binascii
import shutil

current_dir = sys.path[0]
ld_path = os.getenv('LD_LIBRARY_PATH')
lib_path = os.path.join(current_dir, "..")
if ld_path is None or lib_path not in ld_path:
    if ld_path is None:
        os.environ['LD_LIBRARY_PATH'] = "/usr/lib64" + ":" + lib_path
    else:
        os.environ['LD_LIBRARY_PATH'] += ':' + "/usr/lib64" + ":" + lib_path
    try:
        os.execv(sys.argv[0], sys.argv)
    except Exception as e:
        print("Exception: failed to execute under modified environment %s" % e)
        sys.exit(2)


# 初始化日志，1个备份，最大1MB，忽略异常，不影响脚本执行
logfile_dir = '/var/log/fsc_cli/'
logfile = os.path.join(logfile_dir, 'fsc_cli.log')
os.umask(0o0027)  # 设置日志文件权限
with open("/dev/random", 'rb') as file:
    sr = file.read(3)
rand = binascii.hexlify(sr)
tmp_random_dir = os.path.join("tmp", str(int(rand, 16)))


def should_log_file():
    return not os.path.islink(logfile)


try:
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    if not os.path.exists(logfile_dir):
        os.makedirs(logfile_dir)
        os.chmod(logfile_dir, 600)
    handler = logging.handlers.RotatingFileHandler(logfile, maxBytes=1024 * 1024, backupCount=1)
    formatter = logging.Formatter(
        fmt="[%(asctime)s] [%(levelname)s] [%(process)d] [%(filename)s:%("
            "lineno)d]: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S")

    handler.setFormatter(formatter)
    loggingFilter = logging.Filter()
    loggingFilter.filter = should_log_file
    handler.addFilter(loggingFilter)
    logger.addHandler(handler)
except Exception as e:
    print("log error %s" % e)


def execute_cmd(cmd, text_in=None):
    """
    使用subprocess执行命令
    :param text_in:
    :param cmd:
    eg1: 'ls /opt'
    eg2: 'chown omm:wheel CheckReport.json'
    :return: 返回码, 标准输出/标准错误
    """
    logger.info("exe cmd: %s " % cmd)
    process = subprocess.Popen(shlex.split(cmd), stdin=subprocess.PIPE,
                               stderr=subprocess.PIPE,
                               stdout=subprocess.PIPE, universal_newlines=True)
    stdout, stderr = process.communicate(text_in)
    code = process.returncode
    output = stdout if code == 0 else stderr
    if code:
        logger.error("exe: %s fail, out: %s, err: %s" % (cmd, stdout, stderr))
        raise RuntimeError(stderr)
    return code, output


def execute_cmd_with_pipe(cmd_list):
    """
    使用subprocess执行命令
    :param cmd_list:
    eg: ['ps -ef', 'grep ha.bin', 'grep -v grep']
    :return: 返回码, 标准输出/标准错误
    """
    cmd_list = [shlex.split(cmd) for cmd in cmd_list]
    process = subprocess.Popen(cmd_list[0], stdin=subprocess.PIPE,
                               stderr=subprocess.PIPE, stdout=subprocess.PIPE)
    for cmd in cmd_list[1:]:
        process = subprocess.Popen(cmd, stdin=process.stdout,
                                   stderr=subprocess.PIPE,
                                   stdout=subprocess.PIPE)
    stdout, stderr = process.communicate()
    code = process.returncode
    if code:
        logger.error("execute command fail, out: %s, err: %s" % (
            stdout, stderr))
        raise RuntimeError(stderr)
    output = stdout if code == 0 else stderr
    return code, output


def get_conf_password_kmc(ssl_password):
    domain_id = 50
    sys.path.insert(0, os.path.join(current_dir, ".."))
    KMC_LIB = os.path.join(current_dir, "../libfsbKmcTool.so")
    kmc_api = KmcApi(KMC_LIB)
    ret_val = kmc_api.initial()
    if not ret_val:
        logger.error("initial kmc failed")
        sys.exit(1)
    ret_val, text_out = kmc_api.decrypt(domain_id, ssl_password)
    if not ret_val:
        sys.exit(1)
    if isinstance(text_out, bytes):
        text_out = text_out.decode()
    return text_out


def get_conf_password_openssl(ssl_password):
    key_value = "0E879F0106D3EED8FAEE29C20EF7104F6A62BD0413B96E1F542C47DD7A82C010"
    iv_str = ssl_password[0:16]
    password = ssl_password[16:]
    cmd_str = ["echo %s" % password,
               "openssl aes-256-cbc -d -K %s -iv %s -base64" % (
                   key_value, iv_str)]
    ret_code, ret_data = execute_cmd_with_pipe(cmd_str)
    if ret_code:
        logger.error('execute cmd to get openssl password fail')
    if isinstance(ret_data, bytes):
        ret_data = ret_data.decode()
    return ret_data.rstrip()


def get_conf_password(conf_dir, key):
    cert_config_path = os.path.join(conf_dir, "dsware-api.properties")
    ssl_password = None
    with open(cert_config_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith(key):
                ssl_password = line[line.find("=") + 1:]
                break
    if not ssl_password:
        raise Exception("get %s fail in %s" % (key, cert_config_path))
    ssl_password = ssl_password.strip()
    kmc_flag = True
    with open(cert_config_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith("api.content="):
                kmc_flag = False
                break
    if kmc_flag:
        logger.info("get password from kmc")
        return get_conf_password_kmc(ssl_password)
    else:
        logger.info("get password from ssl")
        return get_conf_password_openssl(ssl_password)


def clear_dest_file(file_name):
    if os.path.exists(file_name):
        os.remove(file_name)
    dir_name = os.path.dirname(file_name)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)


def get_server_pem(conf_dir, dest_cert_file, dest_key_file, password):
    client_self_password = get_conf_password(conf_dir, "ssl.self.ks.passwd")
    client_self_path = os.path.join(conf_dir, "client_self.keystore")
    tmp_client_self_key = os.path.join(tmp_random_dir, "client_self.key")
    clear_dest_file(tmp_client_self_key)
    cmd_str = "openssl pkcs12 -in %s -passin stdin -clcerts -nokeys -out %s" \
              % (client_self_path, dest_cert_file)
    ret_code, ret_data = execute_cmd(cmd_str, client_self_password)
    if ret_code:
        logger.error('execute cmd: %s fail: %s' % (cmd_str, ret_data))

    cmd_str = "openssl pkcs12 -nodes -in %s -passin stdin -nocerts -out %s" \
              % (client_self_path, tmp_client_self_key)
    ret_code, ret_data = execute_cmd(cmd_str, client_self_password)
    if ret_code:
        logger.error('execute cmd: %s fail: %s' % (cmd_str, ret_data))
    cmd_str = "openssl pkcs8 -in %s -passin stdin -topk8 -v2 aes-256-cbc" \
              " -out %s -passout stdin" % (tmp_client_self_key, dest_key_file)
    new_password = "%s\n%s\n" % (password, password)
    ret_code, ret_data = execute_cmd(cmd_str, new_password)
    if ret_code:
        logger.error('execute cmd: %s fail: %s' % (cmd_str, ret_data))
    clear_dest_file(tmp_client_self_key)


def get_ca_pem(conf_dir, dest_ca_file, jre_dir=None):
    client_trust_password = get_conf_password(conf_dir, "ssl.trust.ks.passwd")
    client_trust_path = os.path.join(conf_dir, "client_trust.keystore")
    tmp_pkcs12_file = os.path.join(tmp_random_dir, "ca.p12")
    clear_dest_file(tmp_pkcs12_file)
    if jre_dir:
        jre_bin_dir = os.path.join(jre_dir, "bin")
    else:
        jre_path = os.popen("find /usr/share/dsware/ -name 'jre*' -type d").readlines()[0].replace("\n", "")
        jre_bin_dir = os.path.join(jre_path, "bin")
    key_tool_bin = os.path.join(jre_bin_dir, "keytool")
    cmd_str = "%s -importkeystore -srckeystore %s -srcstoretype JKS " \
              "-deststoretype PKCS12 -destkeystore %s" \
              % (key_tool_bin, client_trust_path, tmp_pkcs12_file)
    ca_password = "%s\n%s\n%s\n" % (
        client_trust_password, client_trust_password, client_trust_password)
    ret_code, ret_data = execute_cmd(cmd_str, ca_password)
    if ret_code:
        logger.error('execute cmd: %s fail: %s' % (cmd_str, ret_data))
    cmd_str = "openssl pkcs12 -in %s -passin stdin -cacerts -nokeys -out %s" \
              % (tmp_pkcs12_file, dest_ca_file)
    ret_code, ret_data = execute_cmd(cmd_str, client_trust_password)
    if ret_code:
        logger.error('execute cmd: %s fail: %s' % (cmd_str, ret_data))
    clear_dest_file(tmp_pkcs12_file)


def get_pem(conf_dir, output_dir, password, jre_dir=None):
    dest_cert_file = os.path.join(output_dir, "cert_file.cert")
    dest_key_file = os.path.join(output_dir, "key_file.ca")
    dest_ca_file = os.path.join(output_dir, "ca_file.ca")

    clear_dest_file(dest_cert_file)
    clear_dest_file(dest_key_file)
    clear_dest_file(dest_ca_file)

    get_server_pem(conf_dir, dest_cert_file, dest_key_file, password)
    get_ca_pem(conf_dir, dest_ca_file, jre_dir)
    os.rmdir(tmp_random_dir)


def copy_file(src_dir, dst_path):
    if not os.path.exists(dst_path):
        os.makedirs(os.path.dirname(dst_path))
        shutil.copy(os.path.join(src_dir, os.path.basename(dst_path)), dst_path)


def main(argv):
    try:
        opts, args = getopt.getopt(argv, "d:o:j:")
    except getopt.GetoptError:
        print('%s -d <input dir> -o <output dir>' % sys.argv[0])
        sys.exit(2)
    conf_dir = None
    output_dir = None
    jre_dir = None
    for opt, arg in opts:
        if opt in ("-d", "--ifile"):
            conf_dir = arg
        elif opt in ("-o", "--ofile"):
            output_dir = arg
        elif opt in ("-j", "--jre_dir"):
            jre_dir = arg
    if not conf_dir or not output_dir:
        print('%s -d <input dir> -o <output dir>' % sys.argv[0])
        sys.exit(1)

    os.system("stty -echo")
    if 2 == sys.version_info[0]:
        cert_password = raw_input("password:")
    else:
        cert_password = input("password:")
    os.system("stty echo")
    try:
        primary_key = "/opt/dsware/agent/conf/kmc/conf/primary_ks.key"
        standby_key = "/opt/dsware/agent/conf/kmc/bkp/standby_ks.key"
        copy_file(lib_path, primary_key)
        copy_file(lib_path, standby_key)
        get_pem(conf_dir, output_dir, cert_password, jre_dir)
    except Exception as exc:
        logger.error("Exception: %s" % exc)
        logger.error(traceback.format_exc())
        sys.exit(1)
    finally:
        logger.info("general pem to %s finished!" % conf_dir)
    logger.info("general pem to %s success!" % conf_dir)


if __name__ == "__main__":
    main(sys.argv[1:])
