import os
import multiprocessing as mp
from time import time

import numpy as np
from PIL import Image

from PaIRS_UniNa.parForMulti import ParForMul, ParForPool
import PaIRS_UniNa.PaIRS_PIV as PaIRS_lib
from PaIRS_UniNa.procTools import saveResults

"""
Example script for running PaIRS-UniNa PIV processing in parallel mode.

User customization:
- CONFIG: general settings (paths, indices, pairing mode, output format)
- read_images(): adapt this function to your image naming convention
- saveResults(): adapt output format or output structure if needed

Note:
This script assumes a standard PaIRS-like workflow. Depending on your dataset,
you may need to customize both how images are read and how results are saved.
"""


# =========================
# USER CONFIGURATION
# =========================
# Note: read_images() and saveResults() may need customization depending on
# your input image naming convention and desired output format.
CONFIG = {
    # Input images
    "image_folder": "C:/path/to/your/images/",
    "image_prefix": "img_cam0_",
    "image_extension": ".png",
    "image_digits": 5,   # Number of digits used in input image names
    "start_index": 1,
    "end_index": 10,  # exclusive
    "pair_mode": 0,   # 0 = a/b naming, 1 = consecutive ("time-resolved"), 2 = consecutive with step 2 ("not time-resolved")

    # PaIRS configuration file
    "cfg_file": "C:/path/to/your/out.cfg",

    # Output
    "output_folder": "C:/path/to/output/",
    "output_prefix": "serialTest_",
    "output_extension": ".mat",   # supported by saveResults: ".mat" or ".plt"

    # Processing
    "enable_log": False,
    "pool_size": mp.cpu_count(),  # Number of worker processes
}


# =========================
# IMAGE UTILITIES
# =========================
def image_to_float(dtype):
    """Convert image to contiguous float array."""
    return lambda image: np.ascontiguousarray(image, dtype=dtype)


Image2PIV_Float = image_to_float(
    np.float64 if PaIRS_lib.SizeOfReal() == 8 else np.float32
)


def read_images(img_index):
    """
    Read one image pair from disk.

    IMPORTANT:
    This function may need to be adapted depending on how the user named
    the input images.

    Typical cases:
    - pair_mode = 0: prefix + 'a'/'b' + index
    - pair_mode = 1 or 2: consecutive frames (i, i+1)

    In custom datasets, path_a and path_b may need to be modified.
    """
    folder = CONFIG["image_folder"]
    prefix = CONFIG["image_prefix"]
    ext = CONFIG["image_extension"]
    digits = CONFIG["image_digits"]
    mode = CONFIG["pair_mode"]

    if mode == 0:
        # Example:
        #   prefixa00001.png
        #   prefixb00001.png
        path_a = os.path.join(folder, f"{prefix}a{img_index:0{digits}d}{ext}")
        path_b = os.path.join(folder, f"{prefix}b{img_index:0{digits}d}{ext}")

    elif mode in (1, 2):
        # Example:
        #   prefix00001.png
        #   prefix00002.png
        path_a = os.path.join(folder, f"{prefix}{img_index:0{digits}d}{ext}")
        path_b = os.path.join(folder, f"{prefix}{img_index + 1:0{digits}d}{ext}")

    else:
        raise ValueError("Invalid pair_mode. Use 0, 1, or 2.")

    image_a = Image2PIV_Float(Image.open(path_a))
    image_b = Image2PIV_Float(Image.open(path_b))

    return [image_a, image_b]


# =========================
# OUTPUT HELPER
# =========================
class OutputData:
    """Minimal helper object compatible with procTool.saveResults()."""

    def __init__(self, output_folder, output_prefix, output_extension, nimg):
        self.outExt = output_extension
        self.nimg = nimg
        self.ndig = len(str(self.nimg)) if self.nimg > 0 else 1
        self.outPathRoot = os.path.join(output_folder, output_prefix)

    def resF(self, i, string=''):
        if self.ndig < -1:
            return ''

        if isinstance(i, str):
            return f"{self.outPathRoot}_{i}{self.outExt}"
        elif isinstance(i, int):
            return f"{self.outPathRoot}_{i:0{self.ndig}d}{self.outExt}"
        else:
            return ''


# =========================
# PARALLEL HELPERS
# =========================
class WrapperOutFromPIV(PaIRS_lib.PyFunOutPIV):
    """
    Callback class used by PaIRS during processing.
    Return a non-zero value from FunOut to stop the computation.
    """

    def __init__(self, stop_flag, stop_event):
        super().__init__()
        self.stop_flag = stop_flag
        self.stop_event = stop_event

    def FunOut(self, a, b, o):
        return 0


class AveragePIV:
    """Helper class used to accumulate and average PIV results."""

    def __init__(self):
        self.u = np.zeros(1)
        self.v = np.zeros(1)
        self.sn = np.zeros(1)
        self.info = np.zeros(1)
        self.count = 0

    def add_result(self, piv):
        """Accumulate one processed PIV result."""
        self.count += 1
        self.u = self.u + piv.u
        self.v = self.v + piv.v
        self.sn = self.sn + piv.sn
        self.info = self.info + piv.Info

    def add_average(self, other):
        """Accumulate another partial average object."""
        self.count += other.count
        self.u = self.u + other.u
        self.v = self.v + other.v
        self.sn = self.sn + other.sn
        self.info = self.info + other.info

    def compute_average(self):
        """Convert accumulated sums into averages."""
        if self.count > 0:
            self.u /= self.count
            self.v /= self.count
            self.sn /= self.count
            self.info /= self.count


