#!/usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright © 2015-2016 marmuta <marmvta@gmail.com>
# Copyright © 2016-2017 Francesco Fumanti <francesco.fumanti@gmx.net>
#
# This file is part of Onboard.
#
# Onboard is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# Onboard is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os
import sys
import optparse
import subprocess
import re


def main():
    args = sys.argv[1:]
    parser = optparse.OptionParser(
        usage="Usage: %prog command [command_options]\n"
              "\n"
              "Commands:\n"
              "   %prog list-committers          list committers per file\n"
              "   %prog update-headers           write license headers\n"
              "\n"
              "   (append -h for help with individual commands)\n")

    parser.disable_interspersed_args()
    options, args = parser.parse_args(args)

    if not args:
        parser.print_help()
        sys.exit(1)
    else:
        cmd_str = args.pop(0)
        command = find_command(cmd_str)
        if not command:
            print("error: unknown command '{}'".format(cmd_str),
                  file=sys.stderr)
            sys.exit(1)

    command().run(args)

    sys.exit(0)


def find_command(cmd_str):
    commands = [CommandListCommitters,
                CommandHeaderUpdate]
    for command in commands:
        if any(cmd_str == c for c in command.cmd_strings):
            return command
    return None


class Command:
    class Committer:
        def __init__(self, name, emails, ignore=False, hide_email=False):
            self.name = name
            self.emails = emails
            self.ignore = ignore
            self.hide_email = hide_email
            assert(type(self.emails) is list or
                   type(self.emails) is tuple)

        def __repr__(self):
            return type(self).__name__ + "(" + repr(self.name) + ")"

        def get_main_email(self):
            if not self.hide_email:
                return self.emails[0]
            return None

    _known_committers = [
        Committer("Alan Bell",
                  ["alanbell@ubuntu.com"]),
        Committer("Barry Warsaw",
                  ["barry@ubuntu.com"]),
        Committer("Chris Jones",
                  ["tortoise@tortuga",
                   "cej@blinkenlights",
                   "tortoise@chrisjdesktop",
                   "tortoise@xuanwu"]),
        Committer("Francesco Fumanti",
                  ["francesco.fumanti@gmx.net",
                   "frafu@UbuntuDesktop"]),
        Committer("Gerd Kohlberger",
                  ["lowfi@chello.at",
                   "gerdk@src.gnome.org"]),
        Committer("Henrik Nilsen Omma",
                  ["henrik@ubuntu.com"],
                  # according to Chris Jones he only committed files
                  ignore=True),
        Committer("Jeremy Bicha",
                  ["jbicha@ubuntu.com",
                   "jbicha@gnome.org",
                   "jbicha@linux.com"]),
        Committer("Johannes Almer",
                  ["almer@nomail.com"],
                  # Johannes asked not to publish his email
                  hide_email=True),
        Committer("Launchpad Code Hosting",
                  ["codehost@crowberry"]),
        Committer("Luke Yelavitch",
                  ["luke.yelavich@canonical.com"]),
        Committer("marmuta",
                  ["marmvta@gmail.com",
                   "user@dingsdale",
                   "user@box.local.org"]),
        Committer("Mike Gabriel",
                  ["mike.gabriel@das-netzwerkteam.de",
                  "sunweaver@debian.org"]),
        Committer("Martin Böhme",
                  ["martin.bohm@kubuntu.org"]),
        Committer("Martin Wimpress",
                  ["martin@flexion.org"]),
        Committer("Reiner Herrmann",
                  ["reiner@reiner-h.de"]),
        Committer("Simon Schumann",
                  ["simon.schumann@web.de"]),
    ]

    # Additional copyrights per file that can't be found with bzr annotate.
    # Filename may be a regex pattern.
    _extra_copyrights = [
        ["debian/rules",
         "barry@ubuntu.com", ['20110624']],
        ["data/onboard-autostart.desktop",
         "martin@flexion.org", ['20150101']],
        ["data/onboard-autostart.desktop",
         "luke.yelavich@canonical.com", ['20110101']],
        ["data/onboard-autostart.desktop",
         "marmvta@gmail.com", ['20120101', '20150101']],
        ["data/layoutstrings.py",
         "marmvta@gmail.com", ['20150101']],
        ["data/layoutstrings.py",
         "tortoise@tortuga", ['20080101', '20100101']],
        ["data/onboard.desktop.in",
         "mike.gabriel@das-netzwerkteam.de", ['20160731', '20170708']],
        ["data/onboard-settings.desktop.in",
         "mike.gabriel@das-netzwerkteam.de", ['20160731', '20170708']],
        ["gnome/legacy/Onboard_Indicator@onboard.org/extension.js",
         "simon.schumann@web.de", ['20160101']],
        ["gnome/legacy/Onboard_Indicator@onboard.org/metadata.json",
         "jbicha@ubuntu.com", ['20160731']],
        ["layouts/Whiteboard.onboard",
         "almer@nomail.com", ['20140101']],
        ["layouts/Whiteboard_wide.onboard",
         "almer@nomail.com", ['20140101']],
        ["Onboard/Config.py",
         "francesco.fumanti@gmx.net", ['20080101', '20110101']],
        ["Onboard/Exceptions.py",
         "tortoise@tortuga", ['20080101', '20100101']],
        ["Onboard/DBusUtils.py",
         "lowfi@chello.at", ['20130101']],
        ["Onboard/DBusUtils.py",
         "marmvta@gmail.com", ['20130101', '20140101', '20150101']],
        ["Onboard/IconPalette.py",
         "francesco.fumanti@gmx.net", ['20080101']],
        ["Onboard/osk/osk_audio.c",
         "lowfi@chello.at", ['20130101']],
        ["Onboard/osk/osk_click_mapper.c",
         "lowfi@chello.at", ['20110101']],
        ["Onboard/osk/osk_click_mapper.c",
         "marmvta@gmail.com", ['20120101']],
        ["Onboard/osk/osk_devices.c",
         "lowfi@chello.at", ['20110101']],
        ["Onboard/osk/osk_module.c",
         "lowfi@chello.at", ['20110101']],
        ["Onboard/osk/osk_module.h",
         "lowfi@chello.at", ['20110101']],
        ["Onboard/osk/osk_struts.c",
         "lowfi@chello.at", ['20120101']],
        ["Onboard/osk/osk_udev.c",
         "marmvta@gmail.com", ['20160101']],
        ["Onboard/osk/osk_uinput.c",
         "marmvta@gmail.com", ['20120101']],
        ["Onboard/osk/osk_virtkey.c",
         "tortoise@tortuga", ['20060101', '20070101', '20080101']],
        ["Onboard/osk/osk_virtkey.c",
         "marmvta@gmail.com", ['20100101']],
        ["Onboard/osk/osk_virtkey_x.c",
         "lowfi@chello.at", ['20130101']],
        ["Onboard/osk/osk_virtkey_x.c",
         "tortoise@tortuga", ['20060101', '20070101', '20080101']],
        ["Onboard/osk/osk_virtkey_x.c",
         "marmvta@gmail.com", ['20100101']],
        ["Onboard/Scanner.py",
         "lowfi@chello.at", ['20110101', '20130101']],
        ["Onboard/Sound.py",
         "lowfi@chello.at", ['20130101']],
        ["Onboard/UDevTracker.py",
         "marmvta@gmail.com", ['20160101']],
        ["Onboard/utils.py",
         "lowfi@chello.at", ['20120101']],
        ["setup.py",
         "francesco.fumanti@gmx.net", ['20150101']],
        ["setup.py",
         "reiner@reiner-h.de", ['20150101']],
        ["themes/Aubergine.colors",
         "alanbell@ubuntu.com", ['20110101']],
        ["themes/ModelM.colors",
         "alanbell@ubuntu.com", ['20110101']],
        ["themes/Typist.colors",
         "alanbell@ubuntu.com", ['20110101']],
    ]

    # bzr commits that should not grant copyright
    _exclude_revisions = [
        # Auto-generate license headers.
        "1941.1.6",
        # Regenerate license headers with revisions 1941.1.6,1941.1.8 excluded.
        "1941.1.8",
        # Merge license-cleanup branch including generated license headers.
        "1962",
        # Update license headers.
        "1963", "2178", "2248",
    ]

    # paths excluded by default (matched with startswith())
    _exclude_paths = [".bzr"]

    def run(self, *args):
        raise NotImplementedError()

    @staticmethod
    def iter_filtered_files(options, args):
        if args:
            for fn in args:
                yield fn
        else:
            exclude_paths = options.exclude_paths + Command._exclude_paths
            for dirpath, dirnames, filenames in os.walk("."):

                dirpath = os.path.normpath(dirpath)  # remove starting "./"
                if dirpath == ".":
                    dirpath = ""

                for basename in filenames:
                    fn = os.path.join(dirpath, basename)
                    if all(not fn.startswith(ep) for ep in exclude_paths):
                        yield fn

    @staticmethod
    def get_committer_emails(fn, exclude_emails=(), exclude_revisions=()):
        code = Command.get_code(fn, None, exclude_revisions)
        emails = set()
        for fields in code:
            assert(fields[3] == "|")
            email = fields[1]
            if email not in exclude_emails:
                emails.add(fields[1])
        return list(emails)

    @staticmethod
    def get_committers(fn, exclude_emails=(), exclude_revisions=()):
        """ Returns Committer instances """
        emails = Command.get_committer_emails(fn, exclude_emails,
                                              exclude_revisions)
        committers = set(Command.get_committer_from_email(e) for e in emails)
        return list(committers)

    @staticmethod
    def get_committer_from_email(email):
        for committer in Command._known_committers:
            for em in committer.emails:
                if em == email:
                    return committer

        raise ValueError("Unknown email address '{}'. Please add it to "
                         "'_known_committers' in tools/licensing"
                         .format(email))
        return None

    @staticmethod
    def get_committer_name_from_email(email):
        comm = Command.get_committer_from_email(email)
        if comm:
            return comm.name
        return None

    @staticmethod
    def get_current_email_from_name(name):
        for committer in Command._known_committers:
            if committer.name == name:
                return committer.emails[0]
        return ""

    @staticmethod
    def get_code(fn, emails, exclude_revisions=()):
        """ Get part of fn that was committed by <committer>. """
        cmd = ['/bin/bash', '-c',
               "bzr annotate --all --long --quiet '" + fn + "'"
               ]
        try:
            output = subprocess.check_output(cmd).decode()
            lines = output.split("\n")
        except UnicodeDecodeError:  # binary file?
            lines = []
        except subprocess.CalledProcessError as e:  # file not versioned?
            msg = "File '{}' might not be versioned. " \
                  "Please run 'bzr clean-tree' or " \
                  "update tools/licensing, FileTypeNotVersioned to " \
                  "include any new unversioned files.".format(fn)
            raise RuntimeError(msg) from e

        code = []
        for line in [o for o in lines if o]:
            fields = line.split()
            assert(fields[3] == "|")
            revision = fields[0]
            email = fields[1]
            if revision not in exclude_revisions:
                if emails is None or \
                   email in emails:
                    code.append(fields + [line])
        return code

    @staticmethod
    def find_file_type(fn):
        file_types = [FileTypeNotVersioned,
                      FileTypeBinary,
                      FileTypeNoHeader,
                      FileTypePython,
                      FileTypeC,
                      FileTypeShell,
                      FileTypeXML,
                      FileTypePO,
                      FileTypeDesktopShortcut,
                      FileTypeINI,
                      ]
        for ft in file_types:
            if ft.matches(fn):
                return ft()
        return None


