# Copyright 2011 Avaya Inc. All Rights Reserved.

"""
SSHd service control and configuration management.
"""

from . import shell
import unittest

__author__      = "Avaya Inc."
__copyright__   = "Copyright 2011, Avaya Inc."


DEFAULT_SSHD_CONFIG_FILE='/etc/ssh/sshd_config'


def restart():
    """
    Restart SSHd service.
    """
    return shell.sudo_call('service sshd restart')


def set_config(key, value, config_file=DEFAULT_SSHD_CONFIG_FILE):
    """
    Set value for specified parameter in sshd config file.

    key         -- parameter name
    value       -- value to set
    config_file -- sshd configuration file (defaults to /etc/ssh/sshd_config)

    If value is None the specified keyword value pair entry will be disabled (commented out).
    """
    lines = []
    for line in shell.file_open(config_file).readlines():
        if key in line:
            if not value:
                if not line.startswith('#'):
                    # disable this parameter by commenting out corresponding entry
                    line = '#%s' % line
            else:
                line = '%s %s\n' % (key, str(value))
        lines.append(line)
    with shell.file_open(config_file, 'w') as outfile:
        outfile.writelines(lines)


def get_config(key, config_file=DEFAULT_SSHD_CONFIG_FILE):
    """
    Get value for specified parameter from SSHD_CONFIG_FILE.

    key         -- parameter name
    config_file -- sshd configuration file (defaults to /etc/ssh/sshd_config)


    This method will return None if specified keyword value pair is
    either disabled (commented out) or not present in config file.
    """
    with shell.file_open(config_file) as infile:
        for line in infile:
            line = line.strip()
            if not line.startswith('#'):
                if key in line:
                    return line.split()[-1]


def set_timeout(t, config_file=DEFAULT_SSHD_CONFIG_FILE):
    """
    Convenience function for setting SSHd connection timeout.

    This function works by setting 'ClientAliveInterval' parameter
    in sshd config file and restart sshd service.

    t           -- value to set (in seconds)
    config_file -- sshd configuration file (defaults to /etc/ssh/sshd_config)
    """
    set_config('ClientAliveInterval', t, config_file)
    restart()


def get_timeout(config_file=DEFAULT_SSHD_CONFIG_FILE):
    """
    Convenience function for getting SSHd connection timeout (in seconds).
    
    config_file -- sshd configuration file (defaults to /etc/ssh/sshd_config)
    """
    get_config('ClientAliveInterval', config_file)


class Test(unittest.TestCase):

    def setUp(self):
        import tempfile
        _ , self.ssh_config = tempfile.mkstemp()
        self.test_key   = 'TestKey'
        with open(self.ssh_config, 'w') as outfile:
            outfile.write('TestKey 1')

    def tearDown(self):
        import os
        os.remove(self.ssh_config)

    def test_set_get_config(self):
        expected_value = '10'
        set_config(self.test_key, expected_value, self.ssh_config)
        self.assertEqual(expected_value, get_config(self.test_key, self.ssh_config))

    def test_disable_config(self):
        set_config(self.test_key, None, self.ssh_config)
        self.assertFalse(get_config(self.test_key, self.ssh_config))
        with open(self.ssh_config) as infile:
            self.assertEqual('#TestKey 1', infile.readline().strip())

if __name__ == '__main__':
    unittest.main()
