#!/usr/bin/python3
# -*- coding: utf-8 -*-
# --------------------------------------------------------------------
# Copyright © 2014-2015 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
# --------------------------------------------------------------------

# --------------------------------------------------------------------
# Terminology:
#
# - "update archive": a fake system-image "ubuntu-hash.tar.xz" tar
#   archive.
# --------------------------------------------------------------------

import tarfile
import tempfile
import unittest
from unittest.mock import patch
import os
import shutil
import sys

from ubuntucoreupgrader.upgrader import (
    tar_generator,
    Upgrader,
    parse_args,
)

base_dir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(base_dir)

from ubuntucoreupgrader.tests.utils import (
    create_file,
    make_tmp_dir,
    UbuntuCoreUpgraderTestCase,
    )

# file mode to use for creating test directories.
TEST_DIR_MODE = 0o750

CMD_FILE = 'ubuntu_command'


def make_default_options():
    return parse_args([])


class UpgradeTestCase(unittest.TestCase):

    def test_tar_generator_unpack_assets(self):
        tempf = tempfile.TemporaryFile()
        with tarfile.TarFile(fileobj=tempf, mode="w") as tar:

            # special top-level file that should not be unpacked
            tar.add(__file__, "removed")

            tar.add(__file__, "system/bin/true")
            tar.add(__file__, "assets/vmlinuz")
            tar.add(__file__, "assets/initrd.img")
            tar.add(__file__, "unreleated/something")
            tar.add(__file__, "hardware.yaml")

            result = [m.name for m in tar_generator(tar, "cache_dir", [], "/")]
            self.assertEqual(result, ["system/bin/true", "assets/vmlinuz",
                                      "assets/initrd.img", "hardware.yaml"])

    def test_tar_generator_system_files_unpack(self):
        tempf = tempfile.TemporaryFile()
        root_dir = make_tmp_dir()
        cache_dir = make_tmp_dir()

        sys_dir = os.path.join(cache_dir, 'system')

        os.makedirs(os.path.join(root_dir, 'dev'))
        os.makedirs(sys_dir)

        with tarfile.TarFile(fileobj=tempf, mode="w") as tar:
            tar.add(__file__, "assets/vmlinuz")
            tar.add(__file__, "assets/initrd.img")

            device_a = '/dev/null'
            path = (os.path.normpath('{}/{}'.format(sys_dir, device_a)))
            touch_file(path)

            self.assertTrue(os.path.exists(os.path.join(root_dir, device_a)))

            # should not be unpacked as already exists
            tar.add(__file__, "system{}".format(device_a))

            device_b = '/dev/does-not-exist'
            self.assertFalse(os.path.exists(os.path.join(root_dir, device_b)))

            # should be unpacked as does not exist
            tar.add(__file__, "system{}".format(device_b))

            expected = ["assets/vmlinuz", "assets/initrd.img",
                        "dev/does-not-exist"]

            expected_results = [os.path.join(root_dir, file)
                                for file in expected]

            result = [m.name for m in
                      tar_generator(tar, cache_dir, [], root_dir)]

            self.assertEqual(result, expected_results)

        shutil.rmtree(root_dir)
        shutil.rmtree(cache_dir)

    def test_tar_generator_system_files_remove_before_unpack(self):
        """
        Test that the upgrader removes certain files just prior to
        overwriting them via the unpack.
        """

        tempf = tempfile.TemporaryFile()
        root_dir = make_tmp_dir()
        cache_dir = make_tmp_dir()
        sys_dir = os.path.join(cache_dir, 'system')

        os.makedirs(sys_dir)

        file = 'a/file'
        dir = 'a/dir'

        with tarfile.TarFile(fileobj=tempf, mode="w") as tar:

            file_path = os.path.normpath('{}/{}'.format(root_dir, file))
            touch_file(file_path)
            self.assertTrue(os.path.exists(file_path))

            dir_path = os.path.normpath('{}/{}'.format(root_dir, dir))
            os.makedirs(dir_path)
            self.assertTrue(os.path.exists(dir_path))

            tar.add(__file__, "system/{}".format(file))

            expected = [file]
            expected_results = [os.path.join(root_dir, f)
                                for f in expected]

            result = [m.name for m in
                      tar_generator(tar, cache_dir, [], root_dir)]

            self.assertEqual(result, expected_results)

            # file should be removed
            self.assertFalse(os.path.exists(file_path))

            # directory should not be removed
            self.assertTrue(os.path.exists(dir_path))

        shutil.rmtree(root_dir)
        shutil.rmtree(cache_dir)


