#!/usr/bin/env python3

import argparse
from contextlib import contextmanager
from collections import namedtuple
import copy
import functools
import glob
import logging
import os
import os.path
from pathlib import Path
import re
import shutil
import subprocess
from subprocess import check_output
import tempfile

import yaml

DEFAULT_REPOS_YAML = "repositories.yaml"
LOGGER = logging.getLogger("bump")
CMD_LOGGER = LOGGER.getChild("cmd")
HPC_CC_HTTP_REMOTE = "https://github.com/BlueBrain/hpc-coding-conventions.git"

IGNORED_CONFIG_FILES = {
    ".clang-format",
    ".clang-tidy",
    ".cmake-format.yaml",
    ".pre-commit-config.yaml",
}


@contextmanager
def spack_build_env(spec):
    """Execute a Python context within a `spack build-env` environment"""
    cur_environ = os.environ.copy()
    cmd = ["spack", "build-env", spec, "bash", "--norc", "-c", "env"]
    output = check_output(cmd).decode("utf-8")
    VAR_RE = re.compile("^[a-zA-Z][a-zA-Z0-9_]*=.*")
    environ = {
        tuple(var.split("=", 1)) for var in output.splitlines() if VAR_RE.match(var)
    }
    try:
        os.environ.clear()
        os.environ.update(environ)
        yield environ
    finally:
        os.environ.clear()
        os.environ.update(cur_environ)


def git_status():
    adds = set()
    changes = set()
    dels = set()
    unknowns = set()
    for change in (
        check_output(["git", "status", "--short"]).decode("utf-8").splitlines()
    ):
        if change.startswith("D  "):
            dels.add(change[3:])
        elif change.startswith(" M "):
            changes.add(change[3:])
        elif change.startswith("?? "):
            unknowns.add(change[3:])
        elif change.startswith("A  "):
            adds.add(change[3:])
    return adds, changes, dels, unknowns


def _check_call(*args):
    CMD_LOGGER.info(" ".join(args))
    return subprocess.check_call(args)


def _call(*args):
    CMD_LOGGER.info(" ".join(args))
    return subprocess.call(args)


@contextmanager
def pushd(path):
    old_cwd = os.getcwd()
    cwd = os.chdir(path)
    try:
        yield cwd
    finally:
        os.chdir(old_cwd)


class SpackDevBuild:
    def __init__(self, spec, **kwargs):
        self._spec = spec
        self._command = SpackDevBuild.get_command(spec, **kwargs)
        self._call()
        self._env = self._extract_environ()

    def _call(self):
        timestamp_file = Path(".spack-dev-build.ts")
        timestamp_file.touch()
        _check_call(*self._command)
        self._path = None
        for d in glob.glob("spack-build-*"):
            if os.path.isdir(d) and os.path.getmtime(d) > os.path.getmtime(
                timestamp_file
            ):
                self._path = d
                break
        timestamp_file.unlink()
        if not self._path:
            raise RuntimeError("Could not deduce spack dev-build directory")

    def _extract_environ(self):
        cmd = "source spack-build-env.txt; unset PYTHONHOME; $SPACK_PYTHON -c "
        cmd += "'import os; print(repr(dict(os.environ)))'"
        return eval(check_output(cmd, shell=True).decode("utf-8"))

    @classmethod
    def get_command(cls, spec, test=None, overwrite=True, until=None):
        cmd = ["spack", "dev-build"]
        if overwrite:
            cmd.append("--overwrite")
        if until:
            cmd += ["--until", until]
        if test:
            cmd += [f"--test={test}"]
        cmd.append(spec)
        return cmd

    @property
    def spec(self):
        return self._spec

    @property
    def path(self):
        return self._path

    @property
    def command(self):
        return self._command

    def __enter__(self):
        self.__prev_cwd = os.getcwd()
        self.__prev_environ = copy.deepcopy(os.environ)
        os.chdir(self.path)
        os.environ.clear()
        os.environ.update(self._env)
        return self.path

    def __exit__(self, type, value, traceback):
        os.chdir(self.__prev_cwd)
        os.environ = self.__prev_environ
        del self.__prev_cwd
        del self.__prev_environ


