#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

class ButterflyDataset(Dataset):
    """蝴蝶图像数据集类"""
    
    def __init__(self, data_dir, transform=None, mode='train'):
        """
        初始化数据集
        
        Args:
            data_dir: 数据目录路径
            transform: 数据增强变换
            mode: 模式 ('train' 或 'test')
        """
        self.data_dir = data_dir
        self.transform = transform
        self.mode = mode
        self.samples = []
        
        # 读取类别映射
        with open('../image_data/Butterfly20_dict.json', 'r') as f:
            self.class_dict = json.load(f)
        
        # 反转字典，从类别名到索引
        self.class_to_idx = {v: int(k) for k, v in self.class_dict.items()}
        
        if mode == 'train':
            self._load_train_data()
        else:
            self._load_test_data()
    
    def _load_train_data(self):
        """加载训练数据"""
        print("加载训练数据...")
        butterfly_dir = os.path.join(self.data_dir, 'Butterfly20')
        
        for class_folder in os.listdir(butterfly_dir):
            class_path = os.path.join(butterfly_dir, class_folder)
            if os.path.isdir(class_path):
                class_idx = self.class_to_idx[class_folder]
                
                for img_name in os.listdir(class_path):
                    if img_name.endswith(('.jpg', '.jpeg', '.png')):
                        img_path = os.path.join(class_path, img_name)
                        self.samples.append((img_path, class_idx))
        
        print(f"训练数据加载完成，共 {len(self.samples)} 张图片")
    
    def _load_test_data(self):
        """加载测试数据"""
        print("加载测试数据...")
        test_dir = os.path.join(self.data_dir, 'Butterfly20_test')
        
        # 按数字顺序排序测试图片
        test_images = []
        for img_name in os.listdir(test_dir):
            if img_name.endswith('.jpg'):
                try:
                    img_id = int(img_name.split('.')[0])
                    test_images.append((img_id, img_name))
                except:
                    pass
        
        # 按ID排序
        test_images.sort(key=lambda x: x[0])
        
        for img_id, img_name in test_images:
            img_path = os.path.join(test_dir, img_name)
            self.samples.append((img_path, img_id))  # 对于测试集，标签是图片ID
        
        print(f"测试数据加载完成，共 {len(self.samples)} 张图片")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            if self.mode == 'train':
                return image, label
            else:
                return image, label, img_path.split('/')[-1]  # 返回图片文件名
            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # 返回一个空白图像作为占位符
            blank_image = torch.zeros(3, 224, 224)
            if self.mode == 'train':
                return blank_image, 0
            else:
                return blank_image, label, "error.jpg"

def get_data_loaders(batch_size=32, data_dir='../image_data'):
    """获取训练和验证数据加载器"""
    
    # 数据增强
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 创建数据集
    train_dataset = ButterflyDataset(data_dir, transform=train_transform, mode='train')
    
    # 分割训练集和验证集 (80% 训练, 20% 验证)
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = torch.utils.data.random_split(
        train_dataset, [train_size, val_size]
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_subset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    
    val_loader = DataLoader(
        val_subset, batch_size=batch_size, shuffle=False, num_workers=2
    )
    
    # 测试数据加载器
    test_dataset = ButterflyDataset(data_dir, transform=test_transform, mode='test')
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )
    
    return train_loader, val_loader, test_loader

if __name__ == "__main__":
    # 测试数据加载器
    print("测试数据加载器...")
    
    try:
        train_loader, val_loader, test_loader = get_data_loaders(batch_size=8)
        
        print(f"训练批次数量: {len(train_loader)}")
        print(f"验证批次数量: {len(val_loader)}")
        print(f"测试批次数量: {len(test_loader)}")
        
        # 测试一个批次
        for images, labels in train_loader:
            print(f"批次图像形状: {images.shape}")
            print(f"批次标签: {labels}")
            break
            
    except Exception as e:
        print(f"数据加载器测试失败: {e}")
        print("可能需要等待依赖包安装完成")
