import faulthandler
import os
from time import time

import numpy as np
from PIL import Image

import PaIRS_UniNa.PaIRS_PIV as PaIRS_lib
from PaIRS_UniNa.procTools import saveResults

"""
Example script for running PaIRS-UniNa PIV processing in serial mode.

User customization:
- CONFIG: general settings (paths, indices, pairing mode, output format)
- read_images(): adapt this function to your image naming convention
- saveResults(): adapt output format or output structure if needed

Note:
This script assumes a standard PaIRS-like workflow. Depending on your dataset,
you may need to customize both how images are read and how results are saved.
"""


# =========================
# USER CONFIGURATION
# =========================
# Note: read_images() and saveResults() may need customization depending on
# your input image naming convention and desired output format.
CONFIG = {
    # Input images
    "image_folder": "C:/path/to/your/images/",
    "image_prefix": "img_cam0_",
    "image_extension": ".png",
    "image_digits": 5,   # Number of digits used in input image names
    "start_index": 1,
    "end_index": 10,  # exclusive
    "pair_mode": 0,   # 0 = a/b naming, 1 = consecutive ("time-resolved"), 2 = consecutive with step 2 ("not time-resolved")

    # PaIRS configuration file
    "cfg_file": "C:/path/to/your/out.cfg",

    # Output
    "output_folder": "C:/path/to/output/",
    "output_prefix": "serialTest_",
    "output_extension": ".mat",   # supported by saveResults: ".mat" or ".plt"

    # Processing
    "num_threads": 0,  # 0 = use all logical cores
    "enable_log": True,
}


# =========================
# INITIALIZATION
# =========================
faulthandler.enable()

print("PaIRS Core Version:", PaIRS_lib.Version(PaIRS_lib.MOD_PaIRS))
print("PaIRS PIV Version:", PaIRS_lib.Version(PaIRS_lib.MOD_PIV))

flagStop = 0


# =========================
# CALLBACK FOR STOPPING PIV
# =========================
class WrapperOutFromPIV(PaIRS_lib.PyFunOutPIV):
    """
    Callback class used by PaIRS during processing.
    If FunOut returns a non-zero value, the computation stops.
    """

    def __init__(self):
        super().__init__()

    def FunOut(self, a, b, o):
        return flagStop


# =========================
# IMAGE UTILITIES
# =========================
def image_to_float(dtype):
    """Convert image to contiguous float array."""
    return lambda image: np.ascontiguousarray(image, dtype=dtype)


Image2PIV_Float = image_to_float(
    np.float64 if PaIRS_lib.SizeOfReal() == 8 else np.float32
)


def read_images(img_index):
    """
    Read one image pair from disk.

    IMPORTANT:
    This function may need to be adapted depending on how the user named
    the input images.

    Typical cases:
    - pair_mode = 0: prefix + 'a'/'b' + index
    - pair_mode = 1 or 2: consecutive frames (i, i+1)

    In custom datasets, path_a and path_b may need to be modified.
    """
    folder = CONFIG["image_folder"]
    prefix = CONFIG["image_prefix"]
    ext = CONFIG["image_extension"]
    digits = CONFIG["image_digits"]
    mode = CONFIG["pair_mode"]

    if mode == 0:
        # Example:
        #   prefixa00001.png
        #   prefixb00001.png
        path_a = os.path.join(folder, f"{prefix}a{img_index:0{digits}d}{ext}")
        path_b = os.path.join(folder, f"{prefix}b{img_index:0{digits}d}{ext}")

    elif mode in (1, 2):
        # Example:
        #   prefix00001.png
        #   prefix00002.png
        path_a = os.path.join(folder, f"{prefix}{img_index:0{digits}d}{ext}")
        path_b = os.path.join(folder, f"{prefix}{img_index + 1:0{digits}d}{ext}")

    else:
        raise ValueError("Invalid pair_mode. Use 0, 1, or 2.")

    image_a = Image2PIV_Float(Image.open(path_a))
    image_b = Image2PIV_Float(Image.open(path_b))

    return [image_a, image_b]


# =========================
# OUTPUT HELPER
# =========================
class OutputData:
    """Minimal helper object compatible with procTool.saveResults()."""

    def __init__(self, output_folder, output_prefix, output_extension, nimg):
        self.outExt = output_extension
        self.nimg = nimg
        self.ndig = len(str(self.nimg)) if self.nimg > 0 else 1
        self.outPathRoot = os.path.join(output_folder, output_prefix)

    def resF(self, i, string=''):
        if self.ndig < -1:
            return ''

        if isinstance(i, str):
            return f"{self.outPathRoot}_{i}{self.outExt}"
        elif isinstance(i, int):
            return f"{self.outPathRoot}_{i:0{self.ndig}d}{self.outExt}"
        else:
            return ''


# =========================
# MAIN PROCESSING
# =========================
def main():
    """
    Main processing loop.

    Users typically customize:
    - CONFIG
    - read_images()

    Advanced users may also customize:
    - saveResults() behavior (output format, naming, additional outputs)
    """
    PIV = PaIRS_lib.PIV()

    # Setup callback
    fun_out = WrapperOutFromPIV()
    fun = PaIRS_lib.GetPyFunction(fun_out)

    # Load configuration
    PIV.DefaultValues()
    PIV.readCfgProc(CONFIG["cfg_file"])

    # Processing settings
    PIV.Inp.FlagLog = 1 if CONFIG["enable_log"] else 0
    PIV.Inp.FlagNumThreads = CONFIG["num_threads"]

    os.makedirs(CONFIG["output_folder"], exist_ok=True)

    start_index = CONFIG["start_index"]
    end_index = CONFIG["end_index"]
    pair_mode = CONFIG["pair_mode"]

    nimg = end_index - start_index
    if pair_mode in (1, 2):
        nimg = nimg - 1
    if pair_mode == 2:
        nimg = int(nimg / 2) + 1

    output_data = OutputData(
        CONFIG["output_folder"],
        CONFIG["output_prefix"],
        CONFIG["output_extension"],
        nimg,
    )

    var_names = ["X", "Y", "U", "V", "Fc", "info"]

    step = 2 if CONFIG["pair_mode"] == 2 else 1
    if pair_mode in (1, 2):
        end_index -= 1

    start_time = time()

    try:
        for i in range(start_index, end_index + 1, step):
            global flagStop

            print(f"Processing image {i} -> {output_data.resF(i)}")

            # Read the input image pair.
            # This is one of the main functions users may want to customize.
            PIV.SetImg(read_images(i))
            PIV.PIV_Run(fun)

            variables = [PIV.x, PIV.y, PIV.u, PIV.v, PIV.sn, PIV.Info]

            # Save results using the PaIRS utility function.
            # Users may want to customize saveResults() depending on the desired
            # output format, file naming, or output structure.
            saveResults(output_data, i, variables, var_names)
            print("")

    except SystemError as e:
        print("Error during processing:", e.__cause__)

    print(f"Total time = {time() - start_time:.2f} s")
    print(f"Threads used = {PIV.Inp.FlagNumThreads}")


if __name__ == "__main__":
    main()