# -*- coding: utf-8 -*-
#
# Copyright (C) 2020-2021 Matthias Klumpp <matthias@tenstral.net>
#
# Licensed under the GNU Lesser General Public License Version 3
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the license, or
# (at your option) any later version.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this software.  If not, see <http://www.gnu.org/licenses/>.

import os
import json
import struct
import logging as log
from enum import IntEnum
from uuid import UUID
from datetime import datetime

import numpy as np
from xxhash import xxh3_64

from .. import ureg

__all__ = ['TSyncFile', 'TSyncFileMode', 'TSyncTimeUnit']


class TSyncFileMode(IntEnum):
    '''Time storage mode of a TSync file.'''

    CONTINUOUS = 0  # Continous time-point mapping with no gaps
    SYNCPOINTS = 1  # Only synchronization points are saved


class TSyncTimeUnit(IntEnum):
    '''Unit types for time representation in a TSync file.'''

    INDEX = 0
    NANOSECONDS = 1
    MICROSECONDS = 2
    MILLISECONDS = 3
    SECONDS = 4


class TSyncDataType(IntEnum):
    '''Data types use for storing time values in the data file.'''

    INVALID = 0
    INT16 = 2
    INT32 = 3
    INT64 = 4

    UINT16 = 6
    UINT32 = 7
    UINT64 = 8


TSYNC_MAGIC = int('F223434E5953548A', 16)

TSYNC_VERSION_MAJOR = 1
TSYNC_VERSION_MINOR = 2

TSYNC_BLOCK_TERM = int('1126000000000000', 16)
TSYNC_BLOCK_TERM_32 = int('11260000', 16)


def tsync_dtype_to_pack_fmt_len(dtype: TSyncDataType):
    '''Convert tsync data type into Python unpack format string and length'''
    if dtype == TSyncDataType.INT16:
        return '<h', 2
    if dtype == TSyncDataType.UINT16:
        return '<H', 2
    if dtype == TSyncDataType.INT32:
        return '<i', 4
    if dtype == TSyncDataType.UINT32:
        return '<I', 4
    if dtype == TSyncDataType.INT64:
        return '<q', 8
    if dtype == TSyncDataType.UINT64:
        return '<Q', 8
    raise RuntimeError('No data defined for how to unpack type {}'.format(dtype))


def tsync_time_unit_to_punit(unit: TSyncTimeUnit):
    if unit == TSyncTimeUnit.INDEX:
        return ureg.dimensionless
    if unit == TSyncTimeUnit.NANOSECONDS:
        return ureg.nsec
    if unit == TSyncTimeUnit.MICROSECONDS:
        return ureg.usec
    if unit == TSyncTimeUnit.MILLISECONDS:
        return ureg.msec
    if unit == TSyncTimeUnit.SECONDS:
        return ureg.sec
    raise ValueError('Can not convert tsync time unit type "{}" to Pint unit type.'.format(unit))


def read_utf8_xxh_from_file(f, xxh):
    '''Read UTF-8 encoded string from binary .tsync file'''

    (length,) = struct.unpack('<I', f.read(4))
    if length == int('ffffffff', 16):
        return ''

    if length > (os.fstat(f.fileno()).st_size - f.tell() + 1):
        raise ValueError('String length in binary too long ({}).'.format(length))

    data = f.read(length)
    xxh.update(data)
    return str(data, 'utf-8')