class CommandListCommitters(Command):

    cmd_strings = ["list-committers", "list"]

    @staticmethod
    def parse_cmd_line(args):
        parser = optparse.OptionParser(
            usage="Usage: %prog list-committers [options] [file1 file2 ...]\n"
                  "List committers per file. "
                  "If no individual files are specified\n"
                  "files are found recursively starting "
                  "from the current directory.")
        parser.add_option("-C", "--exclude-committer",
                          dest="exclude_committers",
                          default=[], action="append",
                          help="exclude committer (email address)")
        parser.add_option("-P", "--exclude-path", dest="exclude_paths",
                          default=[], action="append",
                          help="exclude (partial) path")
        parser.add_option("-s", "--source-code", dest="source_code",
                          default=False, action="store_true",
                          help="print source code")
        return parser.parse_args(args)

    def run(self, args):
        failed_files = []
        unknown_files = []
        options, args = self.parse_cmd_line(args)
        for fn in self.iter_filtered_files(options, args):
            ft = self.find_file_type(fn)
            if not ft:
                unknown_files.append(fn)
            elif not ft.can_list_committers():
                failed_files.append([fn, ft])
            else:
                print(fn)
                emails = self.get_committer_emails(fn,
                                                   options.exclude_committers)
                if emails:
                    for email in emails:
                        print(" " * 4 +
                              self.get_committer_name_from_email(email) +
                              " <"  + email + ">")
                        if options.source_code:
                            code = self.get_code(fn, [email])
                            for fields in code:
                                print(" " * 8 + fields[-1])
                print()

        if failed_files:
            print(file=sys.stderr)
            print("Files that were recognized but couldn't be processed:",
                  file=sys.stderr)
            for fn, ft in failed_files:
                print(" " * 4, fn.ljust(50), ft.__class__.__name__,
                      file=sys.stderr)

        if unknown_files:
            print(file=sys.stderr)
            print("Unknown files, ignored:", file=sys.stderr)
            for fn in unknown_files:
                print(" " * 4, fn, file=sys.stderr)


