#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2018 Huawei Technologies Co. Ltd. All rights reserved.
""" rollback tool"""

from __future__ import print_function
import logging
import os
import shutil
import sys
import six

import update_tool as update

BACKUP_CONF_LIST = ['/etc/neutron/neutron.conf',
                    '/etc/neutron/huawei_driver_config.ini']
CODE_MAIN_NAME = 'networking_huawei'

LOGGER = logging.getLogger(__name__)
FORMATTER = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# console log
CONSOLE_HANDLER = logging.StreamHandler(sys.stdout)
CONSOLE_HANDLER.setFormatter(FORMATTER)
LOGGER.addHandler(CONSOLE_HANDLER)
LOGGER.setLevel(logging.INFO)


class RollBack(update.UpdateFunBase):
    """roll back"""

    def __init__(self):
        super(RollBack, self).__init__()
        super(RollBack, self).check_package_file_list()
        self.package_home_path = self.get_package_home_path()

    def roll_back(self, version, backup_path):
        """roll back"""
        # get system site-packages path
        dst_main_path = os.path.join(self.package_home_path, CODE_MAIN_NAME)
        LOGGER.info("start to recover code env to version :%s", version)
        update.copytree(os.path.join(backup_path, CODE_MAIN_NAME), dst_main_path)
        # conf file list callback
        LOGGER.info("start to rollback conf file")
        for conf_file in BACKUP_CONF_LIST:
            src_conf_file = os.path.join(backup_path, os.path.basename(conf_file))
            dst_conf_path = os.path.dirname(conf_file)
            LOGGER.info("copy conf file '%s' to '%s'", src_conf_file, dst_conf_path)
            shutil.copy(src_conf_file, dst_conf_path)
        # rollback tools
        LOGGER.info("start to rollback tools")
        self.roll_back_tools(version)

        LOGGER.info("start to restart neutron")
        self.restart_check(restart=False)

    def get_version_code_list(self):
        """get version code list"""
        history_path = self.get_history_path()
        version_list = []
        version_list_default = []

        for version_code in os.listdir(history_path):
            version_path = os.path.join(history_path, version_code)
            if os.path.isdir(version_path):
                if version_code.startswith('default') and version_code[-1].isdigit():
                    version_list_default.append(version_code)
                elif version_code[-1].isdigit():
                    version_list.append(version_code)
        if not version_list and not version_list_default:
            LOGGER.error('no version to rollback,exit')
            exit(0)
        if version_list_default:
            version_list.extend(version_list_default)
        version_list.sort(key=lambda param: param.split('_', 3)[3])
        return version_list

    def choose_version(self, version_list):
        """choose version"""
        print('-' * 50)
        print('choose one number(1,2,3...) below.\ne.g.1\nsee path \'%s\' for detail' % self.get_history_path())
        print('-' * 50)
        version_id = 1
        for version_code in version_list:
            print("[%d] ==> %s\n" % (version_id, version_code))
            version_id += 1
        print('-' * 50)
        # catch input number
        user_input = six.moves.input("you choose : ")
        num = 0
        try:
            num = int(user_input)
        except Exception:
            print("unknown input : %s catch an exception" % user_input)
            exit(-1)
        if num in six.moves.range(1, version_id):
            version_name = version_list[num - 1]
            LOGGER.info("start to rollback,version :%s", version_name)
            self.roll_back(version_name, self.get_code_full_path(version_name))
        else:
            print("unknown input : %s" % user_input)
            exit(-1)

    def get_code_full_path(self, base_name):
        """get code full path"""
        full_path = os.path.join(self.get_history_path(), base_name)
        if not os.path.isdir(full_path):
            LOGGER.error("unknown version name :%s", base_name)
            exit(1)
        return full_path

    @classmethod
    def roll_back_tools(cls, version):
        """
        :return: list of huawei tools paths in current system
        """
        if '_' not in str(version):
            LOGGER.error("unknown version :%s", version)
        version_id = str(version).split('_')[0] + "-"

        paths = []
        for dirpath, dirnames, _ in os.walk('/'):
            if dirpath.endswith('networking-huawei'):
                paths += [os.path.join(dirpath, elem) for elem in dirnames if elem == 'tools']

        for tools_path in paths:
            for tool_file in os.listdir(tools_path):
                if version_id not in tool_file:
                    continue
                rollback_file = os.path.join(tools_path, tool_file)
                rollback_no_version = os.path.join(tools_path, tool_file[0:tool_file.find('.%s' % version_id)])
                LOGGER.info('rollback file: %s to file: %s', rollback_file, rollback_no_version)
                shutil.move(rollback_file, rollback_no_version)


def main():
    """main function"""
    rollback = RollBack()
    version_list = rollback.get_version_code_list()
    args = sys.argv
    if len(args) == 1:
        rollback.choose_version(version_list)
        return

    if len(args) == 2:
        version_num = 1
        if args[1] == "PREVIOUS":
            version_name = version_list[-1]
            LOGGER.info("start to rollback,version :%s", version_name)
            rollback.roll_back(version_name, rollback.get_code_full_path(version_name))
            return
        try:
            version_num = int(args[1])
        except Exception:
            LOGGER.error('unknown input %s', args[1])
            exit(1)
        if version_num in six.moves.range(1, len(version_list) + 1):
            version_name = version_list[version_num - 1]
            LOGGER.info("start to rollback,version :%s", version_name)
            rollback.roll_back(version_name, rollback.get_code_full_path(version_name))
        else:
            LOGGER.error("unknown input :%s", version_num)
        return
    LOGGER.error('unknown input %s', args[1])


if __name__ == '__main__':
    main()
