import sys
import torch
from transformers import ViTForImageClassification


def main(onnx_file):
    # 导出时需要指定 eager 模式
    model = ViTForImageClassification.from_pretrained(
        "facebook/deit-small-patch16-224", 
        attn_implementation='eager'
    )
    model.eval()

    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, dummy_input, onnx_file, input_names=["img"])


if __name__ == '__main__':
    main(sys.argv[1])