class CommandHeaderUpdate(Command):
    cmd_strings = ["update-headers", "update"]

    @staticmethod
    def parse_cmd_line(args):
        parser = optparse.OptionParser(
            usage="Usage: %prog update-headers [options] [file1 file2 ...]\n"
                  "Rewrite license header per file. "
                  "If no individual files are specified\n"
                  "files are found recursively starting from "
                  "the current directory.")
        parser.add_option("-C", "--exclude-committer",
                          dest="exclude_committers",
                          default=[], action="append",
                          help="exclude committer (email address)")
        parser.add_option("-P", "--exclude-path", dest="exclude_paths",
                          default=[], action="append",
                          help="exclude (partial) path")
        parser.add_option("-v", "--verbose", dest="verbose",
                          default=False, action="store_true",
                          help="print more information")
        return parser.parse_args(args)

    def run(self, args):
        unknown_files = []
        noheader_files = []
        failed_files = []
        options, args = self.parse_cmd_line(args)
        for fn in self.iter_filtered_files(options, args):
            ft = self.find_file_type(fn)
            if not ft:
                unknown_files.append(fn)
            else:
                if ft.needs_license_header():
                    # read file
                    try:
                        with open(fn, encoding="UTF-8") as f:
                            lines = f.readlines()
                    except UnicodeDecodeError:  # binary file?
                        lines = None

                    print("'{}', {}".format(fn, type(ft).__name__))
                    try:
                        template = ft.get_license_template(lines)
                    except RuntimeError as e:
                        msg = "WARNING: skipping '{}'. {}" \
                              .format(fn, e)
                        print(msg, file=sys.stderr)
                        failed_files.append([fn, ft, e])
                    else:
                        header = self._create_license_header(
                            fn, template, options.exclude_committers)

                        # insert header
                        if lines:   # no error, not empty file
                            lines = ft.insert_license_header(lines, header)
                            if options.verbose:
                                for line in lines[:40]:
                                    print(" " * 8, repr(line))

                            # write file
                            with open(fn, mode="w", encoding="UTF-8") as f:
                                f.writelines(lines)

                else:
                    noheader_files.append(fn)

        if noheader_files and \
           options.verbose:
            print(file=sys.stderr)
            print("Files not requiring license header:", file=sys.stderr)
            for fn in noheader_files:
                print(" " * 4, fn, file=sys.stderr)

        if failed_files:
            print(file=sys.stderr)
            print("Files that were recognized but couldn't be processed:",
                  file=sys.stderr)
            for fn, ft, msg in failed_files:
                print("    {:50} {:16} {}"
                      .format(fn, ft.__class__.__name__, msg),
                      file=sys.stderr)

        if unknown_files:
            print(file=sys.stderr)
            print("Unknown files, ignored:", file=sys.stderr)
            for fn in unknown_files:
                print(" " * 4, fn, file=sys.stderr)

    def _create_license_header(self, fn, template, exclude_committers):
        # collect committer information from bzr
        copyright = ""
        if "{copyright}" in template:
            copyrights = {}
            committers = self.get_committers(fn, exclude_committers,
                                             self._exclude_revisions)
            for committer in committers:
                if not committer.ignore:
                    # find years of contribution
                    code = self.get_code(fn, committer.emails,
                                         self._exclude_revisions)
                    timestamps = set()
                    for fields in code:
                        assert(fields[3] == "|")
                        timestamps.add(fields[2])

                    timestamps = sorted(list(timestamps))
                    copyrights[committer] = tuple(timestamps)

            # add additional copyrights that can't be found with bzr annotate
            for file_pattern, email, timestamps in self._extra_copyrights:
                if re.fullmatch(file_pattern, fn):
                    committer = self.get_committer_from_email(email)
                    copyrights[committer] = \
                        copyrights.get(committer, ()) + tuple(timestamps)

            # sort by year tuple
            copyrights = list(copyrights.items())
            copyrights.sort(key=lambda x: x[1])
            # sort by last year of commit
            copyrights.sort(key=lambda x: x[1][-1])

            # build the copyright section
            for committer, timestamps in copyrights:
                years = set(int(ts[:4]) for ts in timestamps)
                years = sorted(list(years))
                year_str = self._build_year_range_str(years)
                copyright += "Copyright © " + year_str + " " + \
                             committer.name
                email = committer.get_main_email()
                if email is not None:
                    copyright += " <"  + email + ">"

                copyright += "\n"

        # assemble the final header
        header = template.format(copyright=copyright, prog="Onboard")
        header_lines = header.split("\n")
        lines = [line.replace("\n", "").rstrip()
                 for line in header_lines]

        return lines

    @staticmethod
    def _build_year_range_str(years):
        """
        Doctests:
        >>> func = CommandHeaderUpdate._build_year_range_str

        >>> func((1,2,3,4,5,))
        '1-5'
        >>> func((1,2,3,5,))
        '1-3, 5'
        >>> func((1,3,4,5,))
        '1, 3-5'
        >>> func((1,2,4,5,))
        '1-2, 4-5'
        >>> func((1,3,5,))
        '1, 3, 5'
        """
        def append_section(year_range):
            if year_range[0] == year_range[1]:
                year_sections.append(str(year_range[1]))
            elif year_range[0] < year_range[1]:
                year_sections.append(str(year_range[0]) + "-" +
                                     str(year_range[1]))
            else:
                assert(False)  # years must be sorted

        year_sections = []
        year_range = None
        for year in years:
            if year_range is not None:
                # directly following year? -> extend range
                if year_range[1] == year - 1:
                    year_range[1] = year
                else:  # gap -> format the range and start a new one
                    append_section(year_range)
                    year_range = None
            if year_range is None:
                year_range = [year, year]

        if year_range is not None:
            append_section(year_range)

        return ", ".join(str(section) for section in year_sections)


