#!/usr/bin/env python3
import argparse
import collections
import collections.abc
import concurrent.futures
import dataclasses
import fnmatch
import json
import sys

import gitlab
import gitlab.v4.objects
import yaml


def thread_pool(num_workers, func, items, num_retries=0):
    def func_with_retries(item):
        for i in range(num_retries + 1):
            try:
                func(item)
            except Exception:
                itemid = getattr(item, "path_with_namespace", None)
                if itemid is None:
                    itemid = getattr(item, "id", None)
                if itemid is None:
                    itemid = item
                print(f"ERROR on item {itemid}, try #{i}", file=sys.stderr)
                sys.excepthook(*sys.exc_info())
            else:
                break

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(func_with_retries, item) for item in items]
    results = [f.result() for f in futures]
    return results


def flatten(items):
    for x in items:
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, (str, bytes)):
            yield from flatten(x)
        else:
            yield x


def is_glob(pattern):
    return any(i in pattern for i in ["*", "?", "["])


def asdict(d):
    if dataclasses.is_dataclass(d):
        return dataclasses.asdict(d)
    if isinstance(d, gitlab.base.RESTObject):
        r = dict(d.attributes)
        # skip some attributes that we don't want to print out
        r.pop("id", None)
        r.pop("created_at", None)
        r.pop("updated_at", None)
        r.pop("owner", None)
        r.pop("project_id", None)
        # special cases some asymmetric attributes
        merge_access_levels = r.pop("merge_access_levels", None)
        if merge_access_levels:
            r["merge_access_level"] = merge_access_levels[0]["access_level"]
        push_access_levels = r.pop("push_access_levels", None)
        if push_access_levels:
            r["push_access_level"] = push_access_levels[0]["access_level"]
        return r
    return d


def get_changes(src, dst):
    changes = {}
    src_dict = asdict(src)
    dst_dict = asdict(dst)
    for k, new_val in dst_dict.items():
        attributes_suffix = "_attributes"
        k_src = k
        if k_src.endswith(attributes_suffix):
            k_src = k_src[: -len(attributes_suffix)]
        old_val = src_dict.get(k_src)
        if isinstance(old_val, dict):
            del old_val["next_run_at"]  # computed key
        if isinstance(new_val, dict):
            if "name_regex_delete" in new_val:
                # convert to the compatibility name used by the read-side API
                new_val["name_regex"] = new_val.pop("name_regex_delete")
        if old_val != new_val:
            changes[k] = new_val
    return changes


SplittedSets = collections.namedtuple(
    "SplittedSets", ["unchanged", "add", "drop", "change"]
)


def split_sets_by_action(current, target, keyfunc):
    target_dict = {keyfunc(b): b for b in target}
    current_keys = set(keyfunc(b) for b in current)
    target_keys = set(target_dict.keys())

    to_keep = [b for b in current if keyfunc(b) in current_keys & target_keys]
    to_add = [b for b in target if keyfunc(b) in target_keys - current_keys]
    to_drop = [b for b in current if keyfunc(b) in current_keys - target_keys]

    unchanged = []
    to_change = []
    for src in to_keep:
        dst = target_dict[keyfunc(src)]
        if get_changes(src, dst):
            to_change.append(src)
        else:
            unchanged.append(src)

    return SplittedSets(unchanged, to_add, to_drop, to_change)


def apply_changes(project, kind, collection, obj, target, description=None):
    actions = []
    changes = get_changes(obj, target)
    description_prefix = f"{description}: " if description else ""

    for k, v in changes.items():
        actions.append(
            Action(
                project=project,
                kind=kind,
                description=description_prefix + f"set {k} to {v}",
                func=lambda obj=obj, k=k, v=v: setattr(obj, k, v),
            )
        )
    if not actions:
        return actions
    if isinstance(obj, gitlab.mixins.SaveMixin):
        actions.append(
            SaveAction(
                project=project,
                kind=kind,
                description=description_prefix + "save changes",
                func=obj.save,
            )
        )
    else:

        def recreate(*args, **kwargs):
            obj.delete()
            collection.create(*args, *kwargs)

        actions.append(
            SaveAction(
                project=project,
                kind=kind,
                description=description_prefix + "recreate",
                func=recreate,
                params=(asdict(target),),
            )
        )
    return actions


