import os
import sys
import time
import argparse

import numpy as np
import cv2
from tqdm import tqdm

import torch
import torch.nn.functional as F

from data import test_dataset

# 兼容你原先的警告修复逻辑
try:
    import imageio
    from PIL import Image
except Exception:
    pass


def _sigmoid(x: np.ndarray) -> np.ndarray:
    # 数值稳定一点
    x = x.astype(np.float32)
    return 1.0 / (1.0 + np.exp(-x))


def _norm01(x: np.ndarray) -> np.ndarray:
    x = x.astype(np.float32)
    mn = float(x.min())
    mx = float(x.max())
    return (x - mn) / (mx - mn + 1e-8)


class AscendOMRunner:
    """
    使用 pyACL(AscendCL Python) 加载 OM 并推理。
    说明：不同 CANN 版本的 pyACL 返回值细节可能略有差异，这里尽量按通用写法处理。
    """

    def __init__(self, om_path: str, device_id: int = 0):
        try:
            import acl  # noqa
        except Exception as e:
            print("[ERROR] import acl 失败：", repr(e))
            print("请确认你已执行：. /usr/local/Ascend/ascend-toolkit/set_env.sh")
            raise

        self.acl = sys.modules["acl"]
        self.om_path = om_path
        self.device_id = device_id

        # 常量：不同版本可能在 acl.const 里，这里做兼容
        self.ACL_MEM_MALLOC_HUGE_FIRST = getattr(getattr(self.acl, "const", self.acl), "ACL_MEM_MALLOC_HUGE_FIRST", 0)
        self.ACL_MEMCPY_HOST_TO_DEVICE = getattr(getattr(self.acl, "const", self.acl), "ACL_MEMCPY_HOST_TO_DEVICE", 1)
        self.ACL_MEMCPY_DEVICE_TO_HOST = getattr(getattr(self.acl, "const", self.acl), "ACL_MEMCPY_DEVICE_TO_HOST", 2)

        # init + set device
        ret = self.acl.init()
        if ret != 0:
            raise RuntimeError(f"acl.init failed, ret={ret}")
        ret = self.acl.rt.set_device(self.device_id)
        if ret != 0:
            raise RuntimeError(f"acl.rt.set_device({self.device_id}) failed, ret={ret}")

        # load model
        self.model_id, ret = self.acl.mdl.load_from_file(self.om_path)
        if ret != 0:
            raise RuntimeError(f"acl.mdl.load_from_file failed, ret={ret}")

        # model desc
        self.model_desc = self.acl.mdl.create_desc()
        ret = self.acl.mdl.get_desc(self.model_desc, self.model_id)
        if ret != 0:
            raise RuntimeError(f"acl.mdl.get_desc failed, ret={ret}")

        # build input/output datasets (allocate device buffers once)
        self.input_dataset, self.input_bufs = self._prepare_dataset(io_type="input")
        self.output_dataset, self.output_bufs = self._prepare_dataset(io_type="output")

        # 记录 input/output dtype（用于解析输出）
        self.input_num = self.acl.mdl.get_num_inputs(self.model_desc)
        self.output_num = self.acl.mdl.get_num_outputs(self.model_desc)

    def _acl_dtype_to_np(self, acl_dtype: int):
        # 常见类型映射（够用就行）
        acl_mod = getattr(self.acl, "const", self.acl)
        mapping = {
            getattr(acl_mod, "ACL_FLOAT16", 1): np.float16,
            getattr(acl_mod, "ACL_FLOAT", 0): np.float32,
            getattr(acl_mod, "ACL_FLOAT32", 0): np.float32,
            getattr(acl_mod, "ACL_INT32", 3): np.int32,
            getattr(acl_mod, "ACL_INT64", 9): np.int64,
            getattr(acl_mod, "ACL_UINT8", 2): np.uint8,
        }
        return mapping.get(int(acl_dtype), np.float32)

    def _prepare_dataset(self, io_type: str):
        if io_type == "input":
            io_num = self.acl.mdl.get_num_inputs(self.model_desc)
            get_size = self.acl.mdl.get_input_size_by_index
        else:
            io_num = self.acl.mdl.get_num_outputs(self.model_desc)
            get_size = self.acl.mdl.get_output_size_by_index

        dataset = self.acl.mdl.create_dataset()
        bufs = []

        for i in range(io_num):
            buf_size = int(get_size(self.model_desc, i))
            # device malloc
            dev_ptr, ret = self.acl.rt.malloc(buf_size, self.ACL_MEM_MALLOC_HUGE_FIRST)
            if ret != 0:
                raise RuntimeError(f"acl.rt.malloc failed, ret={ret}, size={buf_size}")

            data_buf = self.acl.create_data_buffer(dev_ptr, buf_size)
            dataset, ret = self.acl.mdl.add_dataset_buffer(dataset, data_buf)
            if ret != 0:
                raise RuntimeError(f"acl.mdl.add_dataset_buffer failed, ret={ret}")

            bufs.append({"buffer": dev_ptr, "data": data_buf, "size": buf_size})

        return dataset, bufs

    def infer(self, inputs: list):
        """
        inputs: List[np.ndarray]，每个都是 NCHW float32（或与 OM 输入一致）
        return: List[np.ndarray]，每个输出是 flat numpy array
        """
        if len(inputs) != self.input_num:
            raise ValueError(f"模型输入个数={self.input_num}，但你传了 {len(inputs)} 个")

        # H2D copy
        for i, arr in enumerate(inputs):
            arr = np.ascontiguousarray(arr)
            bytes_data = arr.tobytes()
            bytes_ptr = self.acl.util.bytes_to_ptr(bytes_data)
            ret = self.acl.rt.memcpy(
                self.input_bufs[i]["buffer"],
                self.input_bufs[i]["size"],
                bytes_ptr,
                len(bytes_data),
                self.ACL_MEMCPY_HOST_TO_DEVICE,
            )
            if ret != 0:
                raise RuntimeError(f"H2D memcpy failed, ret={ret}, input_idx={i}")

        # execute (同步接口)
        ret = self.acl.mdl.execute(self.model_id, self.input_dataset, self.output_dataset)
        if ret != 0:
            raise RuntimeError(f"acl.mdl.execute failed, ret={ret}")

        # D2H copy outputs
        outs = []
        for i in range(self.output_num):
            out_size = int(self.output_bufs[i]["size"])

            # host malloc
            host_ptr, ret = self.acl.rt.malloc_host(out_size)
            if ret != 0:
                raise RuntimeError(f"acl.rt.malloc_host failed, ret={ret}, out_idx={i}")

            ret = self.acl.rt.memcpy(
                host_ptr,
                out_size,
                self.output_bufs[i]["buffer"],
                out_size,
                self.ACL_MEMCPY_DEVICE_TO_HOST,
            )
            if ret != 0:
                raise RuntimeError(f"D2H memcpy failed, ret={ret}, out_idx={i}")

            # parse bytes -> numpy
            bytes_out = self.acl.util.ptr_to_bytes(host_ptr, out_size)

            # dtype
            acl_dtype = self.acl.mdl.get_output_data_type(self.model_desc, i)
            np_dtype = self._acl_dtype_to_np(acl_dtype)

            out_arr = np.frombuffer(bytes_out, dtype=np_dtype).copy()

            # free host
            self.acl.rt.free_host(host_ptr)

            outs.append(out_arr)

        return outs

    def close(self):
        # destroy datasets and buffers
        try:
            # destroy input buffers
            for item in self.input_bufs:
                self.acl.destroy_data_buffer(item["data"])
                self.acl.rt.free(item["buffer"])
            # destroy output buffers
            for item in self.output_bufs:
                self.acl.destroy_data_buffer(item["data"])
                self.acl.rt.free(item["buffer"])

            self.acl.mdl.destroy_dataset(self.input_dataset)
            self.acl.mdl.destroy_dataset(self.output_dataset)

            self.acl.mdl.destroy_desc(self.model_desc)
            self.acl.mdl.unload(self.model_id)

            self.acl.rt.reset_device(self.device_id)
            self.acl.finalize()
        except Exception:
            pass

    def __del__(self):
        self.close()


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--testsize", type=int, default=352)
    p.add_argument("--max_images", type=int, default=30)
    p.add_argument("--device_id", type=int, default=0)
    p.add_argument("--om", type=str, default="ocinet_352.om", help="ATC 转换后的 .om 路径")
    p.add_argument("--warmup", type=int, default=2, help="预热次数（不计入统计）")
    return p.parse_args()


