#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.models as models
from data_loader import get_data_loaders
import time
import os
import json
from tqdm import tqdm

class ButterflyClassifier:
    """蝴蝶分类器"""
    
    def __init__(self, num_classes=20, device=None):
        """
        初始化分类器
        
        Args:
            num_classes: 类别数量
            device: 训练设备 (cuda 或 cpu)
        """
        self.num_classes = num_classes
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        
        # 读取类别映射
        with open('../image_data/Butterfly20_dict.json', 'r') as f:
            self.class_dict = json.load(f)
        
        print(f"使用设备: {self.device}")
        self._setup_model()
    
    def _setup_model(self):
        """设置模型架构"""
        # 使用预训练的ResNet50
        self.model = models.resnet50(pretrained=True)
        
        # 替换最后的全连接层
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, self.num_classes)
        
        self.model = self.model.to(self.device)
        
        # 定义损失函数和优化器
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        
        # 学习率调度器
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=7, gamma=0.1)
    
    def train(self, train_loader, val_loader, num_epochs=25):
        """训练模型"""
        since = time.time()
        
        best_model_wts = None
        best_acc = 0.0
        
        # 训练历史记录
        history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_acc': []
        }
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)
            
            # 每个epoch都有训练和验证阶段
            for phase in ['train', 'val']:
                if phase == 'train':
                    self.model.train()  # 训练模式
                    dataloader = train_loader
                else:
                    self.model.eval()   # 评估模式
                    dataloader = val_loader
                
                running_loss = 0.0
                running_corrects = 0
                
                # 迭代数据
                for inputs, labels in tqdm(dataloader, desc=f'{phase} epoch {epoch}'):
                    inputs = inputs.to(self.device)
                    labels = labels.to(self.device)
                    
                    # 清零梯度
                    self.optimizer.zero_grad()
                    
                    # 前向传播
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = self.criterion(outputs, labels)
                        
                        # 反向传播 + 优化（仅在训练阶段）
                        if phase == 'train':
                            loss.backward()
                            self.optimizer.step()
                    
                    # 统计
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                
                if phase == 'train':
                    self.scheduler.step()
                
                epoch_loss = running_loss / len(dataloader.dataset)
                epoch_acc = running_corrects.double() / len(dataloader.dataset)
                
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                
                # 记录历史
                if phase == 'train':
                    history['train_loss'].append(epoch_loss)
                    history['train_acc'].append(epoch_acc.item())
                else:
                    history['val_loss'].append(epoch_loss)
                    history['val_acc'].append(epoch_acc.item())
                
                # 深拷贝模型（如果是最好的）
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = self.model.state_dict().copy()
            
            print()
        
        time_elapsed = time.time() - since
        print(f'训练完成于 {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'最佳验证准确率: {best_acc:.4f}')
        
        # 加载最佳模型权重
        self.model.load_state_dict(best_model_wts)
        
        return history
    
    def predict(self, test_loader):
        """对测试集进行预测"""
        self.model.eval()
        all_preds = []
        all_filenames = []
        
        with torch.no_grad():
            for inputs, labels, filenames in tqdm(test_loader, desc='预测'):
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_filenames.extend(filenames)
        
        return all_preds, all_filenames
    
    def save_model(self, path='butterfly_model.pth'):
        """保存模型"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'class_dict': self.class_dict,
            'num_classes': self.num_classes
        }, path)
        print(f"模型已保存到 {path}")
    
    def load_model(self, path='butterfly_model.pth'):
        """加载模型"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"模型已从 {path} 加载")

def main():
    """主训练函数"""
    # 获取数据加载器
    train_loader, val_loader, test_loader = get_data_loaders(batch_size=32)
    
    # 创建分类器
    classifier = ButterflyClassifier(num_classes=20)
    
    # 训练模型
    print("开始训练模型...")
    history = classifier.train(train_loader, val_loader, num_epochs=20)
    
    # 保存模型
    classifier.save_model('butterfly_classifier.pth')
    
    # 在测试集上进行预测
    print("在测试集上进行预测...")
    predictions, filenames = classifier.predict(test_loader)
    
    # 生成结果文件
    generate_results(predictions, filenames, classifier.class_dict)
    
    print("训练和预测完成!")

def generate_results(predictions, filenames, class_dict):
    """生成结果文件"""
    # 按文件名排序（确保顺序正确）
    sorted_indices = sorted(range(len(filenames)), key=lambda i: int(filenames[i].split('.')[0]))
    sorted_predictions = [predictions[i] for i in sorted_indices]
    sorted_filenames = [filenames[i] for i in sorted_indices]
    
    # 写入结果文件
    with open('model_result.txt', 'w') as f:
        for pred_idx in sorted_predictions:
            class_name = class_dict[str(pred_idx)]
            f.write(f"{class_name}\n")
    
    print(f"结果已保存到 model_result.txt")
    print(f"共 {len(sorted_predictions)} 个预测结果")

if __name__ == "__main__":
    main()