class FileType:
    @staticmethod
    def matches(fn):
        raise NotImplementedError()

    def needs_license_header(self):
        return True

    def can_list_committers(self):
        return True

    def read_license_header():
        raise NotImplementedError()

    @staticmethod
    def match_names(fn, paths, files, extensions):
        for p in paths:
            if fn.startswith(p):
                return True
        for f in files:
            if fn.endswith(f):
                return True
        for ext in extensions:
            if fn.endswith(ext):
                return True
        return False

    def get_license_template(self, lines):
        template = lstrip_docstring("""
{copyright}
This file is part of {prog}.

{prog} is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3 of the License, or
(at your option) any later version.

{prog} is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
""")    # noqa - shut up flake8
        return template

    def insert_license_header(self, lines, header):
        begin, end = self.find_insert_range(lines)

        # not an empty file?
        if begin is not None and \
           end is not None:

            # remove old range
            lines = [line for i, line in enumerate(lines)
                     if i < begin or i >= end]

            # insert header
            header = self.comment_lines(header)
            for line in reversed(header):
                lines.insert(begin, line + "\n")

        return lines

    def comment_lines(self, lines):
        raise NotImplementedError()


class FileTypePython(FileType):
    """ Matches python files """
    @staticmethod
    def matches(fn):
        if fn.endswith(".py"):
            return True

        # has python shebang?
        lines = read_file(fn)
        if lines:
            if re.match("#!.*python.*", lines[0]):
                return True

        return False

    def read_license_header():
        raise NotImplementedError()

    def comment_lines(self, lines):
        return [("# " + line).rstrip()  for line in lines]

    def find_insert_range(self, lines):
        """
        Find the line range of the existing license comment.

        Doctests:
        >>> func = FileTypePython().find_insert_range

        >>> func([])
        (None, None)

        >>> func(["#!/bin/usr/python3",
        ...       "# -*- coding: utf-8 -*-",
        ...       "",
        ...       "import os"])
        (3, 3)

        >>> func(["#!/bin/usr/python3",
        ...       "# -*- coding: utf-8 -*-",
        ...       "import os"])
        (2, 2)

        >>> func(["# -*- coding: utf-8 -*-",
        ...       '"\"" doc string "\""',
        ...       "",
        ...       "import os"])
        (1, 1)

        >>> func([" #!/bin/usr/python3",
        ...       " # -*- coding: utf-8 -*-",
        ...       "",
        ...       "import os"])
        (3, 3)

        >>> func(["#!/bin/usr/python3",
        ...       "# -*- coding: utf-8 -*-",
        ...       "",
        ...       "# Copyright",
        ...       "# GPL",
        ...       "",
        ...       "import os"])
        (3, 5)

        >>> func(["#!/bin/usr/python3",
        ...       "# -*- coding: utf-8 -*-",
        ...       "# Copyright",
        ...       "# GPL",
        ...       "",
        ...       "import os"])
        (2, 4)

        >>> func(["#!/bin/usr/python3",
        ...       "",
        ...       "# Copyright",
        ...       "# GPL",
        ...       "",
        ...       "import os"])
        (2, 4)

        >>> func(["# -*- coding: utf-8 -*-",
        ...       "",
        ...       "# Copyright",
        ...       "# GPL",
        ...       "",
        ...       "import os"])
        (2, 4)

        >>> func(["#!/bin/usr/python3",
        ...       "# -*- coding: utf-8 -*-",
        ...       "",
        ...       "# Copyright",
        ...       "# GPL",
        ...       "",
        ...       "# Usage:",
        ...       "# prog options",
        ...       "",
        ...       "import os"])
        (3, 5)
        """

        begin = None
        end = None
        for iline, line in enumerate(lines):
            stripped = line.strip()
            if stripped.startswith("#"):
                # shebang?
                if stripped.startswith("#!"):
                    begin = None

                # encoding line?
                elif re.match("#.*-\*-.*[Cc]oding", stripped):
                    begin = None

                else:
                    if begin is None:
                        begin = iline
                    end = iline + 1
            elif not stripped:
                if begin is not None:
                    break
            else:
                if begin is None:
                    begin = iline
                    end = iline
                break

        return begin, end


