#!/usr/bin/env -S python3 -u
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from unittest.mock import call, MagicMock
import os
import shlex
import subprocess
import sys
import time
import unittest
import base64

REMOTE_BASE_IMAGE_NAME = "public.ecr.aws/w0m4q9l7/github-awslabs-smithy-rs-ci"
LOCAL_BASE_IMAGE_NAME = "smithy-rs-base-image"
BUILD_IMAGE_NAME = "smithy-rs-build-image"
LOCAL_TAG = "local"

C_YELLOW = '\033[1;33m'
C_RESET = '\033[0m'


def announce(message):
    print(f"{C_YELLOW}{message}{C_RESET}")


class DockerPullResult(Enum):
    SUCCESS = 1
    REMOTE_ARCHITECTURE_MISMATCH = 2
    ERROR_THROTTLED = 3
    RETRYABLE_ERROR = 4
    NOT_FOUND = 5
    UNKNOWN_ERROR = 6


class Platform(Enum):
    X86_64 = 0
    ARM_64 = 1


# Script context
class Context:
    def __init__(self, start_path, script_path, tools_path, user_id, image_tag, allow_local_build, github_actions,
                 encrypted_docker_password, docker_passphrase):
        self.start_path = start_path
        self.script_path = script_path
        self.tools_path = tools_path
        self.docker_image_path = tools_path + "/ci-build"
        self.user_id = user_id
        self.image_tag = image_tag
        self.allow_local_build = allow_local_build
        self.github_actions = github_actions
        self.encrypted_docker_password = encrypted_docker_password
        self.docker_passphrase = docker_passphrase

    @staticmethod
    def default():
        start_path = os.path.realpath(os.curdir)
        script_path = os.path.dirname(os.path.realpath(__file__))
        tools_path = get_cmd_output("git rev-parse --show-toplevel", cwd=script_path)[1] + "/tools"
        user_id = get_cmd_output("id -u")[1]
        image_tag = get_cmd_output("./docker-image-hash", cwd=script_path)[1]
        allow_local_build = os.getenv("ALLOW_LOCAL_BUILD") != "false"
        github_actions = os.getenv("GITHUB_ACTIONS") == "true"
        encrypted_docker_password = os.getenv("ENCRYPTED_DOCKER_PASSWORD") or None
        docker_passphrase = os.getenv("DOCKER_LOGIN_TOKEN_PASSPHRASE") or None

        print(f"Start path: {start_path}")
        print(f"Script path: {script_path}")
        print(f"Tools path: {tools_path}")
        print(f"User ID: {user_id}")
        print(f"Required base image tag: {image_tag}")
        print(f"Allow local build: {allow_local_build}")
        print(f"Running in GitHub Actions: {github_actions}")
        return Context(start_path=start_path, script_path=script_path, tools_path=tools_path, user_id=user_id,
                       image_tag=image_tag, allow_local_build=allow_local_build, github_actions=github_actions,
                       encrypted_docker_password=encrypted_docker_password, docker_passphrase=docker_passphrase)


def output_contains_any(stdout, stderr, messages):
    for message in messages:
        if message in stdout or message in stderr:
            return True
    return False

# Mockable shell commands
class Shell:
    # Returns the platform that this script is running on
    def platform(self):
        (_, stdout, _) = get_cmd_output("uname -m")
        if stdout == "arm64":
            return Platform.ARM_64
        return Platform.X86_64

    # Returns True if the given `image_name` and `image_tag` exist locally
    def docker_image_exists_locally(self, image_name, image_tag):
        (status, _, _) = get_cmd_output(f"docker inspect \"{image_name}:{image_tag}\"", check=False)
        return status == 0

    def docker_login(self, password):
        get_cmd_output("docker login --username AWS --password-stdin public.ecr.aws", input=password.encode('utf-8'))

    # Pulls the requested `image_name` with `image_tag`. Returns `DockerPullResult`.
    def docker_pull(self, image_name, image_tag):
        (status, stdout, stderr) = get_cmd_output(f"docker pull \"{image_name}:{image_tag}\"", check=False)
        print("Docker pull output:")
        print("-------------------")
        print(stdout)
        print("-------------------")
        print(stderr)
        print("-------------------")

        not_found_messages = ["not found: manifest unknown"]
        throttle_messages = ["toomanyrequests:"]
        retryable_messages = ["net/http: TLS handshake timeout"]
        if status == 0:
            return DockerPullResult.SUCCESS
        elif output_contains_any(stdout, stderr, throttle_messages):
            return DockerPullResult.ERROR_THROTTLED
        elif output_contains_any(stdout, stderr, not_found_messages):
            return DockerPullResult.NOT_FOUND
        elif output_contains_any(stdout, stderr, retryable_messages):
            return DockerPullResult.RETRYABLE_ERROR
        return DockerPullResult.UNKNOWN_ERROR

    # Builds the base image with the Dockerfile in `path` and tags with with `image_tag`
    def docker_build_base_image(self, image_tag, path):
        run(f"docker build -t \"smithy-rs-base-image:{image_tag}\" .", cwd=path)

    # Builds the local build image
    def docker_build_build_image(self, user_id, docker_image_path):
        run(
            f"docker build -t smithy-rs-build-image --file add-local-user.dockerfile --build-arg=USER_ID={user_id} .",
            cwd=docker_image_path
        )

    # Saves the Docker image named `image_name` with `image_tag` to `output_path`
    def docker_save(self, image_name, image_tag, output_path):
        run(f"docker save -o \"{output_path}\" \"{image_name}:{image_tag}\"")

    # Tags an image with a new image name and tag
    def docker_tag(self, image_name, image_tag, new_image_name, new_image_tag):
        run(f"docker tag \"{image_name}:{image_tag}\" \"{new_image_name}:{new_image_tag}\"")


