import os
import cv2
import argparse
import numpy as np
from preprocess import preprocess, postprocess


def draw_boxes_with_pose(image, boxes, scores, kpts, labels, score_threshold=0.5):
    """ Draws bounding boxes on given image.

    Args:
        image: np.array, shape of (H, W, C) and dtype is uint8.
        boxes: np.array, shape of (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
        scores: np.array, shape of (N,) the scores of bounding boxes.
        labels: list of str, List containing the labels of bounding boxes.
        score_threshold: float, the score of threshold.
    """
    for i, cls_name in enumerate(labels):
        box = boxes[i]
        score = scores[i]
        kpt = kpts[i]
        if score < score_threshold:
            continue

        label = '{} {:.2f}'.format(cls_name, score)
        xmin, ymin, xmax, ymax = box
        cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 250, 0), 1)
        font_face = cv2.FONT_HERSHEY_DUPLEX
        font_scale = 0.6
        font_thickness = 1
        text_w, text_h = cv2.getTextSize(label, font_face, font_scale, font_thickness)[0]

        if int(ymin) - 3 > text_h:
            text_pt = (int(xmin), int(ymin) - 3)
        else:
            text_pt = (int(xmin), int(ymin + text_h + 3))
        if ymin - text_h - 4 < 0:
            text_rec = (int(xmin + text_w), int(ymin + text_h + 4))
        else:
            text_rec = (int(xmin + text_w), int(ymin - text_h - 4))

        cv2.rectangle(image, (xmin, ymin), text_rec, (60, 179, 113), -1)
        cv2.putText(image, label, text_pt, font_face, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
        draw_pose(image, kpt)

    return image


def draw_pose(img, kpt, kpt_line=True, conf_thres=0.25, radius=1):
    """
    Plot keypoints on the image.

    Args:
        kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
        kpt_line (bool, optional): Draw lines between keypoints.
        conf_thres (float, optional): Confidence threshold.
    """
    skeleton = [
        [16, 14],
        [14, 12],
        [17, 15],
        [15, 13],
        [12, 13],
        [6, 12],
        [7, 13],
        [6, 7],
        [6, 8],
        [7, 9],
        [8, 10],
        [9, 11],
        [2, 3],
        [1, 2],
        [1, 3],
        [2, 4],
        [3, 5],
        [4, 6],
        [5, 7],
    ]
    kpt_color = [
        [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [255, 128, 0], 
        [255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0], [255, 128, 0], [51, 153, 255], 
        [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255]
    ]
    limb_color = [
        [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [255, 51, 255], 
        [255, 51, 255], [255, 51, 255], [255, 128, 0], [255, 128, 0], [255, 128, 0], 
        [255, 128, 0], [255, 128, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], [0, 255, 0], 
        [0, 255, 0], [0, 255, 0], [0, 255, 0]
    ]
    for i, k in enumerate(kpt):
        color_k = kpt_color[i]
        x_coord, y_coord = k[0], k[1]
        if len(k) == 3:
            conf = k[2]
            if conf < conf_thres:
                continue
        cv2.circle(img, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)

    if kpt_line:
        ndim = kpt.shape[-1]
        for i, sk in enumerate(skeleton):
            pos1 = (int(kpt[(sk[0] - 1), 0]), int(kpt[(sk[0] - 1), 1]))
            pos2 = (int(kpt[(sk[1] - 1), 0]), int(kpt[(sk[1] - 1), 1]))
            if ndim == 3:
                conf1 = kpt[(sk[0] - 1), 2]
                conf2 = kpt[(sk[1] - 1), 2]
                if conf1 < conf_thres or conf2 < conf_thres:
                    continue
            cv2.line(
                img,
                pos1,
                pos2,
                limb_color[i],
                thickness=1,
                lineType=cv2.LINE_AA,
            )

    return img


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc # type: ignore[import]
        from rbcc.backend.runner import SGRunner # type: ignore[import]
        sg = rbcc.load_sg(args.model_file)
        device = 'cpu' if args.device_id < 0 else f"cuda:{args.device_id}"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


def demo(args):
    runner = init_runner(args)

    img, img_meta = preprocess(args.img_file, True)
    outs = runner(img)
    bboxes, scores, _, kpts = postprocess(outs, img_meta)
    o_img = cv2.imread(args.img_file)
    
    runner.close()
    labels_str = ["person"] * len(scores)
    r_img = draw_boxes_with_pose(o_img, bboxes.astype(np.int32), scores, kpts, labels_str, score_threshold=0.5)
    cv2.imwrite('demo_predict.jpg', r_img)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('img_file', type=str, help="input image.")
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--e2e', action="store_true", help="test end2end")
    demo(parser.parse_args())