# -*-coding:utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
import json
import os
import stat
import socket
import sys
import argparse
import logging
import traceback

import requests


class Constant:
    LOG_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "md.log")
    # rest接口相关
    MD_USER_NAME = "admin"
    PORT = '10880'
    HTTPS = 'https://'

    # url
    LOGIN = '/api/user/login'
    LOGOUT = '/api/user/logout'
    CREATE_LUN_TAKEOVER = "/api/cms/task/create_lun_takeover"
    VOLUME_STATUS = "/api/cms/volume_status"
    ALL_VOLUME_STATUS = {
        "INIT": 0,
        "TAKEOVER_FAIL": 1,
        "ROLLBACK_TAKEOVER_FAIL": 2,
        "TAKEOVER": 3,
        "MIGRATING": 4,
        "PAUSED": 5,
        "MIGRATE_FAIL": 6,
        "MIGRATED": 7,
        "FINISHED": 8
    }

    LUN_MIGRATION = "/api/cms/task/lun_migration"
    LUN_MIGRATION_CHOICES = ["MIGRATE", "PAUSE", "SYNC", "ROLLBACK", "SPLIT", "ROLLBACK_FINISHED_TO_INIT"]

    ROLLBACK_LUN_TAKEOVER = "/api/cms/task/rollback_lun_takeover"

    # 计算节点信息相关
    HOST_INFO_PATH = "/var/log/FusionStorage_eox/"
    HOST_INFO_FILE_SUFFIX = ".ini"
    VOLUME_Name_Prefix = "volume-"

    # 环境信息相关
    ENVIRONMENT_CONFIG_FILE = "md_cfg.json"
    MD_PORTAL_IP = "md_portal_ip"
    MD_ADMIN_PWD = "md_admin_pwd"
    SRC_STORAGE_SN = "srcStorageSn"
    SRC_POOL_ID = "srcPoolId"
    DEST_STORAGE_SN = "destStorageSn"
    DEST_POOL_ID = "destPoolId"
    HOST_UUID = "host_uuid"
    # 这条是通过节点名称获取的
    VOLUME_NAMES = "volumeNames"

    # 生成在环境上的参数示例
    ENVIRONMENT_CONFIG_FILE_DEMO = {
        MD_PORTAL_IP: "1.1.1.1",
        MD_ADMIN_PWD: "aaa",
        SRC_STORAGE_SN: "bbb",
        SRC_POOL_ID: 1,
        DEST_STORAGE_SN: "ccc",
        DEST_POOL_ID: 2,
        HOST_UUID: [
            "uuid1",
            "uuid2"]
    }


def get_json_as_file(file_path):
    with open(file_path, 'r') as file_inspect_json:
        json_dict = json.load(file_inspect_json)
    if not json_dict:
        msg = "There is no info in {}".format(file_path)
        log.error(msg)
        print(msg)
        exit(1)
    return json_dict


def update_json_file(json_path, stream):
    flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
    modes = stat.S_IWUSR | stat.S_IRUSR
    with os.fdopen(os.open(json_path, flags, modes), 'w') as f:
        json.dump(stream, f, indent=2, ensure_ascii=False)


def is_valid_ip(ip):
    try:
        socket.inet_aton(ip)
        return True
    except socket.error:
        return False


def config_check(configs):
    err_dict = {}
    # 先检查参数是否存在
    if not configs:
        err_dict["all"] = "There is no params!"
        return err_dict

    for key in Constant.ENVIRONMENT_CONFIG_FILE_DEMO:
        if configs.get(key) is None or configs.get(key) == "":
            err_dict[key] = "There is no params!"

    if err_dict:
        return err_dict

    # 检查部分参数格式
    if not is_valid_ip(configs.get(Constant.MD_PORTAL_IP)):
        err_dict[Constant.MD_PORTAL_IP] = "Invalid IP address format."

    if not isinstance(configs.get(Constant.HOST_UUID), list):
        err_dict[Constant.HOST_UUID] = "{} must be list! ".format(Constant.HOST_UUID)

    return err_dict


def get_environment_config():
    config_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), Constant.ENVIRONMENT_CONFIG_FILE)

    if not os.path.isfile(config_path):
        update_json_file(config_path, Constant.ENVIRONMENT_CONFIG_FILE_DEMO)
        msg = "Please fill in the configuration file first. Path: {}".format(config_path)
        log.error(msg)
        print(msg)
        exit(1)

    config = get_json_as_file(config_path)
    res = config_check(config)
    if not res:
        return config
    msg = "The parameter is incorrectly set. Detail:{}".format(res)
    log.error(msg)
    print(msg)
    exit(1)


