import argparse
import numpy as np

import rbcc
import rbpy as rb # type: ignore[import]
from rbcc.backend.rbrt.ir_module import IRModuleConverter
from rbcc.utils.helpers import change_rbo_output_shape


def preprocess(img, input_size=416):
    img = rb.cast(img, dtype='float32')
    resized_img = rb.resize(img, input_size, input_size, interpolation="BILINEAR")
    resized_img = resized_img / rb.constant(255.0, dtype="float32")
    return resized_img


def postprocess(outputs, img_size, model_size, conf_thres=0.001, iou_thres=0.65, max_det=300):
    outputs = decode_box(outputs, model_size)
    outputs = non_max_suppression(outputs, conf_thres=conf_thres, nms_thres=iou_thres, max_det=max_det)
    boxes = yolo_correct_boxes(outputs, model_size, img_size, letterbox_image=False)
    scores = rb.reshape(outputs[:, 4] * outputs[:, 5], [-1])
    labels = rb.reshape(rb.cast(outputs[:, 6], dtype='int32'), [-1])
    return boxes, scores, labels


def decode_box(inputs, input_shape):
    outputs = []
    for i, input in enumerate(inputs):
        # input: [1, 13, 13, 255]
        batch_size = input.shape[0]
        input_height = input.shape[1]
        input_width = input.shape[2]
        # [1, 13, 13, 255] -> [1, 13, 13, 3, 85]
        prediction = rb.reshape(input, [batch_size, input_height, input_width, 3, 85])
        
        grid_x = np.broadcast_to(np.arange(input_width, dtype=np.float32), [batch_size, 3, input_height, input_width]).transpose(0, 2, 3, 1)[..., None]
        grid_y = np.broadcast_to(np.arange(input_height, dtype=np.float32), [batch_size, 3, input_height, input_width]).transpose(0, 3, 2, 1)[..., None]

        x = prediction[..., 0:1]
        y = prediction[..., 1:2]
        x = (rb.math.sigmoid(x) + rb.constant(grid_x)) / rb.constant(input_width, dtype='float32')
        y = (rb.math.sigmoid(y) + rb.constant(grid_y)) / rb.constant(input_height, dtype='float32')
        
        stride_h = input_shape[0] / input_height
        stride_w = input_shape[1] / input_width
        anchors = np.array([[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]])
        anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
        scaled_anchors = np.array([(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in anchors[anchors_mask[i]]], dtype=np.float32)
        anchor_w = np.broadcast_to(np.reshape(scaled_anchors[:, 0:1], [1, 1, 1, -1, 1]), [1, input_height, input_width, 3, 1])
        anchor_h = np.broadcast_to(np.reshape(scaled_anchors[:, 1:2], [1, 1, 1, -1, 1]), [1, input_height, input_width, 3, 1])

        w = prediction[..., 2:3] 
        h = prediction[..., 3:4]
        w = (rb.math.exp(w) * rb.constant(anchor_w)) / rb.constant(input_width, dtype='float32')
        h = (rb.math.exp(h) * rb.constant(anchor_h)) / rb.constant(input_height, dtype='float32')

        pred_boxes = rb.reshape(rb.concat([x, y, w, h], -1), [batch_size, -1, 4]) 
        conf        = rb.reshape(rb.math.sigmoid(prediction[..., 4]), [batch_size, -1, 1])
        pred_cls    = rb.reshape(rb.math.sigmoid(prediction[..., 5:]), [batch_size, -1, 80])

        output = rb.concat((pred_boxes, conf, pred_cls), -1)
        outputs.append(output)

    outputs = rb.concat(outputs, 1)
    return outputs
    
    
def non_max_suppression(prediction, conf_thres=0.001, nms_thres=0.65, max_det=300):
    # prediction shape: [1, 10647, 85]
    x1 = prediction[:, :, 0:1] - prediction[:, :, 2:3] / rb.constant(2.0, dtype='float32')
    y1 = prediction[:, :, 1:2] - prediction[:, :, 3:4] / rb.constant(2.0, dtype='float32')
    x2 = prediction[:, :, 0:1] + prediction[:, :, 2:3] / rb.constant(2.0, dtype='float32')
    y2 = prediction[:, :, 1:2] + prediction[:, :, 3:4] / rb.constant(2.0, dtype='float32')
    box_corner = rb.concat([x1, y1, x2, y2], -1)
    prediction = rb.concat([box_corner, prediction[:, :, 4:]], -1)

    image_pred = prediction[0]
    class_pred = rb.argmax(image_pred[:, 5:5+80], axis=1, keepdims=True)  
    class_conf = rb.reduce.max(image_pred[:, 5:5+80], axis=[1], keepdims=True) 

    class_conf = rb.cast(class_conf, dtype='float32')
    class_pred = rb.cast(class_pred, dtype='float32')
    detections = rb.concat([image_pred[:, :5], class_conf, class_pred], 1)

    offset = class_pred * rb.constant(7680., dtype="float32")
    keep = rb.nms(detections[:, :4] + offset, detections[:, 4] * detections[:, 5], 
                max_output_size=max_det, 
                iou_threshold=nms_thres, 
                score_threshold=conf_thres)
    max_detections = rb.gather(detections, keep)

    return max_detections


def yolo_correct_boxes(outputs, input_shape, image_shape, letterbox_image):
    input_shape = np.array(input_shape[::-1], dtype=np.float32)
    image_shape = np.array(image_shape[::-1], dtype=np.float32)
    boxes = outputs[:, :4]
    if letterbox_image:
        xy = (outputs[:, 0:2] + outputs[:, 2:4]) / rb.constant(2.0, dtype='float32')
        wh = outputs[:, 2:4] - outputs[:, 0:2]
        new_shape = np.round(image_shape * np.min(input_shape/image_shape))
        offset  = (input_shape - new_shape) / 2. / input_shape
        scale   = input_shape / new_shape
        offset = rb.constant(np.array(offset, dtype=np.float32).reshape(-1))
        scale = rb.constant(np.array(scale, dtype=np.float32).reshape(-1))
        xy  = (xy - offset) * scale
        wh *= scale
        x1y1 = xy - (wh / rb.constant(2.0, dtype='float32'))
        x2y2 = xy + (wh / rb.constant(2.0, dtype='float32'))
        boxes = rb.concat([x1y1, x2y2], -1)
    image_shape = np.concatenate([image_shape, image_shape], axis=-1)
    boxes = boxes * rb.constant(image_shape, dtype='float32')
    return boxes


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('sg_file', type=str, help="8bit sg file path.")
    parser.add_argument('rbo_file', type=str, help="rbo file path.")
    parser.add_argument('--conf-thres', type=float, default=0.001, help="nms score thresold.")
    parser.add_argument('--iou-thres', type=float, default=0.65, help="nms iou thresold.")
    parser.add_argument('--max-det', type=int, default=300, help="nms max output.")
    parser.add_argument('--config', type=str, default=None, help="compiler env config path")
    args = parser.parse_args()

    sg = rbcc.load_sg(args.sg_file)
    ir_module = IRModuleConverter(sg)

    # 使用 rbpy 定义前处理
    img = rb.var("img", shape=[1, 1080, 1920, 3], dtype="uint8")
    model_input = preprocess(img)

    # 使用 model_input 替换 sg 的输入
    ir_module.to_ast({"images": model_input})

    # 获取 sg 输出节点，使用 rbpy 定义后处理
    outputs = ir_module.get_outputs(return_list=True)
    outputs = postprocess(outputs, [1080, 1920], [416, 416], 
                          conf_thres=args.conf_thres, 
                          iou_thres=args.iou_thres, 
                          max_det=args.max_det)

    ir_mod = rb.IRModule.from_expr([img], outputs)

    rbo = rb.build(ir_mod, args.rbo_file, config=args.config)

    # 经过 nms 后模型输出大小为动态 shape， 这里设置模型的最大 shape
    change_rbo_output_shape(args.rbo_file, [[args.max_det, 4], [args.max_det], [args.max_det]])
    print(f"new rbo has been saved in {args.rbo_file}")