# Pulls a Docker image and retries if it gets throttled
def docker_pull_with_retry(shell, image_name, image_tag, throttle_sleep_time=120, retryable_error_sleep_time=1):
    if shell.platform() == Platform.ARM_64:
        return DockerPullResult.REMOTE_ARCHITECTURE_MISMATCH
    for attempt in range(1, 6):
        announce(f"Attempting to pull remote image {image_name}:{image_tag} (attempt {attempt})...")
        result = shell.docker_pull(image_name, image_tag)
        if result == DockerPullResult.ERROR_THROTTLED:
            announce("Docker pull failed due to throttling. Waiting and trying again...")
            time.sleep(throttle_sleep_time)
        elif result == DockerPullResult.RETRYABLE_ERROR:
            announce("A retryable error occurred. Trying again...")
            time.sleep(retryable_error_sleep_time)
        else:
            return result
    # Hit max retries; the image probably exists, but we are getting throttled hard. Fail.
    announce("Image pulling throttled for too many attempts. The remote image might exist, but we can't get it.")
    return DockerPullResult.ERROR_THROTTLED


# Runs a shell command
def run(command, cwd=None):
    subprocess.run(shlex.split(command), stdout=sys.stderr, stderr=sys.stderr, cwd=cwd, check=True)


# Returns (status, output) from a shell command
def get_cmd_output(command, cwd=None, check=True, **kwargs):
    if isinstance(command, str):
        command = shlex.split(command)

    result = subprocess.run(
        command,
        capture_output=True,
        check=False,
        cwd=cwd,
        **kwargs
    )
    stdout = result.stdout.decode("utf-8").strip()
    stderr = result.stderr.decode("utf-8").strip()
    if check and result.returncode != 0:
        raise Exception(f"failed to run '{command}.\n{stdout}\n{stderr}")

    return result.returncode, stdout, stderr


def decrypt_and_login(shell, secret, passphrase):
    decoded = base64.b64decode(secret, validate=True)
    if not passphrase:
        raise Exception("a secret was set but no passphrase was set (or it was empty)")
    (code, password, err) = get_cmd_output(
        ["gpg", "--decrypt", "--batch", "--quiet", "--passphrase", passphrase, "--output", "-"],
        input=decoded)
    shell.docker_login(password)
    print("Docker login success!")


def acquire_build_image(context=Context.default(), shell=Shell()):
    if context.encrypted_docker_password is not None:
        decrypt_and_login(shell, context.encrypted_docker_password, context.docker_passphrase)
    # If the image doesn't already exist locally, then look remotely
    if not shell.docker_image_exists_locally(LOCAL_BASE_IMAGE_NAME, context.image_tag):
        announce("Base image not found locally.")
        pull_result = docker_pull_with_retry(shell, REMOTE_BASE_IMAGE_NAME, context.image_tag)
        if pull_result != DockerPullResult.SUCCESS:
            if pull_result == DockerPullResult.REMOTE_ARCHITECTURE_MISMATCH:
                announce("Remote architecture is not the same as the local architecture. A local build is required.")
            elif pull_result == DockerPullResult.UNKNOWN_ERROR:
                announce("An unknown failure happened during Docker pull. This needs to be examined.")
                return 1
            else:
                announce("Failed to pull remote image, which can happen if it doesn't exist.")

            if not context.allow_local_build:
                announce("Local build turned off by ALLOW_LOCAL_BUILD env var. Aborting.")
                return 1

            announce("Building a new image locally.")
            shell.docker_build_base_image(context.image_tag, context.docker_image_path)

            if context.github_actions:
                announce("Saving base image for use in later jobs...")
                shell.docker_save(
                    LOCAL_BASE_IMAGE_NAME,
                    context.image_tag,
                    context.start_path + "/smithy-rs-base-image"
                )
        else:
            announce("Successfully pulled remote image!")
            shell.docker_tag(REMOTE_BASE_IMAGE_NAME, context.image_tag, LOCAL_BASE_IMAGE_NAME, context.image_tag)
    else:
        announce("Base image found locally! No retrieval or rebuild necessary.")

    announce("Creating local build image...")
    shell.docker_tag(LOCAL_BASE_IMAGE_NAME, context.image_tag, LOCAL_BASE_IMAGE_NAME, LOCAL_TAG)
    shell.docker_build_build_image(context.user_id, context.docker_image_path)
    return 0