class TSyncFile:
    '''
    Read a TimeSync (.tsync) binary file as generated by the
    Syntalos DAQ system.
    '''

    def __init__(self, fname=None):
        self._format_version = '1.0'
        self._time_created = None
        self._generator_name = ''
        self._collection_id = UUID(int=0x00)
        self._ts_mode = TSyncFileMode.CONTINUOUS
        self._block_size = 128
        self._custom = {}
        self._time_labels = ('A', 'B')
        self._time_units = (
            tsync_time_unit_to_punit(TSyncTimeUnit.MICROSECONDS),
            tsync_time_unit_to_punit(TSyncTimeUnit.MICROSECONDS),
        )
        self._times = np.empty((0, 2))
        if fname:
            self.open(fname)

    @property
    def time_created(self):
        return self._time_created

    @property
    def tolerance(self) -> int:
        '''The tolerance range value, in microseconds'''
        return self._custom.get('tolerance_us', 0)

    @tolerance.setter
    def tolerance(self, usec: int):
        self._custom['tolerance_us'] = usec

    @property
    def generator_name(self) -> str:
        '''Name of the module that generated this file.'''
        return self._generator_name

    @generator_name.setter
    def generator_name(self, name: str):
        self._generator_name = name

    @property
    def collection_id(self) -> UUID:
        '''Data collection ID this file belongs to.'''
        return self._collection_id

    @collection_id.setter
    def collection_id(self, uuid: UUID):
        self._collection_id = uuid

    @property
    def sync_mode(self) -> TSyncFileMode:
        '''Time data storage mode..'''
        return self._ts_mode

    @sync_mode.setter
    def sync_mode(self, mode: TSyncFileMode):
        self._ts_mode = mode

    @property
    def custom(self) -> dict:
        '''User-defined custom üroperties of this file.'''
        return self._custom

    @custom.setter
    def custom(self, v: dict):
        self._custom = v

    @property
    def time_labels(self):
        '''Labels of the two encoded times.'''
        return self._time_labels

    @time_labels.setter
    def time_labels(self, v):
        self._time_labels = v

    @property
    def time_units(self):
        '''Units of the two encoded times.'''
        return self._time_units

    @time_units.setter
    def time_units(self, v):
        self._time_units = v

    @property
    def times(self):
        '''The actual time values of the two clocks.'''
        return self._times

    @times.setter
    def times(self, v):
        self._times = v

    def _read_xxh_unpack(self, format, buffer):
        self._xxh.update(buffer)
        (v,) = struct.unpack(format, buffer)
        return v

    def _read_utf8_xxh_from_file(self, f):
        return read_utf8_xxh_from_file(f, self._xxh)

    def open(self, fname):
        with open(fname, 'rb') as f:
            (magic_number,) = struct.unpack('<Q', f.read(8))
            if magic_number != TSYNC_MAGIC:
                raise ValueError('Unrecognized file type: This file is no tsync file.')

            # read file header block
            self._xxh = xxh3_64()
            major_version = self._read_xxh_unpack('<H', f.read(2))
            minor_version = self._read_xxh_unpack('<H', f.read(2))
            self._format_version = '{}.{}'.format(major_version, minor_version)
            if major_version != TSYNC_VERSION_MAJOR or minor_version > TSYNC_VERSION_MINOR:
                raise ValueError(
                    'Can not read TSync format version {} (max {}.{})'.format(
                        self._format_version, TSYNC_VERSION_MAJOR, TSYNC_VERSION_MINOR
                    )
                )
            log.debug('Reading tsync {} file: {}'.format(self._format_version, fname))
            check_xxh = major_version >= 1 and minor_version >= 2
            if not check_xxh:
                log.warning(
                    'Tsync file version ({}) is too old, checksum validation for '
                    'integrity checks will be skipped.'.format(self._format_version)
                )

            self._time_created = datetime.utcfromtimestamp(self._read_xxh_unpack('<q', f.read(8)))
            self._generator_name = self._read_utf8_xxh_from_file(f)
            self._collection_id = UUID(self._read_utf8_xxh_from_file(f))
            user_json_raw = self._read_utf8_xxh_from_file(f)
            self._custom = {}
            if user_json_raw:
                self._custom = json.loads(user_json_raw)

            self._ts_mode = TSyncFileMode(self._read_xxh_unpack('<H', f.read(2)))
            self._block_size = self._read_xxh_unpack('<i', f.read(4))

            time1Name = self._read_utf8_xxh_from_file(f)
            time1Unit = TSyncTimeUnit(self._read_xxh_unpack('<H', f.read(2)))
            time1DType = TSyncDataType(self._read_xxh_unpack('<H', f.read(2)))

            time2Name = self._read_utf8_xxh_from_file(f)
            time2Unit = TSyncTimeUnit(self._read_xxh_unpack('<H', f.read(2)))
            time2DType = TSyncDataType(self._read_xxh_unpack('<H', f.read(2)))

            self._time_labels = (time1Name, time2Name)
            self._time_units = (
                tsync_time_unit_to_punit(time1Unit),
                tsync_time_unit_to_punit(time2Unit),
            )

            # skip alignment padding
            padding = (f.tell() * -1) & (8 - 1)
            self._xxh.update(f.read(padding))

            # check header CRC
            if check_xxh:
                term_bytecount = 16
                (block_term,) = struct.unpack('<Q', f.read(8))
                (expected_header_cs,) = struct.unpack('<Q', f.read(8))
                if block_term != TSYNC_BLOCK_TERM:
                    raise ValueError(
                        'Header block terminator not found: The file is either '
                        'invalid or its header block was damaged.'
                    )
                if expected_header_cs != self._xxh.intdigest():
                    raise ValueError(
                        'Header checksum mismatch: The file is either invalid or '
                        'its header block was damaged.'
                    )
            else:
                term_bytecount = 8
                (block_term,) = struct.unpack('<I', f.read(4))
                f.read(4)
                if block_term != TSYNC_BLOCK_TERM_32:
                    # check if we maybe had no padding due to an erroneous writer
                    f.seek((padding + 4 + 4) * -1, os.SEEK_CUR)
                    (block_term,) = struct.unpack('<I', f.read(4))
                    f.read(4)
                    if block_term != TSYNC_BLOCK_TERM_32:
                        raise ValueError(
                            'Header block terminator not found: The file is either '
                            'invalid or its header block was damaged.'
                        )

            self._xxh.reset()

            tfmt1, tlen1 = tsync_dtype_to_pack_fmt_len(time1DType)
            tfmt2, tlen2 = tsync_dtype_to_pack_fmt_len(time2DType)

            self._times = np.empty((0, 2))
            bytes_per_entry = tlen1 + tlen2
            bytes_per_block = bytes_per_entry * self._block_size + term_bytecount

            bytes_remaining = os.fstat(f.fileno()).st_size - f.tell()
            if bytes_remaining <= 0:
                # no data is present
                return

            whole_block_count = bytes_remaining // bytes_per_block
            last_block_bytes_remaining = bytes_remaining - (whole_block_count * bytes_per_block)
            if last_block_bytes_remaining == 0:
                last_block_len = 0
            else:
                last_block_len = (last_block_bytes_remaining - term_bytecount) / bytes_per_entry
                if last_block_len.is_integer() and last_block_len > 0:
                    last_block_len = int(last_block_len)
                else:
                    raise ValueError(
                        'File "{}" may be corrupt: Suspicious size ({}) of '
                        'last data block.'.format(fname, last_block_len)
                    )
            entries_n = whole_block_count * self._block_size + last_block_len

            self._times = np.zeros((entries_n, 2), dtype=np.int64)
            self._block_crc = 0
            b_index = 0
            i = 0
            while True:
                if bytes_remaining == 0:
                    break

                time1 = self._read_xxh_unpack(tfmt1, f.read(tlen1))
                time2 = self._read_xxh_unpack(tfmt2, f.read(tlen2))
                bytes_remaining -= bytes_per_entry
                self._times[i] = np.array([time1, time2])

                i += 1
                b_index += 1
                if b_index == self._block_size or bytes_remaining == term_bytecount:
                    bytes_remaining -= term_bytecount

                    if not check_xxh:
                        (block_term,) = struct.unpack('<I', f.read(4))
                        f.read(4)
                        if block_term != TSYNC_BLOCK_TERM_32:
                            raise ValueError(
                                'Block terminator not found: Some data may be corrupted.'
                            )
                        b_index = 0
                        continue

                    # check validity of the block we read last
                    (block_term,) = struct.unpack('<Q', f.read(8))
                    (expected_cs,) = struct.unpack('<Q', f.read(8))
                    if block_term != TSYNC_BLOCK_TERM:
                        raise ValueError(
                            'Block terminator not found: Some data is likely corrupted.'
                        )
                    if expected_cs != self._xxh.intdigest():
                        raise ValueError('Block checksum mismatch: Some data is likely corrupted.')
                    self._xxh.reset()
                    b_index = 0
            del self._xxh


