#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
from collections import Counter
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

def analyze_training_data():
    """分析训练数据的分布和特征"""
    print("开始分析训练数据...")
    
    # 读取类别映射
    with open('../image_data/Butterfly20_dict.json', 'r') as f:
        class_dict = json.load(f)
    
    print(f"总类别数: {len(class_dict)}")
    
    # 统计每个类别的图片数量
    butterfly_dir = '~/image_data/Butterfly20'
    class_counts = {}
    image_sizes = []
    
    for class_folder in os.listdir(butterfly_dir):
        class_path = os.path.join(butterfly_dir, class_folder)
        if os.path.isdir(class_path):
            image_files = [f for f in os.listdir(class_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
            class_counts[class_folder] = len(image_files)
            
            # 随机采样一些图片分析尺寸
            if image_files:
                sample_image = os.path.join(class_path, image_files[0])
                try:
                    with Image.open(sample_image) as img:
                        image_sizes.append(img.size)
                except:
                    pass
    
    print("\n每个类别的图片数量:")
    for class_name, count in class_counts.items():
        print(f"{class_name}: {count} 张图片")
    
    total_images = sum(class_counts.values())
    print(f"\n总图片数量: {total_images}")
    
    # 分析图片尺寸
    if image_sizes:
        widths, heights = zip(*image_sizes)
        print(f"\n图片尺寸分析:")
        print(f"平均宽度: {np.mean(widths):.1f}px")
        print(f"平均高度: {np.mean(heights):.1f}px")
        print(f"最小尺寸: {min(widths)}x{min(heights)}")
        print(f"最大尺寸: {max(widths)}x{max(heights)}")
    
    return class_counts, total_images

def analyze_test_data():
    """分析测试数据"""
    print("\n分析测试数据...")
    
    test_dir = '../image_data/Butterfly20_test'
    test_images = [f for f in os.listdir(test_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
    
    print(f"测试图片数量: {len(test_images)}")
    
    # 检查测试图片的命名顺序
    test_ids = []
    for img_name in test_images:
        if img_name.endswith('.jpg'):
            try:
                img_id = int(img_name.split('.')[0])
                test_ids.append(img_id)
            except:
                pass
    
    if test_ids:
        print(f"测试图片ID范围: {min(test_ids)} - {max(test_ids)}")
        missing_ids = [i for i in range(1, 201) if i not in test_ids]
        if missing_ids:
            print(f"缺失的图片ID: {missing_ids}")
        else:
            print("测试图片ID完整 (1-200)")
    
    return len(test_images)

if __name__ == "__main__":
    print("蝴蝶图像分类数据分析")
    print("=" * 50)
    
    class_counts, total_train = analyze_training_data()
    test_count = analyze_test_data()
    
    print("\n" + "=" * 50)
    print("数据分析完成!")