class Repository(
    namedtuple(
        "Repository",
        [
            "url",
            "features",
            "location",
            "cmake_project_name",
            "default_branch",
            "spack_spec",
            "spack_until",
            "patch",
        ],
    )
):
    @staticmethod
    def create(**kwargs):
        url = kwargs["url"]
        kwargs.setdefault("location", "hpc-coding-conventions")
        kwargs.setdefault("default_branch", "master")
        kwargs.setdefault("spack_spec", None)
        kwargs.setdefault("spack_until", "cmake")
        if "patch" in kwargs:
            kwargs["patch"] = os.path.abspath(kwargs["patch"])
        else:
            kwargs["patch"] = None
        if "github.com" in url:
            repo = GitHubRepository(**kwargs)
        elif "bbpcode.epfl.ch" in url:
            repo = GerritRepository(**kwargs)
        elif "bbpgitlab.epfl.ch" in url:
            repo = GitLabRepository(**kwargs)
        else:
            raise Exception("Unsupported git server")
        repo.log = LOGGER.getChild(repo.name)
        return repo

    def submit(self):
        raise NotImplementedError

    @property
    def name(self):
        name = self.url.split(":")[-1]
        name = name.rsplit("/", 2)[-2:]
        if name[-1].endswith(".git"):
            name[-1] = name[-1][:-4]
        return "-".join(name)

    def bump_branch_exists(self, revision):
        return self._bump_branch(revision) in self.branches

    @property
    def remote(self):
        return self.name

    @property
    def branch(self):
        return self.name

    @property
    def remotes(self):
        return check_output(["git", "remote"]).decode("utf-8").split("\n")

    @property
    def branches(self):
        eax = []
        for br in check_output(["git", "branch"]).decode("utf-8").split("\n"):
            if br.startswith("*"):
                br = br[1:]
            eax.append(br.strip())
        return eax

    def fetch(self):
        self.log.info("fetching repository")
        if self.remote not in self.remotes:
            _check_call("git", "remote", "add", self.name, self.url)
        _check_call("git", "fetch", self.name)
        return True

    def _bump_branch(self, revision):
        return self.name + "-" + revision[:8]

    def _upstream_branch(self, revision):
        return self.remote + "/" + self._bump_branch(revision)

    def checkout_branch(self, revision):
        branch = self._bump_branch(revision)
        if branch in self.branches:
            _check_call("git", "checkout", "-f", "master")
            _check_call("git", "branch", "-D", branch)
        self.log.info("checkout branch %s", branch)

        _check_call(
            "git",
            "checkout",
            "-b",
            branch,
            "{}/{}".format(self.remote, self.default_branch),
        )
        if not os.path.exists(self.location):
            with pushd(os.path.dirname(self.location)):
                _check_call("git", "submodule", "add", "--force", HPC_CC_HTTP_REMOTE)
        _check_call("git", "submodule", "update", "--recursive", "--init")
        return True

    def bump(self, revision):
        self.log.info("Bump submodule")
        with pushd(self.location):
            _check_call("git", "fetch")
            _check_call("git", "checkout", revision)
        if self.is_dirty:
            _check_call("git", "add", self.location)
            _check_call("git", "status")
            return True
        else:
            return False

    def update_gitignore(self):
        _, changes, _, unknowns = git_status()
        ignore_additions = []
        for unknown in unknowns:
            if unknown in IGNORED_CONFIG_FILES:
                ignore_additions.append(unknown)
        for change in changes:
            if change in IGNORED_CONFIG_FILES:
                _check_call("git", "rm", "--force", change)
                ignore_additions.append(unknown)
            else:
                _check_call("git", "add", change)
        ignore_additions.sort()
        if ignore_additions:
            with open(".gitignore", "a") as ostr:
                for ignored in ignore_additions:
                    print(ignored, file=ostr)
            _check_call("git", "add", ".gitignore")
        return True

    def commit(self, revision):
        _check_call("git", "commit", "-m", self.commit_message(revision))
        return True

    def commit_message(self, revision):
        return f"Bump {self.location} submodule to {revision[:8]}"

    def merge_request_title(self, revision):
        return self.commit_message(revision)

    @property
    def is_dirty(self):
        return (
            _call("git", "diff", "--quiet") != 0
            or _call("git", "diff", "--staged", "--quiet") != 0
        )

    def test(self):
        with self._test_cmake():
            self._test_formatting_targets()
            self._test_static_analysis()
        self._test_precommit()
        return True

    @property
    def cmake_base_cmd(self):
        cmd = ["cmake"]
        var_prefix = "-D" + self.cmake_project_name + "_"
        for feature, value in self.features.items():
            if value:
                cmd.append(var_prefix + feature.upper() + ":BOOL=ON")
        return cmd

    def _test_cmake(self):
        if self.spack_spec:
            dev_build = SpackDevBuild(self.spack_spec, until=self.spack_until)
            with dev_build:
                # recall CMake to set proper variables
                _check_call(*self.cmake_base_cmd, ".")
            return dev_build

        if os.path.isdir("_build"):
            shutil.rmtree("_build")
        os.makedirs("_build")
        with pushd("_build"):
            cmd = copy.copy(self.cmake_base_cmd)
            python = check_output(["which", "python"]).decode("utf-8").rstrip()
            cmd.append("-DPYTHON_EXECUTABLE=" + python)
            cmd.append("..")
            _check_call(*cmd)
        return functools.partial(pushd, "_build")

    def _test_formatting_targets(self):
        if self.features.get("formatting") or self.features.get("clang_format"):
            _check_call("make", "clang-format")
        if self.features.get("formatting") or self.features.get("cmake_format"):
            _check_call("make", "cmake-format")

    def _test_static_analysis(self):
        if self.features.get("static_analysis"):
            _check_call("make", "clang-tidy")

    def _test_precommit(self):
        if self.features.get("precommit"):
            _check_call("pre-commit", "run", "-a")

    def clean_local_branches(self):
        _check_call("git", "checkout", "-f", "master")
        _check_call("git", "clean", "-ffdx")
        for branch in self.branches:
            if branch.startswith(self.remote):
                _check_call("git", "branch", "-D", branch)
        precommit_hook = os.path.join(".git", "hooks", "pre-commit")
        if os.path.exists(precommit_hook):
            os.remove(precommit_hook)