class LegacyTSyncFile:
    '''
    Read a legacy TimeSync (.tsync) binary file as generated by the
    Syntalos DAQ system (legacy variant for an older, experimental
    version of this file format that was briefly in use).
    '''

    def __init__(self, fname=None):
        self._format_version = 1
        self._time_created = None
        self._tolerance = 0
        self._generator_name = ''
        self._custom = {}
        self._time_labels = ('A', 'B')
        self._time_units = (
            tsync_time_unit_to_punit(TSyncTimeUnit.MICROSECONDS),
            tsync_time_unit_to_punit(TSyncTimeUnit.MICROSECONDS),
        )
        self._times = np.empty((0, 2))
        if fname:
            self.open(fname)

    @property
    def time_created(self):
        return self._time_created

    @property
    def tolerance(self) -> int:
        '''The tolerance range value, in microseconds'''
        return self._tolerance_us

    @tolerance.setter
    def tolerance(self, usec: int):
        self._tolerance_us = usec

    @property
    def generator_name(self) -> str:
        '''Name of the module that generated this file.'''
        return self._generator_name

    @generator_name.setter
    def generator_name(self, name: str):
        self._generator_name = name

    @property
    def sync_mode(self) -> TSyncFileMode:
        '''Time data storage mode..'''
        return TSyncFileMode.SYNCPOINTS

    @property
    def custom(self) -> dict:
        '''User-defined custom üroperties of this file.'''
        return self._custom

    @custom.setter
    def custom(self, v: dict):
        self._custom = v

    @property
    def time_labels(self):
        '''Labels of the two encoded times.'''
        return self._time_labels

    @time_labels.setter
    def time_labels(self, v):
        self._time_labels = v

    @property
    def time_units(self):
        '''Units of the two encoded times.'''
        return self._time_units

    @time_units.setter
    def time_units(self, v):
        self._time_units = v

    @property
    def times(self):
        '''The actual time values of the two clocks.'''
        return self._times

    @times.setter
    def times(self, v):
        self._times = v

    @staticmethod
    def is_legacy(fname):
        with open(fname, 'rb') as f:
            (magic_number,) = struct.unpack('<I', f.read(4))
            return magic_number == int('C6BBDFBC', 16)

    def open(self, fname):
        with open(fname, 'rb') as f:
            (magic_number,) = struct.unpack('<I', f.read(4))
            if magic_number != int('C6BBDFBC', 16):
                raise ValueError('Unrecognized file type.')

            (self._format_version,) = struct.unpack('<I', f.read(4))
            if self._format_version != 1:
                raise ValueError(
                    'Can not read TSync format version {}'.format(self._format_version)
                )
            log.debug('Reading legacy tsync file: {}'.format(fname))

            (ts,) = struct.unpack('<q', f.read(8))
            self._time_created = datetime.utcfromtimestamp(ts)
            (self._tolerance_us,) = struct.unpack('<I', f.read(4))

            xxh = xxh3_64()
            try:
                self._generator_name = read_utf8_xxh_from_file(f, xxh)
            except UnicodeDecodeError as uni_e:
                raise ValueError(
                    'This legacy tsync file is damaged and can not be read.'
                ) from uni_e

            json_raw = read_utf8_xxh_from_file(f, xxh)
            self._custom = {}
            if json_raw:
                self._custom = json.loads(json_raw)

            tlabel1 = read_utf8_xxh_from_file(f, xxh)
            tlabel2 = read_utf8_xxh_from_file(f, xxh)
            self._time_labels = (tlabel1, tlabel2)

            (tuv1,) = struct.unpack('<H', f.read(2))
            (tuv2,) = struct.unpack('<H', f.read(2))
            self._time_units = (
                tsync_time_unit_to_punit(TSyncTimeUnit(tuv1)),
                tsync_time_unit_to_punit(TSyncTimeUnit(tuv2)),
            )

            self._times = np.empty((0, 2))
            bytes_per_block = 4 + 8 + 8
            bytes_remaining = os.fstat(f.fileno()).st_size - f.tell()
            if bytes_remaining <= 0:
                # no data is present
                return

            if bytes_remaining % bytes_per_block != 0:
                raise ValueError('File may be corrupt: Not a whole number of data blocks found!')

            num_data_blocks = int(bytes_remaining / bytes_per_block)
            indices_continuous = True
            self._times = np.zeros((num_data_blocks, 2), dtype=np.int64)
            for i in range(num_data_blocks):
                (index,) = struct.unpack('<I', f.read(4))
                (time1,) = struct.unpack('<q', f.read(8))
                (time2,) = struct.unpack('<q', f.read(8))
                if index != i:
                    indices_continuous = False
                self._times[i] = np.array([time1, time2])

            if not indices_continuous:
                print('WARNING: Indices in time sync file were not continuous.')


def load_data(part_paths, aux_data_list):
    '''Entry point for automatic dataset loading.

    This function is used internally to Syntalos' .tsync files
    as data or auxiliary data.
    '''
    for fname in part_paths:
        if LegacyTSyncFile.is_legacy(fname):
            tsync = LegacyTSyncFile(fname)
        else:
            tsync = TSyncFile(fname)
        yield tsync
