# Copyright (c) 2017-2026 Juancarlo Añez (apalala@gmail.com)
# SPDX-License-Identifier: BSD-4-Clause
from __future__ import annotations

import io
import multiprocessing
import sys
import time
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from itertools import batched
from pathlib import Path
from types import SimpleNamespace
from typing import Any, NamedTuple

try:
    import rich  # pyright: ignore[reportMissingImports]
except ImportError:
    rich = SimpleNamespace()  # ty:ignore[invalid-assignment]

from rich.progress import (  # pyright: ignore[reportMissingImports]
    BarColumn,
    Progress,
    TaskID,
    TaskProgressColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)

from . import identity, memory_use, startscript, try_read
from .datetime import iso_logpath
from .unicode_characters import U_CHECK_MARK, U_CROSSED_SWORDS

__all__: list[str] = ['parallel_proc', 'processing_loop']


EOLCH = '\r' if sys.stderr.isatty() else '\n'
sys.setrecursionlimit(2**16)


class __Task(NamedTuple):
    payload: Any
    args: Iterable[Any]
    kwargs: Mapping[str, Any]


@dataclass(slots=True)
class ParprocResult:
    payload: Any
    outcome: Any | None = None
    exception: Any | None = None
    linecount: int = 0
    time: float = 0
    memory: int = 0

    @property
    def success(self):
        return self.exception is None

    def __str__(self):
        return str(self.__dict__)


def process_payload(
    process: Callable,
    task: Any,
    pickable: Callable = identity,
    reraise: bool = False,
) -> ParprocResult | None:
    start_time = time.process_time()
    result = ParprocResult(task.payload)
    try:
        outcome = process(task.payload, *task.args, **task.kwargs)
        result.memory = memory_use()
        if hasattr(outcome, 'linecount'):
            result.linecount = outcome.linecount
        else:
            result.linecount = len(try_read(task.payload).splitlines())
        result.outcome = pickable(outcome)
    except KeyboardInterrupt:
        return None
    except Exception as e:
        result.exception = e
    finally:
        result.time = time.process_time() - start_time

    return result


def _executor_pmap(
    executor: Callable,
    process: Callable,
    tasks: Sequence[Any],
) -> Iterable[ParprocResult]:
    nworkers = max(1, multiprocessing.cpu_count())
    n = nworkers * 8
    chunks = batched(tasks, n)
    for chunk in chunks:
        with executor(max_workers=nworkers) as ex:
            futures = [ex.submit(process, task) for task in chunk]
            for future in as_completed(futures):
                yield future.result()


def _thread_pmap(process: Callable, tasks: Sequence[Any]) -> Iterable[ParprocResult]:
    yield from _executor_pmap(ThreadPoolExecutor, process, tasks)


def _process_pmap(process: Callable, tasks: Sequence[Any]) -> Iterable[ParprocResult]:
    yield from _executor_pmap(ProcessPoolExecutor, process, tasks)


def _imap_pmap(process: Callable, tasks: Sequence[Any]) -> Iterable[ParprocResult]:
    nworkers = max(1, multiprocessing.cpu_count())

    n = nworkers * 4
    chunks = batched(tasks, n)

    count = 0
    for chunk in chunks:
        count += len(chunk)
        with multiprocessing.Pool(processes=nworkers) as pool:
            yield from pool.imap_unordered(process, chunk)
    if len(tasks) != count:
        raise RuntimeError(
            'number of chunked tasks different %d != %d' % (len(tasks), count),
        )


_active_pmap = _imap_pmap


def parallel_proc(
    payloads: Iterable[Any],
    process: Callable,
    *args: Any,
    **kwargs: Any,
):
    pickable = kwargs.pop('pickable', identity)
    parallel = kwargs.pop('parallel', True)
    reraise = kwargs.pop('reraise', False)

    process = partial(process_payload, process, pickable=pickable, reraise=reraise)
    tasks = [__Task(payload, args, kwargs) for payload in payloads]

    try:
        if len(tasks) == 1:
            yield process(tasks[0])
        else:
            pmap = _active_pmap if parallel else map
            yield from pmap(process, tasks)
    except KeyboardInterrupt:
        return


def _build_progressbar(total: int) -> tuple[Progress, TaskID]:
    progress = Progress(
        TextColumn(f"[progress.description]{startscript()}"),
        BarColumn(),
        # *Progress.get_default_columns(),
        TaskProgressColumn(),
        TimeElapsedColumn(),
        TimeRemainingColumn(),
        TextColumn("[progress.description]{task.description}"),
        refresh_per_second=1,
        speed_estimate_period=30.0,
    )
    task = progress.add_task('', total=total)
    return progress, task


