v1.0
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# DeiT: https://github.com/facebookresearch/deit
|
||||
# --------------------------------------------------------
|
||||
|
||||
import os
|
||||
import PIL
|
||||
from torchvision import datasets, transforms
|
||||
from timm.data import create_transform
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
|
||||
transform = build_transform(is_train, args)
|
||||
root = os.path.join(args.data_path, is_train)
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def build_transform(is_train, args):
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
std = IMAGENET_DEFAULT_STD
|
||||
# train transform
|
||||
if is_train=='train':
|
||||
# this should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
input_size=args.input_size,
|
||||
is_training=True,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
interpolation='bicubic',
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
return transform
|
||||
|
||||
# eval transform
|
||||
t = []
|
||||
if args.input_size <= 224:
|
||||
crop_pct = 224 / 256
|
||||
else:
|
||||
crop_pct = 1.0
|
||||
size = int(args.input_size / crop_pct)
|
||||
t.append(
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
)
|
||||
t.append(transforms.CenterCrop(args.input_size))
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(mean, std))
|
||||
return transforms.Compose(t)
|
||||
Reference in New Issue
Block a user