class ParallelPIVParameters:
    """Container for shared processing parameters."""

    def __init__(self):
        self.flag_log = 1 if CONFIG["enable_log"] else 0
        self.average = AveragePIV()
        self.cfg_file = CONFIG["cfg_file"]
        self.output_data = None
        self.output_names = ["X", "Y", "U", "V", "Fc", "info"]


def init_piv(stop_event, image_index, proc_id, params: ParallelPIVParameters, *args, **kwargs):
    """
    Initialize one PIV instance for each worker process.
    """
    try:
        piv = PaIRS_lib.PIV()
        piv.DefaultValues()
        piv.readCfgProc(params.cfg_file)
        piv.Inp.FlagLog = params.flag_log
        piv.Inp.FlagNumThreads = 1

        piv.average = AveragePIV()
        piv.wrapper = WrapperOutFromPIV(0, stop_event)
        piv.fun = PaIRS_lib.GetPyFunction(piv.wrapper)

        flag_out, message = process_piv(image_index, proc_id, piv, params, *args, **kwargs)
        return flag_out, message, piv

    except Exception as exc:
        print(exc)
        return -1, str(exc), None


def process_piv(image_index, proc_id, piv, params: ParallelPIVParameters, *args, **kwargs):
    """
    Process one image pair.

    Users may want to customize:
    - read_images() to match their input image naming convention
    - saveResults() to match their preferred output format or structure

    Returns:
        flag_out: 1 = processed, 0 = skipped/stopped, -1 = error
        message: status message
    """
    if piv.wrapper.stop_flag != 0:
        return 0, f"Skipped {image_index} (pid {os.getpid()})"

    # Read the input image pair.
    # This is one of the main functions users may want to customize.
    images = read_images(image_index)
    piv.SetImg(images)

    try:
        result = piv.PIV_Run(piv.fun)

        if result == 0:
            piv.average.add_result(piv)

            variables = [piv.x, piv.y, piv.u, piv.v, piv.sn, piv.Info]

            # Save results using the PaIRS utility function.
            # Users may want to customize saveResults() depending on the desired
            # output format, file naming, or output structure.
            saveResults(params.output_data, image_index, variables, params.output_names)

            return 1, f"Processed image {image_index} -> {params.output_data.resF(image_index)}"

        return 0, f"Blocked {image_index}"

    except SystemError as exc:
        return -1, f"{exc.__cause__} {image_index}"


def finalize_piv(proc_id, piv, params, *args, **kwargs):
    """Return the partial average computed by one worker."""
    return piv.average


def collect_results(proc_id, worker_has_run, partial_average, params, *args, **kwargs):
    """Collect partial averages from all workers."""
    if worker_has_run:
        params.average.add_average(partial_average)
    return params.average


def callback(new_data_available, percentage, proc_id, task_flag, item_name, task_message):
    """
    Callback called during parallel execution.
    Return True to request interruption.
    """
    if not new_data_available:
        return False

    print(f"Processed image {item_name} -> {task_message}")
    return False


# =========================
# MAIN PROCESSING
# =========================
def main():
    """
    Main processing loop.

    Users typically customize:
    - CONFIG
    - read_images()

    Advanced users may also customize:
    - saveResults() behavior (output format, naming, additional outputs)
    """
    print("PaIRS Core Version:", PaIRS_lib.Version(PaIRS_lib.MOD_PaIRS))
    print("PaIRS PIV Version:", PaIRS_lib.Version(PaIRS_lib.MOD_PIV))

    start_time = time()

    os.makedirs(CONFIG["output_folder"], exist_ok=True)

    start_index = CONFIG["start_index"]
    end_index = CONFIG["end_index"]
    pair_mode = CONFIG["pair_mode"]

    nimg = end_index - start_index
    if pair_mode in (1, 2):
        nimg = nimg - 1
    if pair_mode == 2:
        nimg = int(nimg / 2) + 1

    output_data = OutputData(
        CONFIG["output_folder"],
        CONFIG["output_prefix"],
        CONFIG["output_extension"],
        nimg,
    )

    step = 2 if pair_mode == 2 else 1
    if pair_mode in (1, 2):
        end_index -= 1

    image_indices = range(start_index, end_index + 1, step)

    params = ParallelPIVParameters()
    params.output_data = output_data

    parallel_for = ParForMul()
    parallel_for.numCoresParPool = parallel_for.numUsedCores = CONFIG["pool_size"]

    pool_manager = ParForPool()
    pool_manager.startParPool(CONFIG["pool_size"])

    args = (params,)
    kwargs = {}

    average_result, flag_processed, messages, flag_errors = parallel_for.parForExtPool(
        pool_manager.parPool,
        process_piv,
        image_indices,
        initTask=init_piv,
        finalTask=finalize_piv,
        wrapUp=collect_results,
        callBack=callback,
        *args,
        **kwargs,
    )

    pool_manager.closeParPool()

    average_result.compute_average()
    print("")
    print(f"Average U sum = {average_result.u.ravel().sum()} | Count = {average_result.count}")
    print(f"Total time = {time() - start_time:.2f} s")


if __name__ == "__main__":
    main()