def collection_apply(project, kind, collection, target, keyfunc):
    unchanged, to_add, to_drop, to_change = split_sets_by_action(
        collection.list(), target, keyfunc
    )
    ret = []
    for b in sorted(to_add, key=keyfunc):
        ret.append(
            Action(
                project=project,
                kind=kind,
                description=f'add "{keyfunc(b)}"',
                func=collection.create,
                params=(asdict(b),),
            )
        )
    for b in sorted(to_drop, key=keyfunc):
        ret.append(
            Action(
                project=project,
                kind=kind,
                description=f'drop "{keyfunc(b)}"',
                func=b.delete,
            )
        )
    for b in sorted(to_change, key=keyfunc):
        changes = next(i for i in target if keyfunc(i) == keyfunc(b))
        ret += apply_changes(
            project, kind, collection, b, changes, f'edit "{keyfunc(b)}"'
        )
    return ret


@dataclasses.dataclass(repr=False, eq=False, frozen=False)
class Action:
    project: gitlab.v4.objects.Project
    kind: str
    description: str
    msg: str = dataclasses.field(init=False)
    func: callable = None
    params: tuple = ()
    ret: object = None

    def __post_init__(self):
        self.msg = f"project {self.project.path_with_namespace}: {self.kind}: {self.description}"

    def invoke(self):
        try:
            self.ret = self.func(*self.params)
        except Exception as e:
            self.ret = e

    def is_error(self):
        return isinstance(self.ret, Exception)

    def __str__(self):
        s = self.msg
        if self.is_error():
            s += " -> error"
        return s

    def __repr__(self):
        s = str(self)
        return "{}<{}>".format(type(self).__name__, s)

    def to_json(self):
        d = {
            "project": {
                "id": self.project.id,
                "path_with_namespace": self.project.path_with_namespace,
            },
            "kind": self.kind,
            "description": self.description,
        }
        if self.is_error():
            d["error"] = str(self.ret)
        return d


class SaveAction(Action):
    pass


class Error(Action):
    def invoke(self):
        self.ret = Exception(self.msg)

    def is_error(self):
        return True


class Rule:
    def __init__(self, rulez):
        self.rulez = rulez
        self.rule = {}

    def prepare(self, project, rule):
        self.project = project
        self.rule = rule

    def compute(self):
        actions = []
        for k, v in self.rule.items():
            method_name = "rule_" + k
            method = getattr(self, method_name, None)
            if method:
                if type(v) is list:
                    actions += method(*v)
                elif type(v) is dict:
                    actions += method(**v)
                else:
                    actions += method(v)
        return actions

    def rule_ensure_branches(self, *items):
        branches = self.project.branches.list(per_page=100)
        actions = []
        for i in items:
            name = i["name"]
            base = i["base"]
            required = i.get("required", False)

            branch = next((b for b in branches if b.name == name), None)
            if branch:
                continue

            basebranch = next((b for b in branches if b.name == base), None)
            if basebranch:
                actions.append(
                    Action(
                        project=self.project,
                        kind="ensure_branch",
                        description="branch {} from {}".format(name, base),
                        func=self.project.branches.create,
                        params=({"branch": name, "ref": base},),
                    )
                )
            elif required:
                actions.append(
                    Error(
                        project=self.project,
                        kind="ensure_branch",
                        description='base ref "{}" for branch "{}" not found'.format(
                            base, name
                        ),
                    )
                )
        return actions

    def rule_protected_branches(self, *branches):
        actions = collection_apply(
            self.project,
            "protected_branches",
            self.project.protectedbranches,
            branches,
            lambda b: asdict(b)["name"],
        )
        return actions

    def rule_protected_tags(self, *tags):
        actions = collection_apply(
            self.project,
            "protected_tags",
            self.project.protectedtags,
            tags,
            lambda b: asdict(b)["name"],
        )
        return actions

    def rule_settings(self, **settings):
        actions = apply_changes(self.project, "settings", None, self.project, settings)
        return actions

    def rule_default_branch(self, branch):
        return self.rule_settings(default_branch=branch)

    def rule_hooks(self, *hooks):
        actions = collection_apply(
            self.project, "hooks", self.project.hooks, hooks, lambda s: asdict(s)["url"]
        )
        return actions

    def rule_schedules(self, *schedules):
        actions = collection_apply(
            self.project,
            "schedules",
            self.project.pipelineschedules,
            schedules,
            lambda s: asdict(s)["description"],
        )
        return actions

    def rule_variables(self, *variables):
        actions = collection_apply(
            self.project,
            "variables",
            self.project.variables,
            variables,
            lambda s: asdict(s)["key"],
        )
        return actions


