import sys
import torch
import models
import onnx
from onnxsim import simplify
from timm.models import load_checkpoint, create_model


def export_onnx(onnx_file):
    model = create_model(
        'CSWin_64_12211_tiny_224',
        pretrained=False,
        num_classes=1000,
        drop_rate=0.0,
        drop_connect_rate=None, 
        drop_path_rate=0.1,
        drop_block_rate=None,
        global_pool=None,
        bn_tf=False,
        bn_momentum=None,
        bn_eps=None,
        checkpoint_path='',
        img_size=224,
        use_chk=False)
    load_checkpoint(model, "./cswin_tiny_224.pth", True)
    model.eval()

    dummy_input = torch.randn(1, 3, 224, 224)
    # pytorch 卷积默认输入格式是 [Batch, Channel, Height, Width]，
    torch.onnx.export(model, dummy_input, onnx_file, opset_version=13, input_names=["img"])
    print('Simplying onnx ...')
    onnx_model = onnx.load(onnx_file)
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simp, onnx_file)
    print(f"ONNX file has been saved in {onnx_file}.")


if __name__ == '__main__':
    export_onnx(sys.argv[1])