def main():
    opt = parse_args()

    # 数据路径按你原来的写法
    dataset_path = "./dataset/Rail/1130/"
    test_datasets = ["965", "165"]
    MAX_IMAGES = opt.max_images

    # NPU runner（OM 推理）
    if not os.path.exists(opt.om):
        print(f"[ERROR] 找不到 OM 文件：{opt.om}")
        print("请先用 atc 生成 ocinet_352.om，然后用 --om 指向它。")
        sys.exit(1)

    runner = AscendOMRunner(opt.om, device_id=opt.device_id)

    print("\n" + "=" * 60)
    print("NPU(OM+pyACL) 推理模式已启用")
    print(f"OM: {opt.om}")
    print(f"Device: {opt.device_id}")
    print(f"限制: 每个数据集最多处理 {MAX_IMAGES} 张图像")
    print("=" * 60)

    for dataset in test_datasets:
        save_path = "./models/" + dataset + "/"
        os.makedirs(save_path, exist_ok=True)

        image_root = dataset_path + dataset + "/img/"
        gt_root = dataset_path + dataset + "/gt/"

        print(f"\n{'='*50}")
        print(f"开始处理数据集: {dataset}")
        print(f"图像路径: {image_root}")
        print(f"保存路径: {save_path}")
        print(f"{'='*50}")

        if not os.path.exists(image_root):
            print(f"错误: 图像路径不存在: {image_root}")
            continue

        test_loader = test_dataset(image_root, gt_root, opt.testsize)
        total_images = test_loader.size
        if total_images == 0:
            print(f"错误: 数据集 {dataset} 中没有图像")
            continue

        images_to_process = min(MAX_IMAGES, total_images)
        print(f"发现 {total_images} 张图像，将处理前 {images_to_process} 张")

        # warmup（用第一张图做）
        warmup_done = 0
        time_sum = 0.0
        processed_count = 0

        pbar = tqdm(total=images_to_process, desc=f"处理 {dataset}", unit="图像")

        for i in range(images_to_process):
            try:
                image, gt, name = test_loader.load_data()

                # image: torch tensor (1,3,H,W)
                image_np = image.detach().cpu().numpy().astype(np.float32)
                image_np = np.ascontiguousarray(image_np)

                # warmup：前 warmup 次不计时
                if warmup_done < opt.warmup:
                    _ = runner.infer([image_np])
                    warmup_done += 1
                    pbar.update(1)
                    continue

                t0 = time.time()
                outs = runner.infer([image_np])
                t1 = time.time()

                inference_time = t1 - t0
                time_sum += inference_time
                processed_count += 1

                # 取第一个输出作为 res（与你 PyTorch 返回的 res 对齐）
                res0 = outs[0]  # flat
                # 尝试按 testsize reshape 成 2D
                ts = opt.testsize
                if res0.size == ts * ts:
                    res_map = res0.reshape(ts, ts)
                elif res0.size == 1 * 1 * ts * ts:
                    res_map = res0.reshape(1, 1, ts, ts)[0, 0]
                else:
                    # 兜底：当作平方
                    side = int(np.sqrt(res0.size))
                    if side * side == res0.size:
                        res_map = res0.reshape(side, side)
                    else:
                        raise RuntimeError(f"无法解析输出形状，res0.size={res0.size}")

                # sigmoid + resize 到 gt 尺寸
                res_map = _sigmoid(res_map)

                if gt is not None:
                    gt_np = np.asarray(gt, np.float32)
                    gt_np = gt_np / (gt_np.max() + 1e-8)
                    h, w = gt_np.shape[:2]
                    if res_map.shape[0] != h or res_map.shape[1] != w:
                        res_map = cv2.resize(res_map, (w, h), interpolation=cv2.INTER_LINEAR)
                else:
                    # 没有 gt 就按原图尺寸（image_np 的 H,W）
                    h, w = image_np.shape[2], image_np.shape[3]
                    if res_map.shape[0] != h or res_map.shape[1] != w:
                        res_map = cv2.resize(res_map, (w, h), interpolation=cv2.INTER_LINEAR)

                res_map = _norm01(res_map)
                res_u8 = np.uint8(res_map * 255)

                # OTSU 二值化
                _, binary = cv2.threshold(res_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

                out_file = os.path.join(save_path, name[0:-4] + ".jpg")
                cv2.imwrite(out_file, binary)

                avg_time = time_sum / processed_count if processed_count > 0 else 0.0
                remaining = avg_time * (images_to_process - (i + 1))

                pbar.set_postfix({
                    "当前图像": name,
                    "推理时间": f"{inference_time:.3f}s",
                    "平均时间": f"{avg_time:.3f}s",
                    "剩余时间": f"{remaining:.0f}s",
                })

                if (i + 1) % 5 == 0 or i == 0 or i == images_to_process - 1:
                    print(f"  [{i+1}/{images_to_process}] 图像: {name} | "
                          f"本次: {inference_time:.3f}s | "
                          f"平均: {avg_time:.3f}s | "
                          f"剩余: {remaining:.1f}s")

                pbar.update(1)

            except Exception as e:
                print(f"处理图像 {i+1} 时出错: {e}")
                pbar.update(1)
                continue

        pbar.close()

        if processed_count > 0:
            avg_inference_time = time_sum / processed_count
            fps = processed_count / time_sum if time_sum > 0 else 0.0
            print(f"\n数据集 {dataset} 处理完成!")
            print(f"实际计时图像数(不含warmup): {processed_count}")
            print(f"总时间: {time_sum:.3f}秒")
            print(f"平均每张图像推理时间: {avg_inference_time:.3f}秒")
            print(f"FPS: {fps:.2f}")

        print(f"{'='*50}\n")

    print("\n" + "=" * 60)
    print("所有数据集处理完成!")
    print("=" * 60)


if __name__ == "__main__":
    main()
