import sys
import torch
from transformers import AutoModelForSequenceClassification

def export_onnx(onnx_file):
    model = AutoModelForSequenceClassification.from_pretrained("ModelTC/bert-base-uncased-cola")
    model.eval()

    input_ids = torch.ones(1, 128).long()
    attention_mask = torch.ones(1, 128).long()
    dummy_input = (input_ids, attention_mask)
    torch.onnx.export(model, dummy_input, onnx_file, input_names=["input_id", "attention_mask"])

if __name__ == '__main__':
    export_onnx(sys.argv[1])