Files
2025-08-31 18:03:57 +01:00

82 lines
2.8 KiB
Python

import os
import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, is_train)
dataset = datasets.ImageFolder(root, transform=transform)
if is_train == 'train':
ratio = float(getattr(args, "dataratio", 1.0))
seed = int(getattr(args, "seed", 0))
stratified = bool(getattr(args, "stratified", False))
if 0.0 < ratio < 1.0:
if stratified:
idx = _stratified_indices(dataset.targets, ratio, seed)
else:
# simple uniform subsample with torch.Generator for reproducibility
g = torch.Generator().manual_seed(seed)
n = len(dataset)
k = max(1, int(n * ratio))
idx = torch.randperm(n, generator=g)[:k].tolist()
dataset = Subset(dataset, idx)
return dataset
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
if is_train == 'train':
return create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='bicubic',
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
# eval transform
crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
size = int(args.input_size / crop_pct)
t = [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(args.input_size),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
return transforms.Compose(t)
# ---- helpers ----
def _stratified_indices(targets, ratio: float, seed: int):
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
t = torch.as_tensor(targets)
classes = torch.unique(t)
g = torch.Generator().manual_seed(seed)
keep = []
for c in classes.tolist():
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
if len(cls_idx) == 0:
continue
k = max(1, int(round(len(cls_idx) * ratio)))
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
keep.extend(sel.tolist())
# shuffle final indices (stable across seed)
g2 = torch.Generator().manual_seed(seed + 1)
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
return keep