#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import json
import os

def predict_butterfly(image_path, model_path='butterfly_classifier.pth'):
    """
    输入图片地址，输出蝴蝶信息
    
    Args:
        image_path: 图片文件路径
        model_path: 预训练模型路径 (默认: 'butterfly_classifier.pth')
        
    Returns:
        dict: 包含蝴蝶信息的字典，包含以下字段：
            - image: 图片路径
            - class_id: 类别ID (0-19)
            - class_name: 类别名称 (如 '001.Atrophaneura_horishanus')
            - confidence: 预测置信度 (0-1)
            - genus: 属名
            - species: 种名
            - scientific_name: 科学名称
            - common_name: 常见中文名
            - error: 错误信息 (如果有错误)
    """
    try:
        # 检查文件是否存在
        if not os.path.exists(image_path):
            return {'error': f'图片文件不存在: {image_path}'}
        
        if not os.path.exists(model_path):
            return {'error': f'模型文件不存在: {model_path}'}
        
        # 设置设备
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 加载类别映射
        with open('../image_data/Butterfly20_dict.json', 'r') as f:
            class_dict = json.load(f)
        
        # 设置模型
        model = models.resnet50(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 20)
        model = model.to(device)
        
        # 加载预训练模型
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        # 数据预处理
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # 加载和预处理图片
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        # 进行预测
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            confidence, predicted_idx = torch.max(probabilities, 1)
        
        # 获取预测结果
        predicted_idx = predicted_idx.item()
        confidence_score = confidence.item()
        class_name = class_dict[str(predicted_idx)]
        
        # 解析类别信息
        species_parts = class_name.split('_')
        genus_name = species_parts[0].split('.')[1] if len(species_parts) > 1 else "未知"
        species_name = species_parts[1] if len(species_parts) > 1 else "未知"
        
        # 获取常见中文名
        common_names = {
            'Papilio': '凤蝶',
            'Graphium': '青凤蝶',
            'Byasa': '麝凤蝶',
            'Atrophaneura': '红珠凤蝶',
            'Lamproptera': '燕凤蝶',
            'Iphiclides': '粉蝶',
            'Losaria': '斑蝶',
            'Meandrusa': '环蝶',
            'Pachliopta': '红纹凤蝶'
        }
        common_name = common_names.get(genus_name, "未知常见名")
        
        # 构建返回结果
        result = {
            'image': image_path,
            'class_id': predicted_idx,
            'class_name': class_name,
            'confidence': confidence_score,
            'genus': genus_name,
            'species': species_name,
            'scientific_name': f"{genus_name} {species_name}",
            'common_name': common_name
        }
        
        return result
        
    except Exception as e:
        return {'error': f'预测失败: {str(e)}'}

# 示例用法
if __name__ == "__main__":
    # 示例：预测一张蝴蝶图片
    example_image = "../image_data/Butterfly20_test/1.jpg"
    
    if os.path.exists(example_image):
        print(f"预测图片: {example_image}")
        result = predict_butterfly(example_image)
        
        if 'error' in result:
            print(f"错误: {result['error']}")
        else:
            print("\n预测结果:")
            print(f"图片: {result['image']}")
            print(f"类别ID: {result['class_id']}")
            print(f"类别名称: {result['class_name']}")
            print(f"置信度: {result['confidence']:.4f}")
            print(f"属: {result['genus']}")
            print(f"种: {result['species']}")
            print(f"科学名: {result['scientific_name']}")
            print(f"常见名: {result['common_name']}")
    else:
        print(f"示例图片不存在: {example_image}")
        print("请提供有效的图片路径进行测试")
