# ------------------------------------------------------------------------------
# pose.pytorch
# Copyright (c) 2018-present Microsoft
# Licensed under The Apache-2.0 License [see LICENSE for details]
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

import argparse
import torch

from lib.models.pose_hrnet import get_pose_net
from lib.config import cfg
from lib.config import update_config


def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)
    parser.add_argument('--onnx-path',
                        help='saved onnx path',
                        type=str,
                        default='')
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')

    args = parser.parse_args()
    return args


def main(args):
    update_config(cfg, args)
    model = get_pose_net(cfg, is_train=False)
    model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)

    torch.onnx.export(model, torch.rand(1, 3, 256, 192), args.onnx_path, opset_version=11)
    print(f"ONNX file has been saved in {args.onnx_path}")


if __name__ == '__main__':
    args = parse_args()
    main(args)