class SelfTest(unittest.TestCase):
    def test_context(self, github_actions=False, allow_local_build=False, encrypted_docker_password=None,
                     docker_passphrase=None):
        return Context(
            start_path="/tmp/test/start-path",
            script_path="/tmp/test/script-path",
            tools_path="/tmp/test/tools-path",
            user_id="123",
            image_tag="someimagetag",
            encrypted_docker_password=encrypted_docker_password,
            docker_passphrase=docker_passphrase,
            github_actions=github_actions,
            allow_local_build=allow_local_build,
        )

    def mock_shell(self):
        shell = Shell()
        shell.platform = MagicMock()
        shell.docker_build_base_image = MagicMock()
        shell.docker_build_build_image = MagicMock()
        shell.docker_image_exists_locally = MagicMock()
        shell.docker_pull = MagicMock()
        shell.docker_save = MagicMock()
        shell.docker_tag = MagicMock()
        shell.docker_login = MagicMock()
        return shell

    def test_retry_architecture_mismatch(self):
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.ARM_64]
        self.assertEqual(
            DockerPullResult.REMOTE_ARCHITECTURE_MISMATCH,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_docker_login(self):
        shell = self.mock_shell()
        acquire_build_image(self.test_context(
            encrypted_docker_password="jA0ECQMCvYU/JxsX3g/70j0BxbLLW8QaFWWb/DqY9gPhTuEN/xdYVxaoDnV6Fha+lAWdT7xN0qZr5DHPBalLfVvvM1SEXRBI8qnfXyGI",
            docker_passphrase="secret"), shell)
        shell.docker_login.assert_called_with("payload")

    def test_retry_immediate_success(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [DockerPullResult.SUCCESS]
        self.assertEqual(
            DockerPullResult.SUCCESS,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_retry_immediate_not_found(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [DockerPullResult.NOT_FOUND]
        self.assertEqual(
            DockerPullResult.NOT_FOUND,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_retry_immediate_unknown_error(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [DockerPullResult.UNKNOWN_ERROR]
        self.assertEqual(
            DockerPullResult.UNKNOWN_ERROR,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_retry_throttling_then_success(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.SUCCESS
        ]
        self.assertEqual(
            DockerPullResult.SUCCESS,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_retry_throttling_and_retryable_error_then_success(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.RETRYABLE_ERROR,
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.SUCCESS
        ]
        self.assertEqual(
            DockerPullResult.SUCCESS,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_retry_throttling_then_not_found(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.NOT_FOUND
        ]
        self.assertEqual(
            DockerPullResult.NOT_FOUND,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    def test_retry_max_attempts(self):
        shell = self.mock_shell()
        shell.docker_pull.side_effect = [
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.ERROR_THROTTLED,
            DockerPullResult.ERROR_THROTTLED,
        ]
        self.assertEqual(
            DockerPullResult.ERROR_THROTTLED,
            docker_pull_with_retry(
                shell,
                "test-image",
                "test-image-tag",
                throttle_sleep_time=0,
                retryable_error_sleep_time=0
            )
        )

    # When: the base image already exists locally with the right image tag
    # It should: build a local build image using that local base image
    def test_image_exists_locally_already(self):
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.X86_64]
        shell.docker_image_exists_locally.side_effect = [True]

        self.assertEqual(0, acquire_build_image(self.test_context(), shell))

        shell.docker_image_exists_locally.assert_called_once()
        shell.docker_tag.assert_called_with(LOCAL_BASE_IMAGE_NAME, "someimagetag", LOCAL_BASE_IMAGE_NAME, LOCAL_TAG)
        shell.docker_build_build_image.assert_called_with("123", "/tmp/test/tools-path/ci-build")

    # When:
    #  - the base image doesn't exist locally
    #  - the base image doesn't exist remotely
    #  - local builds are allowed
    #  - NOT running in GitHub Actions
    # It should: build a local image from scratch and NOT save it to file
    def test_image_local_build(self):
        context = self.test_context(allow_local_build=True)
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.X86_64]
        shell.docker_image_exists_locally.side_effect = [False]
        shell.docker_pull.side_effect = [DockerPullResult.NOT_FOUND]

        self.assertEqual(0, acquire_build_image(context, shell))
        shell.docker_image_exists_locally.assert_called_once()
        shell.docker_build_base_image.assert_called_with("someimagetag", "/tmp/test/tools-path/ci-build")
        shell.docker_save.assert_not_called()
        shell.docker_tag.assert_called_with(LOCAL_BASE_IMAGE_NAME, "someimagetag", LOCAL_BASE_IMAGE_NAME, LOCAL_TAG)
        shell.docker_build_build_image.assert_called_with("123", "/tmp/test/tools-path/ci-build")

    # When:
    #  - the base image doesn't exist locally
    #  - the base image exists remotely
    #  - local builds are allowed
    #  - there is a difference in platform between local and remote
    #  - NOT running in GitHub Actions
    # It should: build a local image from scratch and NOT save it to file
    def test_image_local_build_architecture_mismatch(self):
        context = self.test_context(allow_local_build=True)
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.ARM_64]
        shell.docker_image_exists_locally.side_effect = [False]

        self.assertEqual(0, acquire_build_image(context, shell))
        shell.docker_image_exists_locally.assert_called_once()
        shell.docker_build_base_image.assert_called_with("someimagetag", "/tmp/test/tools-path/ci-build")
        shell.docker_save.assert_not_called()
        shell.docker_tag.assert_called_with(LOCAL_BASE_IMAGE_NAME, "someimagetag", LOCAL_BASE_IMAGE_NAME, LOCAL_TAG)
        shell.docker_build_build_image.assert_called_with("123", "/tmp/test/tools-path/ci-build")

    # When:
    #  - the base image doesn't exist locally
    #  - the base image doesn't exist remotely
    #  - local builds are allowed
    #  - running in GitHub Actions
    # It should: build a local image from scratch and save it to file
    def test_image_local_build_github_actions(self):
        context = self.test_context(allow_local_build=True, github_actions=True)
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.X86_64]
        shell.docker_image_exists_locally.side_effect = [False]
        shell.docker_pull.side_effect = [DockerPullResult.NOT_FOUND]

        self.assertEqual(0, acquire_build_image(context, shell))
        shell.docker_image_exists_locally.assert_called_once()
        shell.docker_build_base_image.assert_called_with("someimagetag", "/tmp/test/tools-path/ci-build")
        shell.docker_save.assert_called_with(
            LOCAL_BASE_IMAGE_NAME,
            "someimagetag",
            "/tmp/test/start-path/smithy-rs-base-image"
        )
        shell.docker_tag.assert_called_with(LOCAL_BASE_IMAGE_NAME, "someimagetag", LOCAL_BASE_IMAGE_NAME, LOCAL_TAG)
        shell.docker_build_build_image.assert_called_with("123", "/tmp/test/tools-path/ci-build")

    # When:
    #  - the base image doesn't exist locally
    #  - the base image doesn't exist remotely
    #  - local builds are NOT allowed
    # It should: fail since local builds are not allowed
    def test_image_fail_local_build_disabled(self):
        context = self.test_context(allow_local_build=False)
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.X86_64]
        shell.docker_image_exists_locally.side_effect = [False]
        shell.docker_pull.side_effect = [DockerPullResult.NOT_FOUND]

        self.assertEqual(1, acquire_build_image(context, shell))
        shell.docker_image_exists_locally.assert_called_once()
        shell.docker_build_base_image.assert_not_called()
        shell.docker_save.assert_not_called()
        shell.docker_tag.assert_not_called()
        shell.docker_build_build_image.assert_not_called()

    # When:
    #  - the base image doesn't exist locally
    #  - the base image exists remotely
    # It should: pull the remote image and tag it
    def test_pull_remote_image(self):
        context = self.test_context(allow_local_build=False)
        shell = self.mock_shell()
        shell.platform.side_effect = [Platform.X86_64]
        shell.docker_image_exists_locally.side_effect = [False]
        shell.docker_pull.side_effect = [DockerPullResult.SUCCESS]

        self.assertEqual(0, acquire_build_image(context, shell))
        shell.docker_image_exists_locally.assert_called_once()
        shell.docker_build_base_image.assert_not_called()
        shell.docker_save.assert_not_called()
        shell.docker_tag.assert_has_calls([
            call(REMOTE_BASE_IMAGE_NAME, "someimagetag", LOCAL_BASE_IMAGE_NAME, "someimagetag"),
            call(LOCAL_BASE_IMAGE_NAME, "someimagetag", LOCAL_BASE_IMAGE_NAME, LOCAL_TAG)
        ])
        shell.docker_build_build_image.assert_called_with("123", "/tmp/test/tools-path/ci-build")


def main():
    # Run unit tests if given `--self-test` argument
    if len(sys.argv) > 1 and sys.argv[1] == "--self-test":
        sys.argv.pop()
        unittest.main()
    else:
        sys.exit(acquire_build_image())


if __name__ == "__main__":
    main()