import sys
import torch
from torchvision.models import squeezenet1_1, SqueezeNet1_1_Weights


def export_sg(onnx_file):
    model = squeezenet1_1(weights=SqueezeNet1_1_Weights.IMAGENET1K_V1)
    model.eval()

    dummy_input = torch.randn(1, 3, 224, 224)
    # pytorch 卷积默认输入格式是 [Batch, Channel, Height, Width]，
    torch.onnx.export(model, dummy_input, onnx_file, opset_version=13, input_names=["img"])


if __name__ == '__main__':
    export_sg(sys.argv[1])