Incorporate DINOv3, DINOv2
This commit is contained in:
+49
-21
@@ -1,29 +1,39 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Partly revised by YZ @UCL&Moorfields
|
||||
# --------------------------------------------------------
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Subset
|
||||
from torchvision import datasets, transforms
|
||||
from timm.data import create_transform
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
transform = build_transform(is_train, args)
|
||||
root = os.path.join(args.data_path, is_train)
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
|
||||
return dataset
|
||||
if is_train == 'train':
|
||||
ratio = float(getattr(args, "dataratio", 1.0))
|
||||
seed = int(getattr(args, "seed", 0))
|
||||
stratified = bool(getattr(args, "stratified", False))
|
||||
|
||||
if 0.0 < ratio < 1.0:
|
||||
if stratified:
|
||||
idx = _stratified_indices(dataset.targets, ratio, seed)
|
||||
else:
|
||||
# simple uniform subsample with torch.Generator for reproducibility
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
n = len(dataset)
|
||||
k = max(1, int(n * ratio))
|
||||
idx = torch.randperm(n, generator=g)[:k].tolist()
|
||||
dataset = Subset(dataset, idx)
|
||||
|
||||
return dataset
|
||||
|
||||
def build_transform(is_train, args):
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
std = IMAGENET_DEFAULT_STD
|
||||
# train transform
|
||||
|
||||
if is_train == 'train':
|
||||
# this should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
return create_transform(
|
||||
input_size=args.input_size,
|
||||
is_training=True,
|
||||
color_jitter=args.color_jitter,
|
||||
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
return transform
|
||||
|
||||
# eval transform
|
||||
t = []
|
||||
if args.input_size <= 224:
|
||||
crop_pct = 224 / 256
|
||||
else:
|
||||
crop_pct = 1.0
|
||||
crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
|
||||
size = int(args.input_size / crop_pct)
|
||||
t.append(
|
||||
t = [
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
)
|
||||
t.append(transforms.CenterCrop(args.input_size))
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(mean, std))
|
||||
transforms.CenterCrop(args.input_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]
|
||||
return transforms.Compose(t)
|
||||
|
||||
# ---- helpers ----
|
||||
|
||||
def _stratified_indices(targets, ratio: float, seed: int):
|
||||
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
|
||||
t = torch.as_tensor(targets)
|
||||
classes = torch.unique(t)
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
|
||||
keep = []
|
||||
for c in classes.tolist():
|
||||
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
|
||||
if len(cls_idx) == 0:
|
||||
continue
|
||||
k = max(1, int(round(len(cls_idx) * ratio)))
|
||||
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
|
||||
keep.extend(sel.tolist())
|
||||
|
||||
# shuffle final indices (stable across seed)
|
||||
g2 = torch.Generator().manual_seed(seed + 1)
|
||||
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
|
||||
return keep
|
||||
|
||||
|
||||
Reference in New Issue
Block a user