import os
import argparse
import numpy as np
from tqdm import tqdm
from easydict import EasyDict

from preprocess import HRNetCOCODataset
from preprocess import preprocess, postprocess


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load dataset.")
    cfg = EasyDict({
        'MODEL': {
            'NUM_JOINTS': 17,
            'IMAGE_SIZE': [192, 256],
        },
        'TEST': {
            'COLOR_RGB': True,
            'POST_PROCESS': True, 
            'SHIFT_HEATMAP': True,
            'USE_GT_BBOX': True, 
            'IMAGE_THRE': 0.0, 
            'NMS_THRE': 1.0, 
            'SOFT_NMS': False, 
            'OKS_THRE': 0.9,
            'IN_VIS_THRE': 0.2,
            'BBOX_THRE': 1.0,
            'COCO_BBOX_FILE': ''
        }}
    )
    val_loader = HRNetCOCODataset(cfg, args.data, 'val2017', transform=preprocess)

    # 开始进行评价
    print("INFO: Start inference.")
    num_samples = args.eval_nums if args.eval_nums else len(val_loader)
    pbar = tqdm(desc=f"Eval in COCO2017", unit='img', total=num_samples)
    all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3), dtype=np.float32)
    all_boxes = np.zeros((num_samples, 6))
    image_path = []
    for index, (img, meta) in enumerate(val_loader):
        if index == args.eval_nums:
            break
        output = runner(img)
        meta.update({'index': index,'cfg': cfg})
        postprocess(output, meta, all_preds, all_boxes, image_path)
        pbar.update(1)
    
    runner.close()
    print("INFO: HRNetCOCO evaluate....", flush=True)
    name_values, perf_indicator = val_loader.evaluate(all_preds, all_boxes, image_path)
    print(f"INFO: box@map {perf_indicator}", flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--eval-nums', type=int, help="The number of data sets that need to be verified.")
    main(parser.parse_args())