def touch_file(path):
    '''
    Create an empty file (creating any necessary intermediate
    directories in @path).
    '''
    create_file(path, None)


def make_commands(update_list):
    """
    Convert the specified list of update archives into a list of command
    file commands.
    """
    l = []

    # we don't currently add a mount verb (which would be more
    # realistic) since we can't handle that in the tests.
    # ##l.append('mount system')

    for file in update_list:
        l.append('update {} {}.asc'.format(file, file))

    # we don't currently add an unmount verb (which would be more
    # realistic) since we can't handle that in the tests.
    # ##l.append('unmount system')

    return l


class UbuntuCoreUpgraderObectTestCase(UbuntuCoreUpgraderTestCase):

    def test_create_object(self):
        root_dir = make_tmp_dir()

        file = 'created-regular-file'

        file_path = os.path.join(self.update.system_dir, file)
        create_file(file_path, 'foo bar')

        self.cache_dir = self.update.tmp_dir

        archive = self.update.create_archive(self.TARFILE)
        self.assertTrue(os.path.exists(archive))

        dest = os.path.join(self.cache_dir, os.path.basename(archive))
        touch_file('{}.asc'.format(dest))

        commands = make_commands([self.TARFILE])

        options = make_default_options()

        # XXX: doesn't actually exist, but the option must be set since
        # the upgrader looks for the update archives in the directory
        # where this file is claimed to be.
        options.cmdfile = os.path.join(self.cache_dir, 'ubuntu_command')

        options.root_dir = root_dir

        upgrader = Upgrader(options, commands, [])
        upgrader.MOUNTPOINT_CMD = "true"
        upgrader.run()

        path = os.path.join(root_dir, file)
        self.assertTrue(os.path.exists(path))

        shutil.rmtree(root_dir)

    @patch('ubuntucoreupgrader.upgrader.get_mount_details')
    @patch('ubuntucoreupgrader.upgrader.mount')
    @patch('ubuntucoreupgrader.upgrader.unmount')
    def test_no_format_in_cmd(self, mock_umount, mock_mount,
                              mock_mount_details):

        # If the command file does not contain the format command, mkfs
        # should not be called.
        with patch('ubuntucoreupgrader.upgrader.mkfs') as mock_mkfs:
            args = ['cmdfile']
            options = parse_args(args=args)
            commands = make_commands([self.TARFILE])
            upgrader = Upgrader(options, commands, [])
            upgrader.TIMESTAMP_FILE = '/dev/null'
            upgrader.MOUNTPOINT_CMD = "true"
            upgrader.run()

        # No format command in command file, so should not have been
        # called.
        self.assertFalse(mock_mkfs.called)

    @patch('ubuntucoreupgrader.upgrader.get_mount_details')
    @patch('ubuntucoreupgrader.upgrader.mount')
    @patch('ubuntucoreupgrader.upgrader.unmount')
    def test_format(self, mock_umount, mock_mount, mock_mount_details):
        MOCK_FS_TUPLE = ("device", "fstype", "label")
        mock_mount_details.return_value = MOCK_FS_TUPLE

        # mkfs should be called if the format command is specified in
        # the command file.
        with patch('ubuntucoreupgrader.upgrader.mkfs') as mock_mkfs:
            args = ['cmdfile']
            options = parse_args(args=args)
            commands = make_commands([self.TARFILE])

            # add format command to command file
            commands.insert(0, 'format system')

            upgrader = Upgrader(options, commands, [])
            upgrader.TIMESTAMP_FILE = '/dev/null'
            upgrader.MOUNTPOINT_CMD = "true"
            upgrader.run()
            self.assertTrue(upgrader.other_has_been_formatted)

        mock_mkfs.assert_called_with(*MOCK_FS_TUPLE)

    @patch.object(Upgrader, 'sync_partitions')
    @patch('ubuntucoreupgrader.upgrader.get_mount_details')
    @patch('ubuntucoreupgrader.upgrader.remount')
    @patch('ubuntucoreupgrader.upgrader.mount')
    @patch('ubuntucoreupgrader.upgrader.fsck')
    @patch('ubuntucoreupgrader.upgrader.mkfs')
    @patch('ubuntucoreupgrader.upgrader.unmount')
    def test_mount_unmount_no_format(self, mock_umount, mock_mkfs, mock_fsck,
                                     mock_mount, mock_remount,
                                     mock_mount_details, mock_sync_partitions):
        MOCK_FS_TUPLE = ("device", "fstype", "label")
        mock_mount_details.return_value = MOCK_FS_TUPLE

        # mkfs should be called if the format command is specified in
        # the command file.
        args = ['cmdfile']
        options = parse_args(args=args)
        commands = make_commands([self.TARFILE])

        commands.insert(0, 'mount system')
        commands.append('unmount system')

        upgrader = Upgrader(options, commands, [])
        upgrader.TIMESTAMP_FILE = '/dev/null'
        upgrader.MOUNTPOINT_CMD = "true"
        upgrader.run()

        self.assertTrue(mock_sync_partitions.called)

    @patch.object(Upgrader, 'sync_partitions')
    @patch('ubuntucoreupgrader.upgrader.get_mount_details')
    @patch('ubuntucoreupgrader.upgrader.remount')
    @patch('ubuntucoreupgrader.upgrader.mount')
    @patch('ubuntucoreupgrader.upgrader.fsck')
    @patch('ubuntucoreupgrader.upgrader.mkfs')
    @patch('ubuntucoreupgrader.upgrader.unmount')
    def test_mount_unmount_with_format(self, mock_umount, mock_mkfs, mock_fsck,
                                       mock_mount, mock_remount,
                                       mock_mount_details,
                                       mock_sync_partitions):
        MOCK_FS_TUPLE = ("device", "fstype", "label")
        mock_mount_details.return_value = MOCK_FS_TUPLE

        # mkfs should be called if the format command is specified in
        # the command file.
        args = ['cmdfile']
        options = parse_args(args=args)
        commands = make_commands([self.TARFILE])

        commands.insert(0, 'format system')
        commands.insert(1, 'mount system')
        commands.append('unmount system')

        upgrader = Upgrader(options, commands, [])
        upgrader.TIMESTAMP_FILE = '/dev/null'
        upgrader.MOUNTPOINT_CMD = "true"
        upgrader.run()

        self.assertFalse(mock_sync_partitions.called)

    def test_empty_removed_file(self):
        root_dir = make_tmp_dir()

        file = 'created-regular-file'

        file_path = os.path.join(self.update.system_dir, file)
        create_file(file_path, 'foo bar')

        self.cache_dir = self.update.tmp_dir

        removed_file = self.update.removed_file
        # Create an empty removed file
        create_file(removed_file, '')

        archive = self.update.create_archive(self.TARFILE)
        self.assertTrue(os.path.exists(archive))

        dest = os.path.join(self.cache_dir, os.path.basename(archive))
        touch_file('{}.asc'.format(dest))

        commands = make_commands([self.TARFILE])

        options = make_default_options()

        # XXX: doesn't actually exist, but the option must be set since
        # the upgrader looks for the update archives in the directory
        # where this file is claimed to be.
        options.cmdfile = os.path.join(self.cache_dir, 'ubuntu_command')

        options.root_dir = root_dir

        upgrader = Upgrader(options, commands, [])
        upgrader.cache_dir = self.cache_dir
        upgrader.MOUNTPOINT_CMD = "true"

        si_file = self.TARFILE
        si_signature = "{}.asc".format(self.TARFILE)
        upgrader._cmd_update([si_file, si_signature])

        # Ensure that the upgrader has not attempted to remove the
        # cache_dir when the 'removed' file is empty.
        #
        # (Regression test for LP: #1437225).
        self.assertTrue(os.path.exists(upgrader.cache_dir))

        shutil.rmtree(self.cache_dir)

    def test_other_considered_empty(self):

        si_config_file = '/etc/system-image/channel.ini'

        cache_dir = make_tmp_dir()
        args = ['cmdfile']
        options = parse_args(args=args)
        commands = make_commands([self.TARFILE])

        upgrader = Upgrader(options, commands, [])
        upgrader.TIMESTAMP_FILE = '/dev/null'
        upgrader.MOUNTPOINT_CMD = "true"
        upgrader.cache_dir = cache_dir

        self.assertTrue(upgrader.other_considered_empty())

        target = upgrader.get_mount_target()
        channel_ini = os.path.normpath('{}/{}'.format(target,
                                       si_config_file))
        si_dir = os.path.dirname(channel_ini)

        os.makedirs(si_dir)

        self.assertTrue(upgrader.other_considered_empty())

        touch_file(channel_ini)

        self.assertTrue(upgrader.other_considered_empty())

        create_file(channel_ini, "my size is >0")

        self.assertFalse(upgrader.other_considered_empty())

        os.remove(channel_ini)

        self.assertTrue(upgrader.other_considered_empty())

        shutil.rmtree(cache_dir)

if __name__ == "__main__":
    unittest.main()
