import os
import cv2
import argparse
import numpy as np
from preprocess import preprocess, postprocess


CLASS_NAMES = (
    "plane", "ship", "torage tank", "aseball diamond", "ennis court", "asketball court", "round track field",
    "arbor", "bridge", "arge vehicle", "small vehicle", "helicopter", "roundabout", "soccer ball field",
    "swimming pool"
)



def draw_boxes(image, boxes, scores, labels, score_threshold=0.5):
    """ Draws bounding boxes on given image.

    Args:
        image: np.array, shape of (H, W, C) and dtype is uint8.
        boxes: np.array, shape of (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
        scores: np.array, shape of (N,) the scores of bounding boxes.
        labels: list of str, List containing the labels of bounding boxes.
        score_threshold: float, the score of threshold.
    """
    for i, cls_name in enumerate(labels):
        box = boxes[i]
        score = scores[i]
        if score < score_threshold:
            continue

        label = '{} {:.2f}'.format(cls_name, score)
        xc, yc, w, h, ag = box

        wx, wy = w / 2 * np.cos(ag), w / 2 * np.sin(ag)
        hx, hy = -h / 2 * np.sin(ag), h / 2 * np.cos(ag)
        p1 = (xc - wx - hx, yc - wy - hy)
        p2 = (xc + wx - hx, yc + wy - hy)
        p3 = (xc + wx + hx, yc + wy + hy)
        p4 = (xc - wx + hx, yc - wy + hy)

        points = np.array([p1, p2, p3, p4], dtype=np.int32)

        # 绘制多边形
        cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=1, lineType=cv2.LINE_AA)

        # 找到右上角点
        right_points = [tuple(p) for p in points if p[0] == max(points[:, 0])]
        top_right = min(right_points, key=lambda p: p[1])
        offset = (5, -10)
        text_position = (top_right[0] + offset[0], top_right[1] + offset[1])

        # 绘制文字
        cv2.putText(image, label, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255))

    return image


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc # type: ignore[import]
        from rbcc.backend.runner import SGRunner # type: ignore[import]
        sg = rbcc.load_sg(args.model_file)
        device = 'cpu' if args.device_id < 0 else f"cuda:{args.device_id}"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


def demo(args):
    runner = init_runner(args)

    img, img_meta = preprocess(args.img_file, True)
    outs = runner(img)
    bboxes, scores, labels, obb = postprocess(outs, img_meta)
    bboxes = np.concatenate([bboxes, obb[:, None]], -1)
    o_img = cv2.imread(args.img_file)
    
    runner.close()
    labels_str = [CLASS_NAMES[int(i)] for i in labels]
    r_img = draw_boxes(o_img, bboxes, scores, labels_str, score_threshold=0.5)
    cv2.imwrite('demo_predict.jpg', r_img)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('img_file', type=str, help="input image.")
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--e2e', action="store_true", help="test end2end")
    demo(parser.parse_args())