import sys
import torch
import onnx
from onnxsim import simplify
from transformers import SwinForImageClassification


def main(onnx_file):
    model = SwinForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
    model.eval()

    dummy_input = torch.randn(1, 3, 224, 224)
    # pytorch 卷积默认输入格式是 [Batch, Channel, Height, Width]，
    torch.onnx.export(model, dummy_input, onnx_file, input_names=["img"])
    print('Simplying onnx ...')
    onnx_model = onnx.load(onnx_file)
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simp, onnx_file)


if __name__ == '__main__':
    main(sys.argv[1])