class FileTypeC(FileType):
    """ Matches C or C++ files """
    @staticmethod
    def matches(fn):
        extensions = ["c", "cpp", "h"]
        for ext in extensions:
            if fn.endswith("." + ext):
                return True
        return False

    def comment_lines(self, lines):
        return ["/*"] + \
               [(" * " + line).rstrip() for line in lines] + \
               [" */"]

    def find_insert_range(self, lines):
        """
        Find the line range of the existing license comment.

        Doctests:
        >>> func = FileTypeC().find_insert_range

        >>> func(["/*",
        ...       " * Copyright*-",
        ...       " */",
        ...
        ...       "#include"])
        (0, 3)

        >>> func(["/*",
        ...       " * Copyright*-",
        ...       " */",
        ...       "#include"])
        (0, 3)

        >>> func(["",
        ...       "/*",
        ...       " * Copyright*-",
        ...       " */",
        ...
        ...       "#include"])
        (1, 4)

        >>> func(["#include"])
        (0, 0)

        >>> func(["",
        ...       "#include"])
        (1, 1)
        """

        begin = None
        end = None
        multiline_comment = False

        for iline, line in enumerate(lines):
            stripped = line.strip()
            if multiline_comment:
                if stripped.endswith("*/"):
                    multiline_comment = False
                    end = iline + 1

            else:
                if stripped.startswith("/*"):
                    multiline_comment = True
                    if begin is None:
                        begin = iline
                elif not stripped:
                    if begin is not None:
                        break
                else:
                    if begin is None:
                        begin = iline
                        end = iline
                    break

        return begin, end


