import os
import argparse
import cv2
import torch
import glob
import numpy as np
from tqdm import tqdm

from preprocess import evaluate


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)


def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


def generate_test_dataset(args, transform):
    for image_path in glob.glob(os.path.join(args.data, 'images', '*', '*.jpg')):
        if not image_path.endswith('.jpg'):
            continue
        img, meta = transform(image_path, with_meta=True, input_size=[1080, 1920])
        yield img, meta


def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load dataset.")
    val_loader = generate_test_dataset(args, preprocess)
    gt_file = os.path.join(args.data, 'ground_truth')
    event_list = sorted(os.listdir(os.path.join(args.data, 'images')))

    # 开始进行评价
    print("INFO: Start inference.")
    num_samples = args.eval_nums if args.eval_nums is not None else 3226
    # pbar = tqdm(desc=f"Eval in Wider-Face", unit='img', total=num_samples)
    result = {event: {} for event in event_list}
    for index, (img, img_meta) in enumerate(tqdm(val_loader, desc=f"Eval in Wider-Face", unit='img', total=num_samples)):
        if index == args.eval_nums:
            break
        pred = runner(img)
        boxes = postprocess(pred, img_meta)
        event_name = img_meta['event_name']
        image_name = img_meta['image_name']
        result[event_name][image_name] = boxes
        
    runner.close()
    evaluate(result, gt_file)


def postprocess(pred, img_meta):
    boxes, scores = torch.from_numpy(pred[0].astype(np.float32)), pred[1]
    det_num = boxes.shape[0]
    boxes = scale_coords(img_meta['input_size'], boxes, img_meta['ori_size'])

    res = []
    for idx in range(det_num):
        box, conf = boxes[idx], scores[idx]
        x1 = int(box[0] + 0.5)
        y1 = int(box[1] + 0.5)
        x2 = int(box[2] + 0.5)
        y2 = int(box[3] + 0.5)
        res.append([x1, y1, x2-x1, y2-y1, conf])
    res = np.array(res)
    return res

def letterbox(img, new_shape, new_unpad, color=(114, 114, 114)):
    if not isinstance(new_shape, (tuple, list)):
        new_shape = [new_shape, new_shape]
    dw, dh = (new_shape[0] - new_unpad[0]) / 2, (new_shape[1] - new_unpad[1]) / 2
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
    return img


def scale_coords(img1_shape, coords, img0_shape):
    gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
    pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, [0, 2]] /= gain
    coords[:, [1, 3]] /= gain
    coords[:, 0::2].clamp_(0, img0_shape[1])  # x1
    coords[:, 1::2].clamp_(0, img0_shape[0])  # y1
    return coords


def preprocess(img_path, with_meta=False, input_size=[1080, 1920]):
    ori_img = cv2.imread(img_path)
    ori_size = ori_img.shape[:2]
    r = min(input_size[0] / ori_size[0], input_size[1] / ori_size[1])
    new_unpad = int(round(ori_size[1] * r)), int(round(ori_size[0] * r))
    img = cv2.resize(ori_img, new_unpad, interpolation=cv2.INTER_LINEAR)
    img = letterbox(img, input_size[::-1], new_unpad)
    img = img[None, :, :, ::-1]
    if with_meta:
        img_meta = {
            'ori_size': ori_size,
            'input_size': input_size,
            'event_name': img_path.split('/')[-2],
            'image_name': img_path.split('/')[-1].rstrip('.jpg')
        }
        return img, img_meta
    
    return img



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--eval-nums', type=int, help="The number of data sets that need to be verified.")
    main(parser.parse_args())