class GitlabRulez:
    def __init__(self, actually_apply=False, output_format="plain"):
        self.actually_apply = actually_apply
        self.output_format = output_format
        self.rulez = {}
        self.gl = None
        self.actions = []
        self.executed = []

    def connect(self, gitlab_instance, gitlab_server_url, gitlab_api_token):
        if gitlab_server_url:
            self.gl = gitlab.Gitlab(gitlab_server_url, private_token=gitlab_api_token)
        else:
            self.gl = gitlab.Gitlab.from_config(gitlab_instance)
        self.gl.auth()

    def _convert_constant_value(self, item, key):
        if key not in item:
            return
        symbol = item[key]
        if symbol == "NO_ACCESS":
            item[key] = 0
        else:
            item[key] = getattr(gitlab.const, symbol)

    def load_rules(self, config_path):
        with open(config_path) as f:
            self.rulez = yaml.safe_load(f)

        for rule in self.rulez["rules"]:
            for b in rule.get("protected_branches", []):
                self._convert_constant_value(b, "merge_access_level")
                self._convert_constant_value(b, "push_access_level")

    def _project_match_filter(self, project, filterglob):
        return fnmatch.fnmatch(project.path_with_namespace, filterglob)

    def _project_find_matching_rules(self, project):
        rules = []
        for rule in self.rulez["rules"]:
            excludes = (
                self._project_match_filter(project, i)
                for i in flatten(rule.get("excludes", []))
            )
            if any(excludes):
                continue
            matches = (
                self._project_match_filter(project, i) for i in flatten(rule["matches"])
            )
            if any(matches):
                rules.append(rule)
        return rules

    def _project_compute_rule(self, project, rule):
        r = Rule(self)
        r.prepare(project, rule)
        return r.compute()

    def _project_compute_rules(self, project, rules):
        if self.output_format == "plain":
            print(project.id, project.path_with_namespace)
        if not rules:
            return []
        actions = []
        for rule in rules:
            actions += self._project_compute_rule(project, rule)
        return actions

    def apply_rules(self, filterglob=None, archived=False):
        projects = self.fetch_projects(filterglob, archived=archived)

        actions = []
        num_worker_threads = 10

        def compute_rules(project):
            rules = self._project_find_matching_rules(project)
            a = self._project_compute_rules(project, rules)
            actions.extend(a)

        thread_pool(num_worker_threads, compute_rules, projects, num_retries=2)
        self.actions = actions

        for action in self.actions:
            if self.output_format == "plain":
                print(action)
            if self.actually_apply or isinstance(action, Error):
                action.invoke()
                self.executed.append(action)

    def summary(self):
        if self.output_format == "json":
            return self.summary_json()
        return self.summary_plain()

    def summary_json(self):
        ret = 0
        if self.actions and not self.actually_apply:
            ret = 10
        data = {"actions": [a.to_json() for a in self.actions]}
        json.dump(data, sys.stdout)
        failed = any(action.is_error() for action in self.executed)
        if failed:
            ret = 1
        return ret

    def summary_plain(self):
        ret = 0
        print("computed {} actions".format(len(self.actions)))
        if self.actions:
            if self.actually_apply:
                print("applied {} actions".format(len(self.executed)))
            else:
                ret = 10
        failed = [action for action in self.executed if action.is_error()]
        if failed:
            print("found {} errors".format(len(failed)))
            for action in failed:
                print("{}: {}".format(action, action.ret))
            ret = 1
        return ret

    def fetch_projects(self, filterglob=None, **kwargs):
        if filterglob and not is_glob(filterglob):
            return [self.gl.projects.get(filterglob)]

        num_worker_threads = 10
        per_page = 100
        projects = []

        def fetch_page(page):
            p = self.gl.projects.list(per_page=per_page, page=page, **kwargs)
            projects.extend(p)

        total_pages = self.gl.projects.list(
            iterator=True, per_page=100, **kwargs
        ).total_pages
        thread_pool(
            num_worker_threads, fetch_page, range(1, total_pages + 1), num_retries=2
        )
        if filterglob:
            projects = [
                p for p in projects if self._project_match_filter(p, filterglob)
            ]
        projects.sort(key=lambda p: p.path_with_namespace)
        return projects

    def list_projects(self, filterglob=None, archived=False):
        projects = self.fetch_projects(
            filterglob, archived=archived, query_data={"simple": True}
        )
        if self.output_format == "json":
            data = [
                {"id": p.id, "path_with_namespace": p.path_with_namespace}
                for p in projects
            ]
            json.dump(data, sys.stdout)
        else:
            for p in projects:
                print(p.id, p.path_with_namespace)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog="gitlab-rulez")
    parser.add_argument(
        "--gitlab-instance",
        type=str,
        default="apertis",
        help="get connection parameters from this configured instance",
    )
    parser.add_argument("--gitlab-api-token", type=str, help="the GitLab API token")
    parser.add_argument("--gitlab-server-url", type=str, help="the GitLab instance URL")
    parser.add_argument(
        "--output",
        type=str,
        help="the output format",
        default="plain",
        choices=["plain", "json"],
    )

    subparsers = parser.add_subparsers(dest="command", required=True)

    parser_list = subparsers.add_parser("list-projects", help="list the project paths")
    parser_list.add_argument(
        "--filter", type=str, help="only lists objects matching the specified path glob"
    )
    parser_list.add_argument(
        "--archived",
        action="store_const",
        const=None,
        default=False,
        help="include archived projects",
    )

    parser_diff = subparsers.add_parser("diff", help="check what the rules would do")
    parser_diff.add_argument(
        "--filter",
        type=str,
        help="only act on objects matching the specified path glob",
    )
    parser_diff.add_argument(
        "--archived",
        action="store_const",
        const=None,
        default=False,
        help="include archived projects",
    )
    parser_diff.add_argument("RULES", help="the YAML rules file")

    parser_apply = subparsers.add_parser("apply", help="apply the changes to GitLab")
    parser_apply.add_argument(
        "--filter",
        type=str,
        help="only act on objects matching the specified path glob",
    )
    parser_apply.add_argument(
        "--archived",
        action="store_const",
        const=None,
        default=False,
        help="include archived projects",
    )
    parser_apply.add_argument("RULES", help="the YAML rules file")

    args = parser.parse_args(sys.argv[1:])

    if args.command == "list-projects":
        r = GitlabRulez(False, args.output)
        r.connect(args.gitlab_instance, args.gitlab_server_url, args.gitlab_api_token)
        r.list_projects(filterglob=args.filter, archived=args.archived)
        sys.exit(0)

    actually_apply = args.command == "apply"
    r = GitlabRulez(actually_apply, args.output)
    r.load_rules(args.RULES)
    r.connect(args.gitlab_instance, args.gitlab_server_url, args.gitlab_api_token)
    r.apply_rules(filterglob=args.filter, archived=args.archived)
    ret = r.summary()
    sys.exit(ret)
