#
#  Python ctxldaps
#  
#  Copyright (c) Citrix Systems, Inc. All Rights Reserved.
#  
CERTPATH = '/etc/xdl/.ldapcert'

import logging
import os
import warnings

CTXLDAPS_KEYTAB = os.getenv("CTXLDAPS_KEYTAB")
CTXLDAPS_LDAP_URL = os.getenv("CTXLDAPS_LDAP_URL")
CTXLDAPS_HOST_UPN = os.getenv("CTXLDAPS_HOST_UPN")
CTXLDAPS_BASE_DN = os.getenv("CTXLDAPS_BASE_DN")
CTXLDAPS_FILTER = os.getenv("CTXLDAPS_FILTER")
CTXLDAPS_ATTR = os.getenv("CTXLDAPS_ATTR")
CTXLDAPS_DOMAIN_JOIN_METHOD = os.getenv("CTXLDAPS_DOMAIN_JOIN_METHOD")

from abc import ABC


class ConnectionContext(ABC):
    def __init__(self):
        self._connection = None

    def __del__(self):
        self.close()

    @property
    def connection(self):
        return self._connection

    def close(self):
        if self._connection:
            self._connection.unbind()
            self._connection = None

    def __enter__(self):
        return self.connection

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
        return False


import ssl
from ldap3 import Server, Connection, Tls
import ldap3


class KeyTabConn(ConnectionContext):
    def __init__(self, ldap_url, host_upn, keytab):
        super().__init__()
        self.ldap_url = ldap_url
        self.host_upn = host_upn
        self.keytab = keytab
        self._connection = self.__connect()

    def __connect(self):
        tls = Tls(validate=ssl.CERT_REQUIRED, ca_certs_file=CERTPATH, version=ssl.PROTOCOL_TLSv1_2)
        server = Server(self.ldap_url, tls=tls, get_info=ldap3.ALL)
        connection = Connection(server,
                                user=self.host_upn,
                                cred_store={'client_keytab': self.keytab},
                                authentication=ldap3.SASL,
                                sasl_mechanism=ldap3.KERBEROS,
                                return_empty_attributes=False,
                                raise_exceptions=True
                                )
        if not connection.bind():
            logging.error('Unable to connect to ' + self.ldap_url + ' with ' + self.keytab)
            raise RuntimeError('Unable to connect')
        logging.debug('Connected')
        return connection


import subprocess


def check_tgt_valid():
    """
    Check that the Ticket Granting Ticket (TGT) for the Kerberos user principal specified is valid.

    :returns: returncode
        0 - TGT is valid
        1 - TGT is invalid
        2 - Failed to check
    """
    domain_join_method = CTXLDAPS_DOMAIN_JOIN_METHOD
    if domain_join_method.upper() == "Centrify".upper():
        klist_command = ['/usr/share/centrifydc/kerberos/bin/klist', '-s']
    else:
        klist_command = ["klist", "-s"]
    try:
        proc = subprocess.Popen(klist_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = proc.communicate(timeout=20)
        if err is not None:
            err = err.decode()
            if err != "":
                logging.debug("klist error output:\n{0}".format(err))
        if proc.returncode == 0:
            return 0
        else:
            return 1
    except Exception as e:
        logging.debug("klist exception output:\n{0}".format(e))
        return 2


def kerberos_authenticate():
    """
    Regain TGT by kinit
    """
    host_upn = CTXLDAPS_HOST_UPN
    domain_join_method = CTXLDAPS_DOMAIN_JOIN_METHOD
    if domain_join_method.upper() == "Centrify".upper():
        kinit_command = ['/usr/share/centrifydc/kerberos/bin/kinit', '-k', host_upn]
    else:
        kinit_command = ["kinit", "-k", host_upn]
    try:
        proc = subprocess.Popen(kinit_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = proc.communicate(timeout=20)
        if err is not None:
            err = err.decode()
            if err != "":
                logging.debug("kinit error output:\n{0}".format(err))
    except Exception as e:
        logging.debug("kinit exception output:\n{0}".format(e))


from contextlib import contextmanager
import json


@contextmanager
def connect():
    ldap_url = CTXLDAPS_LDAP_URL
    keytab = CTXLDAPS_KEYTAB
    host_upn = CTXLDAPS_HOST_UPN
    logging.debug('ldap_url:' + ldap_url + ' keytab:' + keytab + ' host_upn:' + host_upn)
    try:
        if check_tgt_valid() == 1:
            logging.debug("TGT has been expired, regain now.")
            kerberos_authenticate()
        conn_ctx = KeyTabConn(ldap_url, host_upn, keytab)
        yield conn_ctx.connection
    except Exception as e:
        logging.debug("failed to connect:\n{0}".format(e))


def get_basedn_from_ldap(connection):
    baseDN = ""
    try:
        baseDN = connection.server.info.other.get("defaultNamingContext")
        if baseDN:
            baseDN = baseDN[0]
    finally:
        pass

    return baseDN


def get_basedn_from_fqdn(connection):
    baseDN = ""
    try:
        domain = connection.fqdn.split('.')[1:]
        baseDN = 'DC=' + ',DC='.join(domain)
    finally:
        pass

    return baseDN


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    log_file = "/var/log/xdl/ctxldaps.log"
    logging.basicConfig(filename=log_file, filemode="a", level=logging.ERROR)
    logging.debug("ctxldaps entry")
    result = ""
    with connect() as connection:
        # Get parameters
        baseDN = CTXLDAPS_BASE_DN
        if baseDN.strip() == "":
            baseDN = get_basedn_from_ldap(connection)
        if baseDN.strip() == "":
            baseDN = get_basedn_from_fqdn(connection)
        filter = CTXLDAPS_FILTER
        attr = CTXLDAPS_ATTR
        if baseDN.strip() == "" or filter.strip() == "" or attr.strip() == "":
            logging.error("baseDN, filter or attr is None!")
            raise RuntimeError("baseDN, filter or attr is None!")
        logging.debug('baseDN:' + baseDN + 'filter:' + filter + 'attr:' + attr)

        # LDAP Search
        connection.search(baseDN, filter, attributes=["objectSid", "distinguishedName", "sAMAccountName", "dNSHostName",
                                                      "servicePrincipalName"])
        for entry in connection.entries:
            entryDict = json.loads(entry.entry_to_json())
            if attr == "netBiosName":
                domain = entryDict.get("attributes", {}).get("distinguishedName")
                username = entryDict.get("attributes", {}).get("sAMAccountName")
                if domain is None or username is None:
                    logging.debug('distinguishedName or sAMAccountName is None')
                    continue
                domain_username = domain[0] + ",samAccountName=" + username[0]
                result += domain_username
                result += ";"
            else:
                attribute = entryDict.get("attributes", {}).get(attr)
                if attribute is None:
                    logging.debug('attribute ' + attr + ' is None')
                    continue
                for i in range(0, len(attribute)):
                    result += attribute[i]
                    result += ";"
    print(result)
    logging.debug("ctxldaps exit")
