import os
import cv2
import argparse
import numpy as np
from preprocess import preprocess, postprocess


COCO_INSTANCE_CATEGORY_NAMES = [
    'person', 'bicycle', 'car', 'motorcycle', 'airplane',
    'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
    'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
    'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
    'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

def draw_boxes(image, boxes, scores, labels, score_threshold=0.5):
    """ Draws bounding boxes on given image.

    Args:
        image: np.array, shape of (H, W, C) and dtype is uint8.
        boxes: np.array, shape of (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
        scores: np.array, shape of (N,) the scores of bounding boxes.
        labels: list of str, List containing the labels of bounding boxes.
        score_threshold: float, the score of threshold.
    """
    for i, cls_name in enumerate(labels):
        box = boxes[i]
        score = scores[i]
        if score < score_threshold:
            continue

        label = '{} {:.2f}'.format(cls_name, score)
        xmin, ymin, xmax, ymax = box
        cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 250, 0), 1)
        font_face = cv2.FONT_HERSHEY_DUPLEX
        font_scale = 0.6
        font_thickness = 1
        text_w, text_h = cv2.getTextSize(label, font_face, font_scale, font_thickness)[0]

        if int(ymin) - 3 > text_h:
            text_pt = (int(xmin), int(ymin) - 3)
        else:
            text_pt = (int(xmin), int(ymin + text_h + 3))
        if ymin - text_h - 4 < 0:
            text_rec = (int(xmin + text_w), int(ymin + text_h + 4))
        else:
            text_rec = (int(xmin + text_w), int(ymin - text_h - 4))

        cv2.rectangle(image, (xmin, ymin), text_rec, (60, 179, 113), -1)
        cv2.putText(image, label, text_pt, font_face, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
    return image


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc # type: ignore[import]
        from rbcc.backend.runner import SGRunner # type: ignore[import]
        sg = rbcc.load_sg(args.model_file)
        device = 'cpu' if args.device_id < 0 else f"cuda:{args.device_id}"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


def demo(args):
    runner = init_runner(args)
    ori_img = cv2.imread(args.img_file)
    if args.e2e:
        ori_img = keep_ratio_resize(ori_img, (1080, 1920))
        img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        img = np.expand_dims(img, 0)
        bboxes, scores, labels = runner(img)
    else: 
        img, img_meta = preprocess(args.img_file, True)
        outs = runner(img)
        bboxes, scores, labels = postprocess(outs, img_meta)
    runner.close()
    labels_str = [COCO_INSTANCE_CATEGORY_NAMES[int(i)] for i in labels]
    r_img = draw_boxes(ori_img, bboxes.astype(np.int32), scores, labels_str, score_threshold=0.7)
    cv2.imwrite('demo_predict.jpg', r_img)
    print("result has been saved in './demo_predict.jpg'")


def keep_ratio_resize(ori_img, tar_size):
    ori_size = ori_img.shape[:2]
    r = min(tar_size[0] / ori_size[0], tar_size[1] / ori_size[1])
    new_unpad = int(round(ori_size[0] * r)), int(round(ori_size[1] * r))
    img = cv2.resize(ori_img, new_unpad[::-1], interpolation=cv2.INTER_LINEAR)
    dh, dw = (tar_size[0] - new_unpad[0]) / 2, (tar_size[1] - new_unpad[1]) / 2
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
    return img


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('img_file', type=str, help="input image.")
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--e2e', action="store_true", help="test end2end")
    demo(parser.parse_args())
