import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable

import numpy as np
import pdb, os, argparse
import datetime

from model.OCINet_models import OCINet
from data import get_loader
from data import test_dataset
from utils import clip_gradient, adjust_lr
from scipy import misc
import imageio
import time
import cv2
from PIL import Image
import logging
import pytorch_iou

CE = torch.nn.BCEWithLogitsLoss()
IOU = pytorch_iou.IOU(size_average = True)

torch.cuda.set_device(0)
# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),  # 输出到控制台
        logging.FileHandler('train_debug.log')  # 输出到文件
    ]
)
logger = logging.getLogger(__name__)

torch.cuda.set_device(0)

def run():
    logger.info(f"开始训练，总轮数: {opt.epoch}")
    for epoch in range(1, opt.epoch):
        logger.info(f"=== 第 {epoch}/{opt.epoch} 轮 ===")
        adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
        model.train()
        
        # 记录每个epoch的时间
        epoch_start_time = time.time()
        batch_times = []
        
        for i, pack in enumerate(train_loader, start=1):
            batch_start_time = time.time()
            
            optimizer.zero_grad()
            images, gts = pack
            images = Variable(images)
            gts = Variable(gts)
            images = images.cuda()
            gts = gts.cuda()

            # 前向传播
            s1, s2, s3, s4, s1_sig, s2_sig, s3_sig, s4_sig = model(images)
            
            # 计算损失
            loss = CE(s1, gts)+IOU(s1_sig, gts)+ (CE(s2, gts)+IOU(s2_sig, gts)) \
                   + (CE(s3, gts)+IOU(s3_sig, gts))/2 + (CE(s4, gts)+IOU(s4_sig, gts))/4
            
            # 反向传播
            loss.backward()
            clip_gradient(optimizer, opt.clip)
            optimizer.step()
            
            # 记录batch时间
            batch_time = time.time() - batch_start_time
            batch_times.append(batch_time)
            
            if i % 10 == 0 or i == total_step:
                avg_batch_time = sum(batch_times[-10:]) / len(batch_times[-10:])
                remaining_batches = total_step - i
                eta = datetime.timedelta(seconds=int(avg_batch_time * remaining_batches))
                
                logger.info(
                    f"Epoch [{epoch:03d}/{opt.epoch}], Step [{i:04d}/{total_step}], "
                    f"Loss: {loss.data:.4f}, LR: {opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch):.6f}, "
                    f"BatchTime: {batch_time:.3f}s, ETA: {eta}"
                )
        
        # 记录epoch时间
        epoch_time = time.time() - epoch_start_time
        logger.info(f"Epoch [{epoch}] 完成, 用时: {epoch_time:.2f}秒, 平均batch时间: {sum(batch_times)/len(batch_times):.3f}s")
        
        # 保存模型
        save_path = 'models/'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if (epoch + 1) % 5 == 0:
            model_path = save_path + 'OCINet.pth' + '.%d' % epoch
            torch.save(model.state_dict(), model_path, _use_new_zipfile_serialization=False)
            logger.info(f"模型保存到: {model_path}")

# 主程序开始
logger.info("程序启动: Let's go!")
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=70, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--batchsize', type=int, default=16, help='training batch size')
parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate')
opt = parser.parse_args()

# 构建模型
logger.info("构建模型...")
model = OCINet()
model.cuda()
params = model.parameters()
optimizer = torch.optim.Adam(params, opt.lr)

# 数据加载
logger.info("加载数据集...")
BASE_DIR = r'\root\OCINet-main'
datasets = r'Rail\2086'
image_root = os.path.join(BASE_DIR, 'dataset', datasets, 'img\\')
gt_root = os.path.join(BASE_DIR, 'dataset', datasets, 'gt\\')

logger.info(f"图像路径: {image_root}")
logger.info(f"标注路径: {gt_root}")

# 检查路径是否存在
if not os.path.exists(image_root):
    logger.error(f"图像路径不存在: {image_root}")
    exit(1)
if not os.path.exists(gt_root):
    logger.error(f"标注路径不存在: {gt_root}")
    exit(1)

# 获取数据加载器
try:
    train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
    total_step = len(train_loader)
    logger.info(f"数据集加载成功, 总批次数: {total_step}, 每批大小: {opt.batchsize}")
except Exception as e:
    logger.error(f"数据加载失败: {str(e)}")
    exit(1)

# 打印学习率
logger.info(f'学习率: {opt.lr}')

# 检查GPU可用性
logger.info(f"CUDA 可用: {torch.cuda.is_available()}")
logger.info(f"GPU 数量: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    logger.info(f"当前 GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"GPU 内存使用: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")

# 运行训练
try:
    run()
    logger.info("训练完成!")
except Exception as e:
    logger.error(f"训练过程中出错: {str(e)}")
    import traceback
    logger.error(traceback.format_exc())


