import os
import cv2
import argparse
import torch
from preprocess import preprocess, non_max_suppression
from preprocess import scale_coords, letterbox, crop_aligned_face


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc # type: ignore[import]
        from rbcc.backend.runner import SGRunner # type: ignore[import]
        sg = rbcc.load_sg(args.model_file)
        device = 'cpu' if args.device_id < 0 else f"cuda:{args.device_id}"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


def demo(args):
    runner = init_runner(args)
    im0 = cv2.imread(args.img_file)
    if args.e2e:
        img = im0.copy()
        img = keep_ratio_resize(img, (1080, 1920))
        boxes, scores, landmarks = runner(img[None, :, :, ::-1])
        det_num = boxes.shape[0]
        im0 = img.copy()
        print(det_num, 'face' if det_num == 1 else 'faces')
        if not det_num:
            return 
        for j in range(det_num):
            box = boxes[j]
            landmark = landmarks[j]
            score = scores[j]
            if score < args.score_thre:
                continue
            im0 = show_results(im0, box, score, landmark, step=3)
            if args.crop_face:
                aligned_face = crop_aligned_face(im0.copy(), landmark)
                cv2.imwrite(f'face_{j}.jpg', aligned_face)
                print(f"INFO: New image has saved in './face_{j}.jpg'")
    else:
        img = preprocess(args.img_file)
        pred = runner(img)[0]
        pred = torch.from_numpy(pred).float()
        det = non_max_suppression(pred, conf_thres=0.5, iou_thres=0.45)[0]
        print(len(det), 'face' if len(det) == 1 else 'faces')
        if not len(det):
            return 
        # Rescale boxes from img_size to im0 size
        scale_coords(img.shape[1:3], det[:, :4], im0.shape, kpt_label=False)
        scale_coords(img.shape[1:3], det[:, 6:], im0.shape, kpt_label=5, step=3)
        for j in range(det.size()[0]):
            xyxy = det[j, :4].view(-1).tolist()
            conf = det[j, 4].cpu().numpy()
            if conf < args.score_thre:
                continue
            landmark = det[j, 6:].view(-1).tolist()
            im0 = show_results(im0, xyxy, conf, landmark)
            if args.crop_face:
                aligned_face = crop_aligned_face(im0.copy(), landmark)
                cv2.imwrite(f'face_{j}.jpg', aligned_face)
                print(f"INFO: New image has saved in './face_{j}.jpg'")
    
    runner.close()
    cv2.imwrite('demo_predict.jpg', im0)
    print(f"INFO: New image has saved in './demo_predict.jpg'")


def show_results(img, xyxy, conf, landmarks, step=3):
    h,w,c = img.shape
    tl = 1 or round(0.002 * (h + w) / 2) + 1  # line/font thickness
    x1 = int(xyxy[0])
    y1 = int(xyxy[1])
    x2 = int(xyxy[2])
    y2 = int(xyxy[3])
    img = img.copy()
    cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=3, lineType=cv2.LINE_AA)
    clors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (0,255,255)]
    for i in range(5):
        point_x = int(landmarks[step * i])
        point_y = int(landmarks[step * i + 1])
        # point_conf = float(landmarks[step * i + 2])
        cv2.circle(img, (point_x, point_y), 5, clors[i], -1)
    tf = max(tl - 1, 1)  # font thickness
    label = str(conf)[:5]
    cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    return img


def keep_ratio_resize(ori_img, tar_size):
    ori_size = ori_img.shape[:2]
    r = min(tar_size[0] / ori_size[0], tar_size[1] / ori_size[1])
    new_unpad = int(round(ori_size[1] * r)), int(round(ori_size[0] * r))
    img = cv2.resize(ori_img, new_unpad, interpolation=cv2.INTER_LINEAR)
    img = letterbox(img, list(tar_size)[::-1], new_unpad)
    return img


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('img_file', type=str, help="input image.")
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--score_thre', type=float, default=0.7, help="score threshould.")
    parser.add_argument('--e2e', action="store_true", help="test end2end")
    parser.add_argument('--crop-face', action="store_true", help="test end2end")
    demo(parser.parse_args())