def get_volumes_name(host_list):
    err_host_dict = {}
    all_volumes_name = []
    for host_uuid in host_list:
        host_info_file = os.path.join(Constant.HOST_INFO_PATH, host_uuid + Constant.HOST_INFO_FILE_SUFFIX)
        if not os.path.isfile(host_info_file):
            err_host_dict[host_uuid] = "The configuration file is not found or the permission is insufficient."
            continue
        host_info = get_json_as_file(host_info_file)
        node_volumes_name = get_node_volumes_by_ini(host_info)
        all_volumes_name += node_volumes_name

    if err_host_dict:
        msg = "Something error has happened. Detail {}".format(err_host_dict)
        log.error(msg)
        print(msg)
        exit(1)

    return list(set(all_volumes_name))


# 从计算节点信息中提取卷id
def get_node_volumes_by_ini(host_info):
    err_list = []
    id_list = []

    # 层层检查 确保每个信息都存在，以及报错能支持快速定位
    host_uuid = host_info.get("host", "unknown_host")
    for vms in host_info.get("vms", {}):

        if not vms:
            msg = "There is no host vms info! Host uuid : {}".format(host_uuid)
            log.error(msg)
            print(msg)
            exit(1)

        vms_name = vms.get("name", "unknown_vms")
        volumes_list = vms.get("volumes_attached", [])
        if not volumes_list:
            msg = "There is no host volumes list info! Host uuid : {}, vms name: {}".format(host_uuid, vms_name)
            log.error(msg)
            print(msg)
            exit(1)

        # 此处卷名找不到，不马上报错，留到后一个流程报错，这样可以一次处理多点
        for volume in volumes_list:
            volume_device = volume.get("device", "known_device" + "_in_" + vms_name + "_in_" + host_uuid)
            volume_id = volume.get("volumeId", volume_device + host_uuid)
            id_list.append(volume_id)

    # 注意去重
    res = set()
    for volume_id in id_list:
        if not str(volume_id).startswith("device"):
            res.add(Constant.VOLUME_Name_Prefix + volume_id)
            continue
        err_list.append(volume_id)

    if err_list:
        msg = "Volume id not found in {}".format(str(err_list))
        log.error(msg)
        print(msg)
        exit(1)

    return list(res)


class RestClient:
    def __init__(self, ip, port=Constant.PORT):
        self.ip = ip
        self.port = port
        self.token = None

    def make_header(self, content_type='application/json', **kwargs):
        """
        构造header
        :param content_type:
        :param kwargs:
        :return:
        """
        header = {'Accept-Language': 'en-US'}
        if content_type is not None:
            header['Content-type'] = content_type
        if self.token is not None:
            header['X-Auth-Token'] = self.token
        for i in kwargs:
            header[i] = kwargs.get(i)
        return header

    def login(self, password, username=Constant.MD_USER_NAME):
        """
        登录模块
        :param username:用户名
        :param password:密码
        :return:
        """
        url = Constant.HTTPS + self.ip + ":" + self.port + Constant.LOGIN
        login_data = {
            "username": username,
            "password": password,
        }
        json_data = json.dumps(login_data)
        login_header = {
            'Content-type': 'application/json',
            'Accept-Language': 'en-US'
        }
        # 消除规避ssl认证导致的警告
        requests.packages.urllib3.disable_warnings()
        try:
            res = requests.post(url, data=json_data, headers=login_header, verify=False)
            res.close()

        except Exception as e:
            msg = "Failed to login [https://%s:%s], Detail:%s" % (
                self.ip, self.port, str(e))
            log.error(msg)
            print(msg)
            exit(1)
        return res

    def normal_request(self, url, method, data=None, **kwargs):
        """
        一般地请求，除登录和上传文件外的请求统一走这里
        :param url:
        :param method:
        :param data:
        :param kwargs:
        :return:
        """
        # 消除规避ssl认证导致的警告
        requests.packages.urllib3.disable_warnings()
        if data is not None:
            json_data = json.dumps(data)
        else:
            json_data = data

        headers = self.make_header(**kwargs)
        if method == 'put':
            res = requests.put(url, data=json_data, headers=headers,
                               verify=False)
        elif method == 'post':
            res = requests.post(url, data=json_data, headers=headers,
                                verify=False)
        elif method == 'get':
            res = requests.get(url, data=json_data, headers=headers, verify=False)
        elif method == 'delete':
            res = requests.delete(url, headers=headers, verify=False)
        res.close()
        return res

    def set_token(self, token):
        self.token = token


