From 62746e7f031713d15640e00e9824cc90ef478950 Mon Sep 17 00:00:00 2001 From: rmaphoh Date: Sun, 15 Oct 2023 23:48:14 +0100 Subject: [PATCH] image size option --- README.md | 6 ++++-- main_finetune.py | 1 + models_vit.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d558a19..bfd7444 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./finetune_IDRiD/ \ - --finetune ./RETFound_cfp_weights.pth + --finetune ./RETFound_cfp_weights.pth \ + --input_size 224 ``` @@ -97,7 +98,8 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./internal_IDRiD/ \ - --resume ./finetune_IDRiD/checkpoint-best.pth + --resume ./finetune_IDRiD/checkpoint-best.pth \ + --input_size 224 ``` diff --git a/main_finetune.py b/main_finetune.py index 42c5330..7b611de 100644 --- a/main_finetune.py +++ b/main_finetune.py @@ -239,6 +239,7 @@ def main(args): label_smoothing=args.smoothing, num_classes=args.nb_classes) model = models_vit.__dict__[args.model]( + img_size=args.input_size, num_classes=args.nb_classes, drop_path_rate=args.drop_path, global_pool=args.global_pool, diff --git a/models_vit.py b/models_vit.py index d7334c3..5a1c30e 100644 --- a/models_vit.py +++ b/models_vit.py @@ -49,7 +49,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer): def vit_large_patch16(**kwargs): model = VisionTransformer( - img_size=224,patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model