class FileTypeShell(FileTypePython):
    """ Matches python files """
    @staticmethod
    def matches(fn):
        if fn.endswith(".sh"):
            return True

        # has shell shebang?
        lines = read_file(fn)
        if lines:
            if re.match("#!.*(?:/bash|/sh|/dash).*", lines[0]):
                return True

        return False


class FileTypeINI(FileTypePython):
    """ Matches python files """
    @staticmethod
    def matches(fn):
        if fn.endswith(".ini"):
            return True

        lines = read_file(fn)
        for line in lines:
            line = line.strip()
            if line and \
               not line.startswith("#"):
                if re.match("^\[.*\].*", line):
                    return True
                else:
                    return False


class FileTypeDesktopShortcut(FileTypeINI):
    """ Matches desktop files """
    @staticmethod
    def matches(fn):
        extensions = ["desktop", "desktop.in"]
        for ext in extensions:
            if fn.endswith("." + ext):
                return True
        return False

    def needs_license_header(self):
        return False


class FileTypeXML(FileType):
    """ Matches python files """
    @staticmethod
    def matches(fn):
        extensions = ["xml", "html", "colors", "theme", "onboard"]
        for ext in extensions:
            if fn.endswith("." + ext):
                return True

        return False

    def comment_lines(self, lines):
        return ["<!--"] + \
               ["" + line for line in lines] + \
               ["-->"]

    def find_insert_range(self, lines):
        """
        Find the line range of the existing license comment.

        Doctests:
        >>> func = FileTypeXML().find_insert_range

        >>> func(['<?xml version="1.0" ?>',
        ...       "<!---",
        ...       "Copyright",
        ...       "-->",
        ...       "",
        ...       "<keyboard>"])
        (1, 4)
        """

        begin = None
        end = None
        multiline_comment = False

        for iline, line in enumerate(lines):
            stripped = line.strip()
            if multiline_comment:
                if stripped.endswith("-->"):
                    multiline_comment = False
                    end = iline + 1
            else:
                # encoding line?
                if stripped.startswith("<?"):
                    begin = None
                elif stripped.startswith("<!--"):
                    begin = None
                    multiline_comment = True
                    begin = iline
                elif not stripped:
                    if begin is not None:
                        break
                else:
                    if begin is None:
                        begin = iline
                        end = iline
                    break

        return begin, end


