#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import json
import os

class ButterflyPredictor:
    """蝴蝶预测器类"""
    
    def __init__(self, model_path='butterfly_classifier.pth'):
        """
        初始化蝴蝶预测器
        
        Args:
            model_path: 预训练模型路径
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"使用设备: {self.device}")
        
        # 加载类别映射
        with open('../image_data/Butterfly20_dict.json', 'r') as f:
            self.class_dict = json.load(f)
        
        # 加载属和种信息
        self.genus_info = self._load_genus_info()
        self.species_info = self._load_species_info()
        
        # 设置模型
        self.model = self._setup_model()
        
        # 加载预训练模型
        self.load_model(model_path)
        
        # 数据预处理
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def _setup_model(self):
        """设置模型架构"""
        model = models.resnet50(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 20)  # 20个类别
        model = model.to(self.device)
        return model
    
    def _load_genus_info(self):
        """加载属信息"""
        genus_info = {}
        with open('genus.txt', 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split(' ', 1)
                if len(parts) == 2:
                    genus_id, genus_name = parts
                    genus_info[genus_id] = genus_name
        return genus_info
    
    def _load_species_info(self):
        """加载种信息"""
        species_info = {}
        with open('species.txt', 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split(' ', 1)
                if len(parts) == 2:
                    species_id, species_full_name = parts
                    species_info[species_id] = species_full_name
        return species_info
    
    def load_model(self, model_path):
        """加载预训练模型"""
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"模型文件 {model_path} 不存在")
        
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()  # 设置为评估模式
        print(f"模型已从 {model_path} 加载")
    
    def predict_image(self, image_path):
        """
        预测单张图片
        
        Args:
            image_path: 图片文件路径
            
        Returns:
            dict: 包含蝴蝶信息的字典
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"图片文件 {image_path} 不存在")
        
        try:
            # 加载和预处理图片
            image = Image.open(image_path).convert('RGB')
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            # 进行预测
            with torch.no_grad():
                outputs = self.model(image_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                confidence, predicted_idx = torch.max(probabilities, 1)
            
            # 获取预测结果
            predicted_idx = predicted_idx.item()
            confidence_score = confidence.item()
            
            # 获取类别名称
            class_name = self.class_dict[str(predicted_idx)]
            
            # 解析类别信息
            species_id = class_name.split('.')[0]
            species_parts = class_name.split('_')
            genus_name = species_parts[0].split('.')[1] if len(species_parts) > 1 else "未知"
            species_name = species_parts[1] if len(species_parts) > 1 else "未知"
            
            # 构建返回结果
            result = {
                'image_path': image_path,
                'predicted_class': predicted_idx,
                'class_name': class_name,
                'genus': genus_name,
                'species': species_name,
                'confidence': confidence_score,
                'full_species_name': f"{genus_name} {species_name}",
                'all_probabilities': probabilities.cpu().numpy()[0].tolist()
            }
            
            return result
            
        except Exception as e:
            raise Exception(f"预测过程中出错: {str(e)}")
    
    def get_butterfly_info(self, image_path):
        """
        获取蝴蝶信息 - 主要接口函数
        
        Args:
            image_path: 图片文件路径
            
        Returns:
            dict: 包含详细蝴蝶信息的字典
        """
        prediction = self.predict_image(image_path)
        
        # 构建更详细的信息
        butterfly_info = {
            'image': image_path,
            'prediction': {
                'class_id': prediction['predicted_class'],
                'class_name': prediction['class_name'],
                'confidence': f"{prediction['confidence']:.4f}"
            },
            'taxonomy': {
                'genus': prediction['genus'],
                'species': prediction['species'],
                'full_name': prediction['full_species_name']
            },
            'additional_info': {
                'scientific_name': prediction['full_species_name'],
                'common_name': self._get_common_name(prediction['genus'], prediction['species'])
            }
        }
        
        return butterfly_info
    
    def _get_common_name(self, genus, species):
        """获取常见名称（这里可以扩展更多信息）"""
        # 这里可以添加更多蝴蝶的常见名称映射
        common_names = {
            'Papilio': '凤蝶',
            'Graphium': '青凤蝶',
            'Byasa': '麝凤蝶',
            'Atrophaneura': '红珠凤蝶',
            'Lamproptera': '燕凤蝶'
        }
        
        return common_names.get(genus, "未知常见名")

def predict_butterfly_from_image(image_path):
    """
    主要函数：输入图片地址，输出蝴蝶信息
    
    Args:
        image_path: 图片文件路径
        
    Returns:
        dict: 包含蝴蝶信息的字典
    """
    try:
        predictor = ButterflyPredictor()
        result = predictor.get_butterfly_info(image_path)
        return result
    except Exception as e:
        return {'error': str(e)}

# 示例用法
if __name__ == "__main__":
    # 示例：预测一张蝴蝶图片
    example_image = "../image_data/Butterfly20_test/1.jpg"
    
    if os.path.exists(example_image):
        print(f"预测图片: {example_image}")
        result = predict_butterfly_from_image(example_image)
        
        print("\n预测结果:")
        print(f"图片: {result['image']}")
        print(f"类别ID: {result['prediction']['class_id']}")
        print(f"类别名称: {result['prediction']['class_name']}")
        print(f"置信度: {result['prediction']['confidence']}")
        print(f"属: {result['taxonomy']['genus']}")
        print(f"种: {result['taxonomy']['species']}")
        print(f"全名: {result['taxonomy']['full_name']}")
        print(f"科学名: {result['additional_info']['scientific_name']}")
        print(f"常见名: {result['additional_info']['common_name']}")
    else:
        print(f"示例图片不存在: {example_image}")
        print("请提供有效的图片路径进行测试")
