import os
import argparse
import numpy as np

import rbcc
import rbpy as rb # type: ignore[import]
from rbcc.backend.rbrt.ir_module import IRModuleConverter
from rbcc.utils.helpers import change_rbo_output_shape


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = list(img.shape[1:3])  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)
    r = min(min(new_shape[0] / shape[0], new_shape[1] / shape[1]), 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))

    # pad_size = rb.constant(np.array([0, 0, top, bottom, left, right, 0, 0], dtype=np.int64))
    # img = rb.pad(img, value=rb.constant(color[0], "float32"), pad_size=pad_size)

    # Using concat instead of pad performs better.
    batch, resize_h, _, c = list(img.shape)
    w_cat_inputs = []
    if left != 0:
        w_left = rb.constant(np.full([batch, resize_h, left, c], color[0], dtype=np.float32))
        w_cat_inputs.append(w_left)
    w_cat_inputs.append(img)

    if right != 0:
        w_right = rb.constant(np.full([batch, resize_h, right, c], color[0], dtype=np.float32))
        w_cat_inputs.append(w_right)

    pad_w_img = rb.concat(w_cat_inputs, axis=2)

    h_cat_inputs = []
    if top != 0:
        h_top = rb.constant(np.full([batch, top, new_shape[1], c], color[0], dtype=np.float32))
        h_cat_inputs.append(h_top)
    h_cat_inputs.append(pad_w_img)

    if bottom != 0:
        h_bottom = rb.constant(np.full([batch, bottom, new_shape[1], c], color[0], dtype=np.float32))
        h_cat_inputs.append(h_bottom)

    img = rb.concat(h_cat_inputs, axis=1)
    
    return img, (dw, dh)


def preprocess(img, input_size=(640, 640)):
    h, w = list(img.shape[1:3])
    scale = min(input_size[0] / h, input_size[1] / w)
    resize_h, resize_w = int(h * scale), int(w * scale)
    img = rb.cast(img, dtype='float32')
    img = rb.resize(img, resize_h, resize_w, interpolation="BILINEAR")

    img, _ = letterbox(img, input_size)

    img = img / rb.constant(255.0, dtype="float32")

    img_meta = {
        'img_size': [h, w],
        'input_size': input_size
    }
    return img, img_meta


def clip_box(boxes, ori_size):
    img_h, img_w = ori_size
    upper_bound = rb.constant(np.array([img_w, img_h, img_w, img_h], dtype=np.float32))
    lower_bound = rb.constant(np.array([0, 0, 0, 0], dtype=np.float32))
    return rb.maximum(rb.minimum(boxes, upper_bound), lower_bound)


def clip_landmark(landmark, ori_size):
    img_h, img_w = ori_size
    upper_bound = rb.constant(np.array([img_w, img_h, 1.] * 5, dtype=np.float32))
    lower_bound = rb.constant(np.array([0., 0., 0.] * 5, dtype=np.float32)) 
    return rb.maximum(rb.minimum(landmark, upper_bound), lower_bound)   


def scale_coords(coords, img_meta, kpt_label):
    ori_size = img_meta['img_size']
    input_size = img_meta['input_size']
    gain = min(input_size[0] / ori_size[0], input_size[1] / ori_size[1])
    dw = (input_size[1] - ori_size[1] * gain) / 2
    dh = (input_size[0] - ori_size[0] * gain) / 2
    if not kpt_label:
        pad = rb.constant(np.array([dw, dh, dw, dh], dtype=np.float32))
        gain = rb.constant(np.array([gain, gain, gain, gain], dtype=np.float32))
        coords = (coords - pad) / gain
        coords = clip_box(coords, ori_size)
    else:
        pad = rb.constant(np.array([dw, dh, 0.] * 5, dtype=np.float32))
        gain = rb.constant(np.array([gain, gain, 1.] * 5, dtype=np.float32))
        coords = (coords - pad) / gain
        coords = clip_landmark(coords, ori_size)
    return coords


def xywh2xyxy(xy, wh):
    half_wh = wh / 2
    xy_min = xy - half_wh
    xy_max = xy + half_wh
    boxes = rb.concat([xy_min, xy_max], axis=1)
    return boxes


def postprocess(outputs, img_meta, conf_thres=0.01, iou_thres=0.5, max_det=300):
    # outputs shape is [N, 25200, 21]
    outputs = rb.reshape(outputs, [-1, 21])
    xy, wh = outputs[:, :2], outputs[:, 2:4]
    obj_conf, cls_conf, landmarks = outputs[:, 4:5], outputs[:, 5:6], outputs[:, 6:]
    boxes = xywh2xyxy(xy, wh)
    
    # remove batch
    scores = rb.reduce.max(cls_conf * obj_conf, axis=[1])
    scores = rb.reshape(scores, [-1])

    # shape of boxes is Nx4, scores is N
    keep = rb.nms(boxes, scores, 
                  max_output_size=max_det, 
                  iou_threshold=iou_thres, 
                  score_threshold=conf_thres)

    boxes = rb.gather(boxes, keep)
    scores = rb.gather(scores, keep)
    landmarks = rb.gather(landmarks, keep)

    boxes = scale_coords(boxes, img_meta, kpt_label=False)
    landmarks = scale_coords(landmarks, img_meta, kpt_label=5)
    return boxes, scores, landmarks


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('sg_file', type=str, help="8bit sg file path.")
    parser.add_argument('rbo_file', type=str, help="rbo file path.")
    parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size, (height, width)')
    parser.add_argument('--conf-thres', type=float, default=0.01, help="nms score thresold.")
    parser.add_argument('--iou-thres', type=float, default=0.50, help="nms iou thresold.")
    parser.add_argument('--max-det', default=300, type=int, help="rbo file path.")
    parser.add_argument('--config', type=str, help="rbo file path.")
    args = parser.parse_args()

    sg = rbcc.load_sg(args.sg_file)
    ir_module = IRModuleConverter(sg)

    img = rb.var("img", shape=[1, 1080, 1920, 3], dtype="uint8")
    model_input, img_meta = preprocess(img, input_size=args.img_size)

    ir_module.to_ast({"images": model_input})
    
    outputs = ir_module.get_outputs(return_list=True)[0]
    outputs = postprocess(outputs, img_meta, 
                          conf_thres=args.conf_thres, 
                          iou_thres=args.iou_thres, 
                          max_det=args.max_det)

    ir_mod = rb.IRModule.from_expr([img], list(outputs))
    
    rbo = rb.build(ir_mod, args.rbo_file, config=args.config)
    
    # 经过 nms 后模型输出大小为动态 shape， 这里设置模型的最大 shape
    change_rbo_output_shape(args.rbo_file, [[args.max_det, 4], [args.max_det], [args.max_det, 15]])
    print(f"New rbo has been saved in '{args.rbo_file}'")