class GitHubRepository(Repository):
    def submit(self, revision):
        _check_call("git", "push", self.remote, "HEAD")
        _check_call(
            "hub",
            "pull-request",
            "-m",
            self.commit_message(revision),
        )
        return True


class GerritRepository(Repository):
    def submit(self, revision):
        _check_call("git-review")
        return True


class GitLabRepository(Repository):
    def submit(self, revision):
        _check_call(
            "git",
            "push",
            self.remote,
            "HEAD",
            "merge_request.create",
            "merge_request.title=" + self.merge_request_title(revision),
        )
        return True


def repositories(file):
    with open(file) as istr:
        return [
            Repository.create(**repo) for repo in yaml.safe_load(istr)["repositories"]
        ]


class IntegrationRepository:
    def __init__(self, repos, top_repository):
        self._repos = repos
        self._ignored = set()
        self.top_repository = top_repository

        if not os.path.isdir(self.top_repository):
            _check_call("git", "init", self.top_repository)
            with pushd(self.top_repository):
                with open("test", "w") as ostr:
                    ostr.write("test")
                _check_call("git", "add", "test")
                _check_call("git", "commit", "-n", "-m", "add test")
        if not os.path.isdir(os.path.join(self.top_repository, ".git")):
            raise Exception("Could not find .git directory in " + self.top_repository)

    @property
    def repos(self):
        return [repo for repo in self._repos if repo.name not in self._ignored]

    def clean_local_branches(self):
        with pushd(self.top_repository):
            for repo in self.repos:
                repo.clean_local_branches()

    def update(self, revision, dry_run):
        with pushd(self.top_repository):
            self._fetch()
            self._checkout_branch(revision)
            self._patch()
            self._bump(revision)
            self._test()
            self._update_gitignore()
            self._commit(revision)
            if not dry_run:
                self._submit(revision)

    def _fetch(self):
        succeeded = True
        for repo in self.repos:
            succeeded &= repo.fetch()
        if not succeeded:
            raise Exception("Fetch operation failed")

    def _checkout_branch(self, revision):
        succeeded = True
        for repo in self.repos:
            if repo.bump_branch_exists(revision):
                repo.log.info(
                    'ignored because branch "%s" already exists.',
                    repo._bump_branch(revision),
                )
                self._ignored.add(repo.name)
            else:
                succeeded &= repo.checkout_branch(revision)
        if not succeeded:
            raise Exception("Checkout operation failed")

    def _patch(self):
        for repo in self.repos:
            if repo.patch:
                subprocess.check_call("patch -p1 <" + repo.patch, shell=True)

    def _bump(self, revision):
        for repo in self.repos:
            if not repo.bump(revision):
                repo.log.info("ignore because submodule is up to date")
                self._ignored.add(repo.name)

    def _test(self):
        succeeded = True
        for repo in self.repos:
            succeeded &= repo.test()
        if not succeeded:
            raise Exception("Test operation failed")

    def _update_gitignore(self):
        succeeded = True
        for repo in self.repos:
            succeeded &= repo.update_gitignore()
        if not succeeded:
            raise Exception("Gitignore update operation failed")

    def _commit(self, revision):
        succeeded = True
        for repo in self.repos:
            succeeded &= repo.commit(revision)
        if not succeeded:
            raise Exception("Commit operation failed")

    def _submit(self, revision):
        succeeded = True
        for repo in self.repos:
            succeeded = repo.submit(revision)
        if not succeeded:
            raise Exception("Submit operation failed")


def main(**kwargs):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config",
        metavar="FILE",
        help="Configuration file [default: %(default)s]",
        default=DEFAULT_REPOS_YAML,
    )
    parser.add_argument(
        "-r",
        "--revision",
        metavar="HASH",
        default="HEAD",
        help="Git revision to bump [default: %(default)s]",
    )
    parser.add_argument(
        "--clean", help="remove local git branches", action="store_true"
    )
    parser.add_argument(
        "-d",
        "--dry-run",
        action="store_true",
        help="do not create pull-request or gerrit review",
    )
    parser.add_argument(
        "--repo",
        help="Collective git repository [default: %(default)s]",
        default=os.path.join(tempfile.gettempdir(), "hpc-cc-projects"),
    )
    parser.add_argument(
        "-p", "--project", nargs="+", help="Filter repositories by CMake project name"
    )
    args = parser.parse_args(**kwargs)
    revision = check_output(["git", "rev-parse", args.revision]).decode("utf-8").strip()
    repos = repositories(args.config)
    if args.project:
        repos = [repo for repo in repos if repo.cmake_project_name in set(args.project)]
    ir = IntegrationRepository(repos, args.repo)
    if args.clean:
        ir.clean_local_branches()
    else:
        ir.update(revision, args.dry_run)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()