class Operate:

    def __init__(self, params):
        self.params = params
        self.ip = self.params.get(Constant.MD_PORTAL_IP)
        self.password = self.params.get(Constant.MD_ADMIN_PWD)
        self.port = Constant.PORT
        self.rest_client = RestClient(self.ip)
        try:
            self.login(self.password)
        except Exception, err:
            print("login failed. detail: {}".format(err))
            log.error(traceback.format_exc())
            exit(1)

    def login(self, password):
        res = self.rest_client.login(password)
        result = ResponseParse(res)
        status_code, error_code, error_des = result.get_res_code()
        if status_code != 200 or error_code != 0:
            msg = "Failed to login, Detail:[status:%s,code:%s]%s" \
                  % (status_code, error_code, error_des)
            log.error(msg)
            raise Exception(msg)

        self.rest_client.set_token(result.get_res_data())

    def logout(self):
        url = Constant.HTTPS + self.ip + ":" + self.port + Constant.LOGOUT

        header = {"username": Constant.MD_USER_NAME}

        res = self.rest_client.normal_request(url, 'delete', **header)

        result = ResponseParse(res)

        status_code, error_code, error_des = result.get_res_code()
        if status_code != 200 or error_code != 0:
            msg = "Failed to logout, Detail:[status:%s,code:%s]%s" \
                  % (status_code, error_code, error_des)
            log.error(msg)
            print(msg)
            exit(1)

    def create_lun_takeover(self):
        url = Constant.HTTPS + self.ip + ":" + self.port + Constant.CREATE_LUN_TAKEOVER
        body_need_params = [
            Constant.SRC_STORAGE_SN,
            Constant.SRC_POOL_ID,
            Constant.DEST_STORAGE_SN,
            Constant.DEST_POOL_ID,
            Constant.VOLUME_NAMES
        ]

        body = {}
        for param in body_need_params:
            body[param] = self.params.get(param)

        res = self.rest_client.normal_request(url, "post", body)

        result = ResponseParse(res)
        status_code, error_code, error_des = result.get_res_code()
        if status_code != 200 or error_code != 0:
            msg = "Failed to create lun takeover, Detail:[status:%s,code:%s]%s" \
                  % (status_code, error_code, error_des)
            log.error(msg)
            print(msg)
            exit(1)

        print("create lun takeover success. detail:{}".format(result.get_res_data()))

    def lun_migration(self, action):
        url = Constant.HTTPS + self.ip + ":" + self.port + Constant.LUN_MIGRATION
        body_need_params = [
            Constant.SRC_STORAGE_SN,
            Constant.SRC_POOL_ID,
            Constant.DEST_STORAGE_SN,
            Constant.DEST_POOL_ID,
            Constant.VOLUME_NAMES
        ]

        body = {"action": action}
        for param in body_need_params:
            body[param] = self.params.get(param)

        res = self.rest_client.normal_request(url, "post", body)

        result = ResponseParse(res)
        status_code, error_code, error_des = result.get_res_code()
        if status_code != 200 or error_code != 0:
            msg = "Failed to lun migration, Detail:[status:%s,code:%s]%s" \
                  % (status_code, error_code, error_des)
            log.error(msg)
            print(msg)
            exit(1)

        log.info("success. detail:{}".format(result.get_res_data()))
        print("Start lun migration success. detail:{}".format(result.get_res_data()))

    def volume_status(self):
        url = Constant.HTTPS + self.ip + ":" + self.port + Constant.VOLUME_STATUS
        body = {Constant.VOLUME_NAMES: self.params.get(Constant.VOLUME_NAMES)}

        res = self.rest_client.normal_request(url, "get", body)

        result = ResponseParse(res)
        status_code, error_code, error_des = result.get_res_code()
        if status_code != 200 or error_code != 0:
            msg = "Failed to get volume status, Detail:[status:%s,code:%s]%s" \
                  % (status_code, error_code, error_des)
            log.error(msg)
            print(msg)
            exit(1)

        log.info("Get volume status success. detail:{}".format(result.get_res_data()))
        print("Get volume status success. detail:\n")
        for item in result.get_volumes_data():
            print('Name: {:<10} Volume Status: {:<5} Process: {}'.
                  format(item['name'], item['volume_status'], item['process']))

    def rollback_lun_takeover(self):
        url = Constant.HTTPS + self.ip + ":" + self.port + Constant.ROLLBACK_LUN_TAKEOVER

        body = {Constant.VOLUME_NAMES: self.params.get(Constant.VOLUME_NAMES)}

        res = self.rest_client.normal_request(url, "put", body)

        result = ResponseParse(res)
        status_code, error_code, error_des = result.get_res_code()
        if status_code != 200 or error_code != 0:
            msg = "Failed to rollback lun takeover, Detail:[status:%s,code:%s]%s" \
                  % (status_code, error_code, error_des)
            log.error(msg)
            print(msg)
            exit(1)

        log.info("success. detail:{}".format(result.get_res_data()))
        print("success. detail:{}".format(result.get_res_data()))


