import csv
from pathlib import Path

from config import ATTACKED_DIR, VERIFY_CSV
from common_image import sha256_file, phash_file, phash_distance, extract_watermark
from common_chain import get_contract, get_all_records_from_chain


PHASH_THRESHOLD = 15  # 轻中度传播失真阈值，可根据样本调整


def verify_one_image(img_path: Path, chain_records: list):
    current_sha = sha256_file(img_path)
    current_phash = phash_file(img_path)

    # 先精确匹配
    for r in chain_records:
        if current_sha == r["content_id"]:
            return {
                "status": "EXACT_MATCH",
                "matched_file": r["file_name"],
                "matched_wm": r["watermark_id"],
                "distance": 0,
                "extracted_wm": ""
            }

    # 再尝试水印优先
    wm = extract_watermark(img_path, wm_len_bytes=8)
    if wm:
        for r in chain_records:
            if wm == r["watermark_id"]:
                dist = phash_distance(current_phash, r["phash_value"])
                if dist <= PHASH_THRESHOLD:
                    return {
                        "status": "SIMILAR_VARIANT",
                        "matched_file": r["file_name"],
                        "matched_wm": r["watermark_id"],
                        "distance": dist,
                        "extracted_wm": wm
                    }
                else:
                    return {
                        "status": "TAMPERED",
                        "matched_file": r["file_name"],
                        "matched_wm": r["watermark_id"],
                        "distance": dist,
                        "extracted_wm": wm
                    }

    # 最后 pHash 兜底
    best_record = None
    best_dist = None
    for r in chain_records:
        dist = phash_distance(current_phash, r["phash_value"])
        if best_dist is None or dist < best_dist:
            best_dist = dist
            best_record = r

    if best_record is not None and best_dist <= PHASH_THRESHOLD:
        return {
            "status": "SIMILAR_VARIANT",
            "matched_file": best_record["file_name"],
            "matched_wm": best_record["watermark_id"],
            "distance": best_dist,
            "extracted_wm": wm
        }

    return {
        "status": "NOT_FOUND",
        "matched_file": "",
        "matched_wm": "",
        "distance": best_dist if best_dist is not None else -1,
        "extracted_wm": wm
    }


def main():
    VERIFY_CSV.parent.mkdir(parents=True, exist_ok=True)

    w3, contract = get_contract()
    chain_records = get_all_records_from_chain(contract)

    if not chain_records:
        print("链上没有记录，请先运行 ex4_batch_register_fullchain.py")
        return

    subdirs = [p for p in sorted(ATTACKED_DIR.iterdir()) if p.is_dir()]
    if not subdirs:
        print("未找到攻击样本目录，请先运行 ex4_batch_attack_watermarked.py")
        return

    with VERIFY_CSV.open("w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow([
            "group", "test_image", "status",
            "matched_file", "matched_wm", "extracted_wm", "phash_distance"
        ])

        for subdir in subdirs:
            print("\n" + "=" * 80)
            print(f"原图组: {subdir.name}")
            print("=" * 80)

            image_files = []
            for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.webp"):
                image_files.extend(subdir.glob(ext))
            image_files = sorted(image_files)

            for img_path in image_files:
                result = verify_one_image(img_path, chain_records)

                print("-" * 60)
                print("测试图片:", img_path.name)
                print("状态:", result["status"])
                print("匹配原图:", result["matched_file"])
                print("提取水印:", result["extracted_wm"])
                print("pHash距离:", result["distance"])

                writer.writerow([
                    subdir.name,
                    img_path.name,
                    result["status"],
                    result["matched_file"],
                    result["matched_wm"],
                    result["extracted_wm"],
                    result["distance"]
                ])

    print(f"\n批量验证完成，结果已写入: {VERIFY_CSV}")


if __name__ == "__main__":
    main()