diff --git a/README.md b/README.md index d558a19..bfd7444 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./finetune_IDRiD/ \ - --finetune ./RETFound_cfp_weights.pth + --finetune ./RETFound_cfp_weights.pth \ + --input_size 224 ``` @@ -97,7 +98,8 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./internal_IDRiD/ \ - --resume ./finetune_IDRiD/checkpoint-best.pth + --resume ./finetune_IDRiD/checkpoint-best.pth \ + --input_size 224 ``` diff --git a/main_finetune.py b/main_finetune.py index 42c5330..7b611de 100644 --- a/main_finetune.py +++ b/main_finetune.py @@ -239,6 +239,7 @@ def main(args): label_smoothing=args.smoothing, num_classes=args.nb_classes) model = models_vit.__dict__[args.model]( + img_size=args.input_size, num_classes=args.nb_classes, drop_path_rate=args.drop_path, global_pool=args.global_pool, diff --git a/models_vit.py b/models_vit.py index d7334c3..5a1c30e 100644 --- a/models_vit.py +++ b/models_vit.py @@ -49,7 +49,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer): def vit_large_patch16(**kwargs): model = VisionTransformer( - img_size=224,patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model