class FileTypePO(FileType):
    """ Matches *.po translation files """
    @staticmethod
    def matches(fn):
        extensions = ["po"]
        for ext in extensions:
            if fn.endswith("." + ext):
                return True

        return False

    def comment_lines(self, lines):
        return ["#" + (" " if line else "") + line for line in lines]

    def find_insert_range(self, lines):
        """
        Find the line range of the existing license comment.

        Doctests:
        >>> func = FileTypePO().find_insert_range

        >>> func(['msgid ""'])
        (0, 0)

        >>> func(['',
        ...       '',
        ...       'msgid ""'])
        (0, 2)

        >>> func(['',
        ...       '# German translation for onboard',
        ...       '# Copyright (c) 2011 Rosetta Contributors',
        ...       '# This file is distributed under the same license',
        ...       '# FIRST AUTHOR <EMAIL@ADDRESS>, 2011.',
        ...       '#',
        ...       '',
        ...       'msgid ""'])
        (0, 7)
        """

        begin = None
        end = None

        for iline, line in enumerate(lines):
            stripped = line.strip()
            if stripped.startswith("#"):
                if begin is None:
                    begin = iline
            elif not stripped:
                if begin is None:
                    begin = iline
            else:
                if begin is None:
                    begin = iline
                    end = iline
                else:
                    end = iline
                break

        return begin, end

    def get_license_template(self, lines):
        template_lines = []
        footer_lines = []

        # Extract these snippets from the original license header:
        # # German translation for onboard
        #   ^--                        --^
        # # Copyright (c) 2011 Rosetta Contributors and Canonical Ltd 2011
        #   ^--                                 --^
        # and everything below the line
        # # This file is distributed under the same license...
        state = 0
        final_state = 3
        for iline, line in enumerate(lines):
            stripped = line.strip()
            if stripped.startswith("#"):

                if state == 0:
                    m = re.search("\w.+ translation for ", stripped)
                    if m:
                        template_lines.append(m.group() + "{prog}")
                        template_lines.append("")
                        state = 1

                elif state == 1:
                    m = re.search("Copyright", stripped)
                    if m:
                        template_lines.append(m.group() + ":")
                        state = 2

                # last line of the fixed part
                elif state == 2:
                    m = re.search("#\s*This file is distributed under the "
                                  "same license as.*package",
                                  stripped)
                    if m:
                        state = 3

                elif state == 3:
                    # match everything after the comment character
                    m = re.search("#\s?(.*)", stripped)
                    if m and m.groups():
                        s = m.groups()[0]
                        if s:   # no empty lines
                            footer_lines.append(s)

            elif not stripped:
                break

        # parsing failed?
        if state != final_state:
            msg = "Failed to parse original license header (state=={})." \
                  .format(state)
            raise RuntimeError(msg)

        # BSD-3-clause license
        # Onboard's strings are GPLv3, the translated ones BSD, see
        # https://help.launchpad.net/Translations/LicensingFAQ
        bsd = split_docstring("""
Copyright:
  Original strings: See files they were extracted from
  Translations: Rosetta Contributers

License:
  Original strings: GPL-3+ license
  Translations: BSD-3-clause license
\n
""")[1:]

        template_lines += bsd
        template_lines += footer_lines
        template_lines += [""]

        return "\n".join(template_lines)


