import sys
import torch
from torchvision.models import inception_v3, Inception_V3_Weights


def export_sg(onnx_file):
    model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
    model.forward = model._forward # remove _transform_input
    model.eval()

    dummy_input = torch.randn(1, 3, 299, 299)
    # pytorch 卷积默认输入格式是 [Batch, Channel, Height, Width]。
    torch.onnx.export(model, dummy_input, onnx_file, opset_version=13, input_names=["img"])


if __name__ == '__main__':
    export_sg(sys.argv[1])