def processing_loop(
    filenames: Sequence[str],
    process: Callable,
    *args: Any,
    reraise: bool = False,
    **kwargs: Any,
) -> Iterable[ParprocResult]:
    try:
        total = len(filenames)
        total_time = 0.0
        run_time = 0.0
        start_time = time.time()
        results = parallel_proc(filenames, process, *args, **kwargs)
        results = results or []
        count = 0
        success_count = 0
        success_linecount = 0
        progress, progress_task = _build_progressbar(total)

        logpath = None
        if total > 1:
            prefix = startscript().replace('.', '_')
            logpath = iso_logpath(prefix=prefix)

        @contextmanager
        def logctx() -> Generator[io.TextIOBase | Any, None, None]:
            if isinstance(logpath, Path):
                with logpath.open(mode="a", encoding="utf-8") as logfile:
                    yield logfile
            else:
                yield sys.stderr

        with progress:
            for result in results:
                if result is None:
                    continue
                count += 1

                total_time = time.time() - start_time
                filename = Path(result.payload).name
                if result.exception:
                    icon = f'[red]{U_CROSSED_SWORDS}'
                else:
                    icon = f'[green]{U_CHECK_MARK}'

                progress.update(
                    progress_task,
                    advance=1,
                    description=f'{icon} {filename}',
                )

                # with logctx() as log:
                #     print(result.payload, file=log)
                if result.exception:
                    try:
                        with logctx() as log:
                            print('ERROR:', result.payload, file=log)
                            print(result.exception, file=log)
                    except Exception:
                        # in case of errors while serializing the exception
                        with logctx() as log:
                            print(
                                'EXCEPTION',
                                type(result.exception).__name__,
                                file=log,
                            )
                    if reraise:
                        raise result.exception
                elif result.outcome is not None:
                    success_count += 1
                    success_linecount += result.linecount
                    run_time += result.time
                    yield result

            progress.update(progress_task, advance=0, description='')
            progress.stop()
        with logctx() as log:
            file_process_summary(
                filenames,
                total_time,
                run_time,
                success_count,
                success_linecount,
                log,
            )
    except KeyboardInterrupt:
        return


def file_process_progress(
    latest_result: ParprocResult,
    count: int,
    total: int,
    total_time: float,
):
    filename = latest_result.payload

    percent = count / total
    mb_memory = (latest_result.memory + memory_use()) // (1024 * 1024)

    eta = (total - count) * 0.8 * total_time / (0.2 * count)
    bar = '[%-24s]' % ('#' * round(24 * percent))

    print(
        '%3d/%-3d' % (count, total),
        bar,
        '%3d%%' % (100 * percent),
        # format_hours(total_time),
        f'{format_hours(eta)}ETA',
        format_minutes(latest_result),
        '%3dMiB' % mb_memory if mb_memory else '',
        (Path(filename).name + ' ' * 80)[:40],
        file=sys.stderr,
        end=EOLCH,
    )


def format_minutes(result: ParprocResult) -> str:
    return f'{result.time / 60:3.0f}:{result.time % 60:04.1f}'


def format_hours(time: float) -> str:
    return f'{time // 3600:2.0f}:{(time // 60) % 60:02.0f}:{time % 60:02.0f}'


def file_process_summary(
    filenames: Sequence[str],
    total_time: float,
    run_time: float,
    success_count: int,
    success_linecount: int,
    log,
):
    filecount = 0
    linecount = 0
    for fname in filenames:
        filecount += 1

        nlines = len(try_read(fname).splitlines())
        linecount += nlines

    failures = filecount - success_count

    summary_text = '''\
        ──────────────────────────────────────────────────────────────────────
        {:12,d}   files input
        {:12,d}   source lines input
        {:12,d}   total lines processed
        {:12,d}   successes
        {:12,d}   failures
        {:12.1f}%  success rate
         {:>13s}  time
         {:>13s}  run time
    '''
    summary_text = '\n'.join(s.strip() for s in summary_text.splitlines())

    summary = summary_text.format(
        filecount,
        linecount,
        success_linecount,
        success_count,
        failures,
        100 * success_count / filecount if filecount != 0 else 0,
        format_hours(total_time),
        format_hours(run_time),
    )
    print(summary, file=log)
    print(EOLCH + 80 * ' ', file=sys.stderr)

    if log != sys.stderr:
        print(summary, file=sys.stderr)
    if failures:
        rich.print(f'[red bold]FAILURES: [green]{log.name}')
        print(file=sys.stderr)
        sys.exit(1)
