import sys
import torch
from transformers import ViTForImageClassification


def main(onnx_file):
    model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation='eager')
    model.eval()

    dummy_input = torch.randn(1, 3, 224, 224)
    # pytorch 卷积默认输入格式是 [Batch, Channel, Height, Width]，
    torch.onnx.export(model, dummy_input, onnx_file, input_names=["img"])


if __name__ == '__main__':
    main(sys.argv[1])