class ResponseParse(object):
    def __init__(self, res):
        # rest请求返回的初始结果
        self.res = res

    def get_res_code(self):
        status_code = self.res.status_code
        error_code = -1
        error_des = "failed"
        if status_code == 200:
            res = self.res.json()
            error_code = res.get("code")
            error_des = res.get("description")
            if error_des is None or error_code == 0:
                error_des = "success"
        return status_code, int(error_code), error_des

    def get_res_data(self):
        res = self.res.json()
        return res.get("data")

    def get_volumes_data(self):
        res = self.res.json().get("data")

        result_list = [[] for _ in Constant.ALL_VOLUME_STATUS]
        for volume in res:
            vlm = {'name': volume.get('name'), 'volume_status': volume.get('status'), 'process': volume.get('process')}
            result_list[Constant.ALL_VOLUME_STATUS[volume.get('status')]].append(vlm)
        result = []
        for i in range(len(Constant.ALL_VOLUME_STATUS)):
            result += result_list[i]
        return result


class ParseArgs:
    DESCRIPTION = "Takeover and migration tool script in the NFV scenario."
    VOLUME_STATUS = "Querying the Volume Status on Hosts.The status includes: INIT, TAKEOVER_FAIL, " \
                    "ROLLBACK_TAKEOVER_FAIL, TAKEOVER, MIGRATING, PAUSED, MIGRATE_FAIL, MIGRATED, FINISHED"
    CREATE_LUN_TAKEOVER = "Create lun takeover tasks."
    LUN_MIGRATION = "Delivering a Volume Migration Task. The input parameters include: {0} (Only migrated " \
                    "volumes can be split.)".format(Constant.LUN_MIGRATION_CHOICES)
    ROLLBACK_LUN_TAKEOVER = "Rollback lun takeover"
    PARAMS_LIST = [
        "-vs", "--volume_status",
        "-clt", "--create_lun_takeover",
        "-lm", "--lun_migration",
        "-rlt", "--rollback_lun_takeover"
    ]

    @classmethod
    def parse_args(cls, args):
        parse = argparse.ArgumentParser(prog=__file__, add_help=False, description=cls.DESCRIPTION)
        parse.add_argument(cls.PARAMS_LIST[0], cls.PARAMS_LIST[1], action="store_true", help=cls.VOLUME_STATUS)
        parse.add_argument(cls.PARAMS_LIST[2], cls.PARAMS_LIST[3], action="store_true", help=cls.CREATE_LUN_TAKEOVER)
        parse.add_argument(cls.PARAMS_LIST[4], cls.PARAMS_LIST[5], nargs=1,
                           help=cls.LUN_MIGRATION, choices=Constant.LUN_MIGRATION_CHOICES)
        parse.add_argument(cls.PARAMS_LIST[6], cls.PARAMS_LIST[7], action="store_true", help=cls.ROLLBACK_LUN_TAKEOVER)

        if not args:
            sys.exit(parse.format_help())

        for arg in args:
            if "-" == arg.strip()[0] and arg.strip() not in cls.PARAMS_LIST:
                print("parameter Error.\n{}".format(parse.format_help()))

        if len(args) > 2:
            print("The parameter is too long.Only one task can be executed at a time."
                  " \n{}".format(parse.format_help()))
        _args_obj = parse.parse_args(args)
        return _args_obj


def main():
    # 从同目录下的json里获取环境信息，并简单校验。首次执行会生成demo并要求填写。
    config = get_environment_config()

    # 通过主机名去获取卷名,函数内部校验
    volumes_name = get_volumes_name(config.get(Constant.HOST_UUID))
    config.update({Constant.VOLUME_NAMES: volumes_name})

    # 获取输入
    params = sys.argv[1:]
    arg_obj = ParseArgs.parse_args(params)

    # 初始化，同时调用login
    opr = Operate(config)

    if arg_obj.create_lun_takeover:
        opr.create_lun_takeover()

    if arg_obj.lun_migration:
        opr.lun_migration(arg_obj.lun_migration[0])  # 这里是个列表，虽然只有一个元素

    if arg_obj.volume_status:
        opr.volume_status()

    if arg_obj.rollback_lun_takeover:
        opr.rollback_lun_takeover()

    # 释放token
    opr.logout()


Log_Format = '%(asctime)s.%(msecs)d %(levelname)s pid[%(process)s] [%(filename)s %(funcName)s:%(lineno)d]:%(' \
             'message)s'
logging.basicConfig(filename=Constant.LOG_PATH, format=Log_Format, datefmt='%Y-%m-%d %I:%M:%S',
                    level=logging.INFO)
log = logging.getLogger()

if __name__ == "__main__":
    main()