class FileTypeBinary(FileType):
    """ Match binary files."""

    @staticmethod
    def matches(fn):
        return FileType.match_names(fn, [], [], [".png", ".oga"])

    def needs_license_header(self):
        return False

    def can_list_committers(self):
        return False


class FileTypeNotVersioned(FileType):
    """ Match files not under source control. """

    @staticmethod
    def matches(fn):
        return FileType.match_names(fn, ["build/"], [], [".swp", "~"])

    def needs_license_header(self):
        return False

    def can_list_committers(self):
        return False


class FileTypeNoHeader(FileType):
    """
    Match files that don't need or can't have license headers
    and are neither binary nor unversioned files.
    """
    files = []

    @staticmethod
    def matches(fn):
        return FileType.match_names(fn, [],
            ["COPYING", "README", "TODO",
                "AUTHORS", "HACKING",
                "NEWS", "CHANGELOG", "MANIFEST.in",
                ".bzrignore", "attic/Makefile"],
            [".lm", ".svg", ".ui"])

    def needs_license_header(self):
        return False


def read_file(fn):
    with open(fn, encoding="UTF-8") as f:
        try:
            lines = f.readlines()
        except UnicodeDecodeError:  # binary file?
            lines = []
    return lines


def write_file(fn, lines):
    with open(fn, encoding="UTF-8") as f:
        f.writelines(lines)


def lstrip_docstring(docstring):
    return "\n".join(split_docstring(docstring))


def split_docstring(docstring, strip_leading_spaces=False):
    lines = docstring.expandtabs().splitlines()
    if strip_leading_spaces:
        lines = [line.lstrip() for line in lines]
    else:
        lines = [line.rstrip() for line in lines]

    # cut off empty lines at top and bottom
    if not lines[0]:
        lines = lines[1:]
    if not lines[-1]:
        lines = lines[:-1]

    return lines

if __name__ == '__main__':
    if "--doctest" in sys.argv:
        import doctest
        doctest.testmod()
    else:
        main()




