import argparse
import numpy as np

import rbcc
import rbpy as rb # type: ignore[import]
from rbcc.backend.rbrt.ir_module import IRModuleConverter
from rbcc.utils.helpers import change_rbo_output_shape


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = list(img.shape[1:3])  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)
    r = min(min(new_shape[0] / shape[0], new_shape[1] / shape[1]), 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))

    # pad_size = rb.constant(np.array([0, 0, top, bottom, left, right, 0, 0], dtype=np.int64))
    # img = rb.pad(img, value=rb.constant(color[0], "float32"), pad_size=pad_size)

    # Using concat instead of pad performs better.
    batch, resize_h, _, c = list(img.shape)
    w_cat_inputs = []
    if left != 0:
        w_left = rb.constant(np.full([batch, resize_h, left, c], color[0], dtype=np.float32))
        w_cat_inputs.append(w_left)
    w_cat_inputs.append(img)

    if right != 0:
        w_right = rb.constant(np.full([batch, resize_h, right, c], color[0], dtype=np.float32))
        w_cat_inputs.append(w_right)

    pad_w_img = rb.concat(w_cat_inputs, axis=2)

    h_cat_inputs = []
    if top != 0:
        h_top = rb.constant(np.full([batch, top, new_shape[1], c], color[0], dtype=np.float32))
        h_cat_inputs.append(h_top)
    h_cat_inputs.append(pad_w_img)

    if bottom != 0:
        h_bottom = rb.constant(np.full([batch, bottom, new_shape[1], c], color[0], dtype=np.float32))
        h_cat_inputs.append(h_bottom)

    img = rb.concat(h_cat_inputs, axis=1)
    
    return img, (dw, dh)


def preprocess(img, input_size=(640, 640)):
    h, w = list(img.shape[1:3])
    scale = min(input_size[0]/h, input_size[1]/w)
    resize_h, resize_w = int(h * scale), int(w * scale)
    img = rb.cast(img, dtype='float32')
    img = rb.resize(img, resize_h, resize_w, interpolation="BILINEAR")

    img, pad = letterbox(img, input_size)

    img = img / rb.constant(255.0, dtype="float32")

    img_meta = {
        'scale_factor': [resize_h / h, resize_w / w],
        'pad_size': pad,
        'img_size': [h, w]
    }
    return img, img_meta

def xywh2xyxy(xy, wh):
    half_wh = wh / 2
    xy_min = xy - half_wh
    xy_max = xy + half_wh
    boxes = rb.concat([xy_min, xy_max], axis=1)
    return boxes

def scale_boxes(boxes, img_meta):
    pad_w, pad_h = img_meta['pad_size'] # w, h
    rescale_h, rescale_w = img_meta['scale_factor'] # h, w
    img_h, img_w = img_meta['img_size']

    pad = rb.constant(np.array([pad_w, pad_h, pad_w, pad_h], dtype=np.float32))
    scale_factor = rb.constant(np.array([rescale_w, rescale_h, rescale_w, rescale_h], dtype=np.float32))

    boxes = (boxes - pad) / scale_factor

    upper_bound = rb.constant(np.array([img_w, img_h, img_w, img_h], dtype=np.float32) - 1)
    lower_bound = rb.constant(np.array([0, 0, 0, 0], dtype=np.float32))
    boxes = rb.maximum(rb.minimum(boxes, upper_bound), lower_bound)
    return boxes


def postprocess(outputs, img_meta, conf_thres=0.001, iou_thres=0.65, max_det=300):
    # outputs is [1, 84, 8400]
    if outputs.shape[0] != 1:
        raise RuntimeError(f"expect the batch size is 1.")
    outputs = rb.transpose(outputs[0], [1, 0])
    xy, wh, scores = outputs[:, :2], outputs[:, 2:4], outputs[:, 4:]
    boxes = xywh2xyxy(xy, wh)
    
    # use best class, not multi_label
    labels = rb.argmax(scores, axis=1)
    scores = rb.reduce.max(scores, axis=[1])
    scores = rb.reshape(scores, [-1])

    # shape of boxes is Nx4, scores is N
    keep = rb.nms(boxes, scores, 
                  max_output_size=max_det, 
                  iou_threshold=iou_thres, 
                  score_threshold=conf_thres)
    boxes = rb.gather(boxes, keep)
    scores = rb.gather(scores, keep)
    labels = rb.gather(labels, keep)
    boxes = scale_boxes(boxes, img_meta)
    return boxes, scores, labels


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('sg_file', type=str, help="8bit sg file path.")
    parser.add_argument('rbo_file', type=str, help="rbo file path.")
    parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size, (height, width)')
    parser.add_argument('--conf-thres', type=float, default=0.001, help="nms score thresold.")
    parser.add_argument('--iou-thres', type=float, default=0.70, help="nms iou thresold.")
    parser.add_argument('--max-det', type=int, default=300, help="nms max output.")
    parser.add_argument('--config', type=str, default=None, help="compiler env config path")
    args = parser.parse_args()

    sg = rbcc.load_sg(args.sg_file)
    ir_module = IRModuleConverter(sg)

    # 使用 rbpy 定义前处理
    img = rb.var("img", shape=[1, 1080, 1920, 3], dtype="uint8")
    model_input, img_meta = preprocess(img, input_size=args.img_size)

    # 使用 model_input 替换 sg 的输入
    # 此处 "images" 需要参考 sg 中模型输入节点名称
    ir_module.to_ast({"images": model_input})

    # 获取 sg 输出节点，使用 rbpy 定义后处理
    # 我们从 [xy, wh, conf_cls] 分支开始定义后处理
    outputs = ir_module.get_outputs(return_list=True)[0]
    outputs = postprocess(outputs, img_meta, 
                          conf_thres=args.conf_thres, 
                          iou_thres=args.iou_thres, 
                          max_det=args.max_det)

    ir_mod = rb.IRModule.from_expr([img], list(outputs))

    rbo = rb.build(ir_mod, args.rbo_file)

    # 经过 nms 后模型输出大小为动态 shape， 这里设置模型的最大 shape
    change_rbo_output_shape(args.rbo_file, [[args.max_det, 4], [args.max_det], [args.max_det]])
    print(f"New rbo has been saved in '{args.rbo_file}'")