
'''
Created on August 1, 2017

@author: feliu@cisco.com

Copyright (c) 2017 by Cisco Systems, Inc.
All rights reserved.

Classes used for device command executor invocation.
'''

import devpkg.utils.util

class ConfigKeeper:
    """
    ConfigKeeper keeps a snapshot of the current probed FMC configuration.
    """

    global_configs = {}

    def __init__(self):    
        # configs, e.g. {'accesspolicies':[...], 'securityzones':[...],...}
        self.configs = {}
 
    def set(self, key, value):
        '''
        @param key - a certain configuration to be saved, e.g. 'securityzones'
        @param value - the configuration to be saved, e.g. a list of security zones
        '''
        
        self.configs[key] = value

    @staticmethod
    def get_global(key):
        return ConfigKeeper.global_configs[key] if ConfigKeeper.global_configs.has_key(key) else []

    @staticmethod
    def set_global(key, value):
        ConfigKeeper.global_configs[key] = value

    def get(self, key):
        return self.configs[key] if self.configs.has_key(key) else []


    def get_id_from_securityzones(self, name):
        securityzones = list(self.get('securityzones'))
        for sz in securityzones:
            if sz['name'] == name:
                return str(sz['id'])
        interfacesecurityzones = list(self.get('interfacesecurityzones'))
        for sz in interfacesecurityzones:
            if sz['name'] == name:
                return str(sz['id'])
        return ''

    def get_mode_from_securityzone_name(self, name):
        securityzones = list(self.get('securityzones'))
        for sz in securityzones:
            if sz['name'] == name:
                return str(sz['interfaceMode'])
        interfacesecurityzones = list(self.get('interfacesecurityzones'))
        for sz in interfacesecurityzones:
            if sz['name'] == name:
                return str(sz['interfaceMode'])
        return ''

    def get_network_from_networks_by_value(self, network_value):
        networks = list(self.get('networks'))
        for network in networks:
            if network['value'] == network_value:
                return (network['name'], network['id'])
        return (None, None)

    def get_network_from_networks_by_name(self, network_name):
        networks = list(self.get('networks'))
        for network in networks:
            if network['name'] == network_name:
                return (network['value'], network['id'])
        # See FMC defect: CSCvj36576 - Not able to create network with /32 prefix.
        # On FMC, a /32 network is considered a Host rather than a network, so if a network cannot
        # be found above, we might be able to find it in the hosts bucket, if found, we append /32
        # to it as the config from APIC is well-formed with prefix, so they can be compared correctly.
        hosts = list(self.get('hosts'))
        for host in hosts:
            if host['name'] == network_name:
                return (host['value'] + '/32', host['id'])
        return (None, None)

    def get_host_from_hosts_by_value(self, host_value):
        hosts = list(self.get('hosts'))
        for host in hosts:
            if host['value'] == host_value:
                return (host['name'], host['id'])
        return (None, None)

    def get_host_from_hosts_by_name(self, host_name):
        hosts = list(self.get('hosts'))
        for host in hosts:
            if host['name'] == host_name:
                return (host['value'], host['id'])
        return (None, None)

    def get_ipv4_static_route_id(self, nameif, network_value, host_value):
        ipv4staticroutes = list(self.get('ipv4staticroutes'))
        (network_name, network_id) = self.get_network_from_networks_by_value(network_value)
        if not network_name:
            return ''
        (host_name, host_id) = self.get_host_from_hosts_by_value(host_value)
        for route in ipv4staticroutes:
            is_same_gateway = False
            if route['gateway'].has_key('object'):
                is_same_gateway = route['gateway']['object']['name'] == host_name
            else:
                is_same_gateway = route['gateway']['literal']['value'] == host_value
            if nameif == route['interfaceName'] and network_name == route['selectedNetworks'][0]['name'] and is_same_gateway:
                return route['id']
        return ''

    def get_host_value_from_ipv4_route(self, route):
        if not route or not route.has_key('gateway'):
            return None
        if route['gateway'].has_key('object'):
            host_value, host_id = self.get_host_from_hosts_by_name(route['gateway']['object']['name'])
        else:
            host_value = route['gateway']['literal']['value']
        return host_value
            
    def get_network_value_from_ipv4_route(self, route):
        if not route or not route.has_key('selectedNetworks'):
            return None
        network_value, network_id = self.get_network_from_networks_by_name(route['selectedNetworks'][0]['name'])
        return network_value
            
    def get_all_ipv4_static_route_ids_by_nameif(self, nameif):
        ipv4staticroutes = list(self.get('ipv4staticroutes'))
        all_route_ids = []
        for route in ipv4staticroutes:
            if devpkg.utils.util.asciistr(route['interfaceName']) == devpkg.utils.util.asciistr(nameif):
                all_route_ids.append(route['id'])
        return all_route_ids

    def get_id_from_probe(self, name, command_type):        
        if command_type == "AccessPolicy":
            # bucket is not separated for AccessPolicy as it can be applied to different device of different tenant
            accesspolicies = list(self.get('accesspolicies'))
            for each_policy in accesspolicies:
                if each_policy['name'] == name:
                    return devpkg.utils.util.asciistr(each_policy['id'])
        elif command_type == "etherchannelinterfaces":
            etherchannelinterfaces = list(self.get('etherchannelinterfaces'))
            for each_intf in etherchannelinterfaces:
                if each_intf['name'] == name:
                    return devpkg.utils.util.asciistr(each_intf['id'])
            # in the case of subinterface of port-channel, we may find the id froom subinterfaces
            subinterfaces = list(self.get('subinterfaces'))
            for each_intf in subinterfaces:
                if devpkg.utils.util.asciistr(each_intf['name']) + '.' + devpkg.utils.util.asciistr(each_intf['vlanId']) == name:
                    return devpkg.utils.util.asciistr(each_intf['id'])
        elif command_type == "physicalinterfaces":
            physicalinterfaces = list(self.get('physicalinterfaces'))
            for each_intf in physicalinterfaces:
                if each_intf['name'] == name:
                    return devpkg.utils.util.asciistr(each_intf['id'])
            # command_type "physicalinterfaces" covers both physical and subinterface case, so look for subinterfaces also
            subinterfaces = list(self.get('subinterfaces'))
            for each_intf in subinterfaces:
                if devpkg.utils.util.asciistr(each_intf['name']) + '.' + devpkg.utils.util.asciistr(each_intf['vlanId']) == name:
                    return devpkg.utils.util.asciistr(each_intf['id'])
        elif command_type == "subinterfaces":
            subinterfaces = list(self.get('subinterfaces'))
            for each_intf in subinterfaces:
                if devpkg.utils.util.asciistr(each_intf['name']) + '.' + devpkg.utils.util.asciistr(each_intf['vlanId']) == name:
                    return devpkg.utils.util.asciistr(each_intf['id'])
        elif command_type == "securityzones":
            securityzones = list(self.get('securityzones'))
            for each_sz in securityzones:
                if each_sz['name'] == name:
                    return devpkg.utils.util.asciistr(each_sz['id'])
        elif command_type == "ipv4staticroutes":
            nameif, network_value, host_value = name.split('#') # name in the form nameif#network_value#host_value
            return self.get_ipv4_static_route_id(nameif, network_value, host_value)
        return ''

    """
    Subinterface name may enter here. But associated interfaces with security zones only show the physical interfaces names. 
    Thus we have to check first
    """
    def get_security_zone_id_from_interface_name(self, intf_name):
        'Encode unicode chacters if there is any, get the physical interface name only'
        intf_name_str = devpkg.utils.util.asciistr(intf_name)
        # Based on the input of intf_name, determine if it's a subinterface (with '.' in the name), whole port-channel interface,
        # or whole physical interface, then get the associated security_zone_id from different bucket, namely 'subinterfaces',
        # 'etherchannelinterfaces', or 'physicalinterfaces' respectively.
        security_zone_id = ''
        if '.' in intf_name_str:
            phy_intf_name, vlan_id = intf_name_str.split('.')
            subinterfaces = self.get('subinterfaces')
            for subinterface in subinterfaces:
                if devpkg.utils.util.asciistr(subinterface['name']) == phy_intf_name and devpkg.utils.util.asciistr(subinterface['vlanId']) == vlan_id:
                    if subinterface.has_key('securityZone'):
                        security_zone_id = subinterface['securityZone']['id']
        elif 'Port-channel' in intf_name_str:
            etherchannels = self.get('etherchannelinterfaces')
            for etherchannel in etherchannels:
                if etherchannel['name'] == intf_name_str:
                    if etherchannel.has_key['securityZone']:
                        security_zone_id = etherchannel['securityZone']['id']
        else: # Physical interface
            physicalinterfaces = self.get('physicalinterfaces')
            for physicalinterface in physicalinterfaces:
                if physicalinterface['name'] == intf_name_str:
                    if physicalinterface.has_key('securityZone'):
                        security_zone_id = physicalinterface['securityZone']['id']
        return security_zone_id

    def get_interface_from_probe(self, name, command_type):
        interface_list = list(self.get(command_type))
        for each_intf in interface_list:
            if command_type == "subinterfaces" and devpkg.utils.util.asciistr(each_intf['name']) + '.' + devpkg.utils.util.asciistr(each_intf['vlanId']) == name:
                return each_intf
            elif each_intf['name'] == name:
                return each_intf
        return None

    def get_security_zone_name_from_zone_id(self, zone_id):
        for sz in self.get('securityzones'):
            if sz['id'] == zone_id:
                return sz['name']

    def get_associated_securityzones_of_interfaces(self, intfs):
        # For each selected interface, search all interfaces from the probe data, and return the associated security zone of the matching one.
        associated_szs = []
        for selected_intf in intfs:
            intf_type = selected_intf['type']
            if intf_type == 'PhysicalInterface':
                # 'PhysicalInterface' type covers both physcial interface and sub-interface, so we need to search both lists
                physical_intfs = list(self.get('physicalinterfaces'))
                sub_intfs = list(self.get('subinterfaces'))
                searching_intfs = physical_intfs + sub_intfs
            elif intf_type == 'SubInterface':
                searching_intfs = list(self.get('subinterfaces'))
            elif intf_type == 'EtherchannelInterface':
                searching_intfs = list(self.get('etherchannelinterfaces'))
            for target_intf in searching_intfs:
                # Both target_intf['name'] or selected_intf['name'] could be either form of 'name.vlan_id' or 'name'
                subIntfId = '.' + devpkg.utils.util.asciistr(target_intf['subIntfId']) if target_intf.has_key('subIntfId') and target_intf['subIntfId'] and devpkg.utils.util.asciistr(target_intf['subIntfId']) else ''
                target_intf_name = target_intf['name'] + subIntfId if '.' not in target_intf['name'] else target_intf['name']
                if target_intf_name == selected_intf['name'] and target_intf.has_key('securityZone'):
                    associated_sz = target_intf['securityZone']
                    associated_sz['name'] = self.get_security_zone_name_from_zone_id(associated_sz['id'])
                    # update 'interfaceMode' and 'interfaces' field
                    for sz in self.get('securityzones'):
                        if sz['name'] == associated_sz['name']:
                            associated_sz['interfaceMode'] = sz['interfaceMode']
                            associated_sz['interfaces'] = sz['interfaces']
                    associated_szs.append(associated_sz)
        return associated_szs
        
    def get_accessrules_with_securityzones(self, associated_szs):
        if len(associated_szs) == 0 or len(list(self.get('accessrules'))) == 0:
            return []
        
        accessrules = []
        for sz in associated_szs:
            for rule in list(self.get('accessrules')):
                all_zones = []
                if rule.has_key('sourceZones') and rule['sourceZones'].has_key('objects'):
                    all_zones.extend(rule['sourceZones']['objects'])
                if rule.has_key('destinationZones') and rule['destinationZones'].has_key('objects'):
                    all_zones.extend(rule['destinationZones']['objects'])
                for zone in all_zones:
                    if sz['name'] == zone['name']:
                        if rule not in accessrules:
                            accessrules.append(rule)
        return accessrules

                
