#!/usr/bin/env python
# -*-coding:utf-8-*-
import os
import sys
import ctypes
import logging.handlers
from collections import namedtuple

if 2 == sys.version_info[0]:
    PYTHON_VERSION = 2
else:
    PYTHON_VERSION = 3

# 初始化日志，1个备份，最大1MB，忽略异常，不影响脚本执行
LOG_DIR = "/var/log/fsc_cli"
LOG_PATH = "/var/log/fsc_cli/fsc_cli.log"
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.abspath(CUR_DIR + os.path.sep + "..")
DSWARE_API_LOG_DIR = PARENT_DIR + os.path.sep + "log"
DSWARE_API_LOG_PATH = DSWARE_API_LOG_DIR + os.path.sep + "dsware_api.log"
os.umask(0o0027)  # 设置日志文件权限


def should_log_file(record):
    return not os.path.islink(LOG_PATH)


try:
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    if os.path.exists(DSWARE_API_LOG_DIR):
        LOG_DIR = DSWARE_API_LOG_DIR
        LOG_PATH = DSWARE_API_LOG_PATH
    if not os.path.exists(LOG_DIR):
        os.makedirs(LOG_DIR)
        os.chmod(LOG_DIR, 750)
    handler = logging.handlers.RotatingFileHandler(LOG_PATH, maxBytes=1024 * 1024, backupCount=1)
    formatter = logging.Formatter(
        fmt='[%(asctime)s] [%(filename)s:%(lineno)d] '
            '[%(levelname)s]: %(message)s',
        datefmt="%Y-%m-%d %H:%M:%S")
    handler.setFormatter(formatter)
    loggingFilter = logging.Filter()
    loggingFilter.filter = should_log_file
    handler.addFilter(loggingFilter)
    logger.addHandler(handler)
except Exception as e:
    print("log error")


# 调用kmc动态库加解密
class KmcApi(object):
    def __init__(self, dllfile):
        self.path = dllfile
        self.lib = None

    def initial(self):
        # 加载动态库
        try:
            self.lib = ctypes.cdll.LoadLibrary(self.path)
        except Exception as err:
            logger.error('load dll (%s) failed, err (%s).', self.path, err)
            return False
        # 初始化kmc客户端
        if PYTHON_VERSION == 2:
            node_type = bytes("FSA")
        else:
            node_type = bytes("FSA", encoding='utf8')
        ret_code = self.lib.DSWARE_KMC_Initialize(node_type)
        if ret_code != 0:
            logger.error('initial kmc failed, err (%u)', ret_code)
            return False
        return True

    def _encrypt(self, enctypt_info):
        # 描述接口(参数都是ctypes类型，类似C语言接口用法)
        cipher_text_ptr = ctypes.pointer(ctypes.cast(enctypt_info[3],
                                                     ctypes.c_char_p))
        self.lib.DSWARE_KMC_Encrypt.argtypes = \
            [ctypes.c_int, ctypes.c_char_p, ctypes.c_uint,
             ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_uint)]
        self.lib.DSWARE_KMC_Encrypt.restype = ctypes.c_int
        # 调用接口
        ret_code = self.lib.DSWARE_KMC_Encrypt(
            enctypt_info[0], enctypt_info[1], enctypt_info[2], cipher_text_ptr,
            ctypes.pointer(enctypt_info[4]))
        if ret_code != 0:
            return None
        return cipher_text_ptr[0]

    def encrypt(self, domain, plaintext):
        # 参数都是字符串类型，Python语言用法
        domain_id = ctypes.c_int(domain)
        plain_buff_len = 1024
        plain_text = ctypes.create_string_buffer(plain_buff_len)
        if len(plaintext) >= plain_buff_len:
            logger.error('plain text len (%u) too long.', len(plaintext))
            return False, None
        if PYTHON_VERSION == 2:
            ctypes.memmove(plain_text, bytes(plaintext), len(plaintext))
        else:
            ctypes.memmove(plain_text, bytes(plaintext, encoding='utf8'),
                           len(plaintext))
        plain_text_len = ctypes.c_uint(len(plaintext) + 1)
        cipher_buff_len = 1024
        cipher_text = ctypes.create_string_buffer(cipher_buff_len)
        cipher_text_len = ctypes.c_uint(cipher_buff_len)
        encrypt_tuple_info = namedtuple('encrypt_tuple_info',
                                        ['domain_id', 'plain_text', 'plain_len', 'cipher_text', 'cipher_len'])
        enctypt_info = encrypt_tuple_info(domain_id, plain_text, plain_text_len,
                                   cipher_text, cipher_text_len)
        result_str = self._encrypt(enctypt_info)
        if not result_str:
            return False, None
        KmcApi.memset(plain_text)
        return True, result_str

    def _decrypt(self, decrypt_info):
        # 描述接口(参数都是ctypes类型，类似C语言接口用法)
        plaintext_ptr = ctypes.pointer(ctypes.cast(decrypt_info[3],
                                                   ctypes.c_char_p))
        self.lib.DSWARE_KMC_Decrypt.argtypes = [
            ctypes.c_int, ctypes.c_char_p, ctypes.c_uint,
            ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_uint)]
        self.lib.DSWARE_KMC_Decrypt.restype = ctypes.c_int
        # 调用接口
        ret_code = self.lib.DSWARE_KMC_Decrypt(
            decrypt_info[0], decrypt_info[1], decrypt_info[2], decrypt_info[3],
            ctypes.pointer(decrypt_info[4]))
        if ret_code != 0:
            return False
        return plaintext_ptr[0]

    def decrypt(self, domain, cipher):
        # 参数都是字符串类型，Python语言用法
        domain_id = ctypes.c_int(domain)
        cipher_buff_len = 1024
        cipher_text = ctypes.create_string_buffer(cipher_buff_len)
        if len(cipher) >= cipher_buff_len:
            logger.error('cipher text len (%u) too long.', len(cipher))
            return False, None
        if PYTHON_VERSION == 2:
            ctypes.memmove(cipher_text, bytes(cipher), len(cipher))
        else:
            ctypes.memmove(cipher_text, bytes(cipher, encoding='utf8'),
                           len(cipher))
        cipher_text_len = ctypes.c_uint(len(cipher))
        plain_buff_len = 1024
        plain_text = ctypes.create_string_buffer(plain_buff_len)
        plain_text_len = ctypes.c_uint(plain_buff_len)
        decrypt_tuple_info = namedtuple('decrypt_tuple_info',
                                        ['domain_id', 'cipher_text', 'cipher_text_len', 'plain_text', 'plain_text_len'])
        decrypt_info = decrypt_tuple_info(domain_id, cipher_text, cipher_text_len,
                                   plain_text, plain_text_len)
        result_str = self._decrypt(decrypt_info)
        if not result_str:
            return False, None
        return True, result_str

    @staticmethod
    def memset(char_buffer):
        # 对缓冲区清零，charBuffer必须是ctypes数据类型
        ctypes.memset(char_buffer, 0, ctypes.sizeof(char_buffer))
