import sys
import argparse
import numpy as np

import rbcc
import rbpy as rb # type: ignore[import]
from rbcc.backend.rbrt.ir_module import IRModuleConverter
from rbcc.utils.helpers import change_rbo_output_shape


def preprocess(img, input_size=640):
    h, w = list(img.shape[1:3])
    scale = input_size / max(h, w)  # resize image to img_size
    resize_h, resize_w = int(h * scale), int(w * scale)
    img = rb.cast(img, dtype='float32')
    resized_img = rb.resize(img, resize_h, resize_w, interpolation="BILINEAR")

    dh, dw = input_size - resize_h, input_size - resize_w
    top = dh // 2
    bottom = dh - top
    left = dw // 2
    right = dw - left
    pad_size = rb.constant(np.array([0, 0, top, bottom, left, right, 0, 0], dtype=np.int64))

    padded_img = rb.pad(input=resized_img, value=rb.constant(114, "float32"), pad_size=pad_size)
    padded_img = padded_img / rb.constant(255.0, dtype="float32")

    return padded_img

def xywh2xyxy(xy, wh):
    half_wh = wh / 2
    xy_min = xy - half_wh
    xy_max = xy + half_wh
    boxes = rb.concat([xy_min, xy_max], axis=1)
    return boxes

def get_valid_box(boxes, img_h, img_w):
    upper_bound = rb.constant(np.array([img_w, img_h, img_w, img_h], dtype=np.float32) - 1)
    lower_bound = rb.constant(np.array([0, 0, 0, 0], dtype=np.float32))
    return rb.maximum(rb.minimum(boxes, upper_bound), lower_bound)

def scale_boxes(boxes, img_size, model_size):
    img_h, img_w = img_size
    model_h, model_w = model_size
    assert model_h == model_w

    scale = model_h / max(img_h, img_w)  # resize image to img_size
    resize_h, resize_w = int(img_h * scale), int(img_w * scale)

    pad_h, pad_w = (model_h - resize_h) // 2, (model_w - resize_w) // 2
    pad = rb.constant(np.array([pad_w, pad_h, pad_w, pad_h], dtype=np.float32))

    rescale_h, rescale_w = resize_h / img_h, resize_w / img_w
    scale_factor = rb.constant(np.array([rescale_w, rescale_h, rescale_w, rescale_h], dtype=np.float32))

    boxes = (boxes - pad) / scale_factor
    boxes = get_valid_box(boxes, img_h, img_w)
    return boxes

def postprocess(outputs, img_size, model_size, conf_thres=0.001, iou_thres=0.65, max_det=300):
    # xy is [1, 2, N], wh is [1, 2, N], confs is [1, 81, N]
    xy, wh, confs = outputs
    boxes = xywh2xyxy(xy, wh)
    _, nc, _ = list(confs.shape)
    # remove batch
    boxes = rb.reshape(boxes, [4, -1])
    boxes = rb.transpose(boxes, [1, 0])
    confs = rb.reshape(confs, [nc, -1])

    obj_conf = confs[:1, :]
    cls_conf = confs[1:, :]

    # use best class, not multi_label
    scores = obj_conf * cls_conf
    labels = rb.argmax(scores, axis=0)
    scores = rb.reduce.max(scores, axis=[0])
    scores = rb.reshape(scores, [-1])

    # shape of boxes is Nx4, scores is N
    keep = rb.nms(boxes, scores, 
                  max_output_size=max_det, 
                  iou_threshold=iou_thres, 
                  score_threshold=conf_thres)
    boxes = rb.gather(boxes, keep)
    scores = rb.gather(scores, keep)
    labels = rb.gather(labels, keep)
    boxes = scale_boxes(boxes, img_size, model_size)
    return boxes, scores, labels


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('sg_file', type=str, help="8bit sg file path.")
    parser.add_argument('rbo_file', type=str, help="rbo file path.")
    parser.add_argument('--conf-thres', type=float, default=0.1, help="nms score thresold.")
    parser.add_argument('--iou-thres', type=float, default=0.65, help="nms iou thresold.")
    parser.add_argument('--max-det', type=int, default=300, help="nms max output.")
    args = parser.parse_args()

    sg = rbcc.load_sg(args.sg_file)
    ir_module = IRModuleConverter(sg)

    # 使用 rbpy 定义前处理
    img = rb.var("img", shape=[1, 1080, 1920, 3], dtype="uint8")
    model_input = preprocess(img)

    # 使用 model_input 替换 sg 的输入
    # 此处 "images" 需要参考 sg 中模型输入节点名称
    ir_module.to_ast({"images": model_input})

    # 获取 sg 输出节点，使用 rbpy 定义后处理
    # 我们从 [xy, wh, conf_cls] 分支开始定义后处理
    outputs = ir_module.get_outputs(return_list=True)
    outputs = postprocess(outputs, [1080, 1920], [640, 640], 
                          conf_thres=args.conf_thres, 
                          iou_thres=args.iou_thres, 
                          max_det=args.max_det)

    ir_mod = rb.IRModule.from_expr([img], list(outputs))

    rbo = rb.build(ir_mod, sys.argv[2])

    # 经过 nms 后模型输出大小为动态 shape， 这里设置模型的最大 shape
    change_rbo_output_shape(sys.argv[2], [[args.max_det, 4], [args.max_det], [args.max_det]])