import os
import argparse

import cv2
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
import pycocotools.mask as mask_util
from pycocotools.cocoeval import COCOeval

from preprocess import preprocess, postprocess


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load dataset.")
    val_loader = generate_test_dataset(args, preprocess)

    # 开始进行评价
    print("INFO: Start inference.")
    num_samples = args.eval_nums if args.eval_nums else 2198
    pbar = tqdm(desc=f"Eval in COCO2017", unit='img', total=num_samples)
    gt_file = os.path.join(args.label, 'valcoco.json')
    label_txt = os.path.join(args.label, 'val2017.txt') 
    coco = PoseCocoEvaluator(gt_file, progress_bar=False)
    coco_files = coco.file_names(label_txt)

    ORDER_COCO = [0, 15, 14, 17, 16, 5, 2, 6, 3, 7, 4, 11, 8, 12, 9, 13, 10]

    outputs = []
    image_ids = []
    for index, (img, img_size) in enumerate(val_loader):
        if index == args.eval_nums:
            break

        vec6, heat6 = runner(img)
        h, w = img_size
        vec6 = cv2.resize(vec6[0], (w, h), interpolation=cv2.INTER_CUBIC)
        heat6 = cv2.resize(heat6[0], (w, h), interpolation=cv2.INTER_CUBIC)

        joint_list, person_to_joint_assoc, _ = postprocess(heat6, vec6, img_size)
        img_name = next(coco_files)
        image_ids.append(img_name.split('.')[0])
        for ridxPred in range(len(person_to_joint_assoc)):
            result = {
                "image_id": image_ids[-1],
                "category_id": 1,
                "keypoints": [],
                "score": 0
            }
            keypoints = np.zeros((17, 3))

            for part in range(17):
                ind = ORDER_COCO[part]
                index = int(person_to_joint_assoc[ridxPred, ind])
                if -1 == index:
                    keypoints[part, 0] = 0
                    keypoints[part, 1] = 0
                    keypoints[part, 2] = 0
                else:
                    keypoints[part, 0] = joint_list[index, 0] 
                    keypoints[part, 1] = joint_list[index, 1] 
                    keypoints[part, 2] = 1
            result["score"] = person_to_joint_assoc[ridxPred, -2] * person_to_joint_assoc[ridxPred, -1]
            result["keypoints"] = list(keypoints.reshape(51))
            outputs.append(result)
        pbar.update(1)

    runner.close()
    coco.collect(outputs)
    mAP = coco.evaluate_imgs(image_ids)
    print(f"INFO:  mAP: {mAP}", flush=True)


def generate_test_dataset(args, transform):
    gt_file = os.path.join(args.label, 'valcoco.json')
    label_txt = os.path.join(args.label, 'val2017.txt')
    coco = PoseCocoEvaluator(gt_file, progress_bar=False)
    for file_name in coco.file_names(label_txt):
        img_file = os.path.join(args.data, 'val2017', file_name)
        img, img_size = transform(img_file, return_size=True)
        yield img, img_size


class CocoEvaluator(object):

    def __init__(self, val_json, progress_bar=True, iouType='bbox'):
        self.cocoGt = COCO(val_json)
        self.eval_resutls = []
        self.now_img = None
        self.iter = 0
        self.progress_bar = progress_bar
        self.iouType = iouType
  
    def collect(self, box, score, clazz, image_id):
        xmin, ymin, xmax, ymax = box
        w = xmax - xmin
        h = ymax - ymin
        self.eval_resutls.append({
            "image_id": image_id,
            "category_id": clazz,
            "bbox": [xmin, ymin, w, h],
            "score": score
        })
  
    def collect_with_mask(self, box, score, clazz, mask, id):
        self.collect(box, score, clazz, id)
        rle = mask_util.encode(np.array(mask, dtype=np.uint8, order="F"))
        rle['counts'] = rle['counts'].decode("utf-8")
        self.eval_resutls[-1]['segmentation'] = rle
    
    def collect_mask(self, score, clazz, mask, id):
        rle = mask_util.encode(np.array(mask, dtype=np.uint8, order="F"))
        rle['counts'] = rle['counts'].decode("utf-8")
        self.eval_resutls.append({
            "image_id": id,
            "category_id": clazz,
            "score": score,
            "segmentation": rle
        })
        
    def file_names(self, catNames=None):
        catIds = self.cocoGt.getCatIds(catNms=catNames)
        imgIds = self.cocoGt.getImgIds(catIds=catIds)
        cocoimgs = self.cocoGt.loadImgs(imgIds)
        for i, cocoimg in enumerate(cocoimgs):
            self.now_img = cocoimg
            self.iter = i + 1
            if self.progress_bar:
                print('\r{}/{}'.format(self.iter, len(cocoimgs)), end='', flush=True)
            yield cocoimg['file_name']
  
    def evaluate(self):
        cocoDt = self.cocoGt.loadRes(self.eval_resutls)
        cocoEval = COCOeval(self.cocoGt, cocoDt, self.iouType)
        cocoEval.params.imgIds = cocoDt.getImgIds()
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        return cocoEval.stats[0]

    def evaluate_imgs(self, img_ids=None, catNames=None):
        cocoDt = self.cocoGt.loadRes(self.eval_resutls)
        cocoEval = COCOeval(self.cocoGt, cocoDt, self.iouType)
        if catNames:
            cocoEval.params.catIds = sorted(self.cocoGt.getCatIds(catNms=catNames))
            cocoEval.params.imgIds = sorted(self.cocoGt.getImgIds(cocoEval.params.catIds))
        img_ids = [r['image_id'] for r in self.eval_resutls] if img_ids is None else img_ids
        cocoEval.params.imgIds = img_ids
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        return cocoEval.stats[0]
  
    def evaluate_when(self, iters, max_iter=None):
        if max_iter is not None:
            if self.iter == max_iter:
                img_ids = [r['image_id'] for r in self.eval_resutls]
                return self.evaluate_imgs(img_ids)

        if self.iter in iters:
            return self.evaluate()


class PoseCocoEvaluator(CocoEvaluator):

    def __init__(self, val_json, progress_bar=True, iouType='keypoints'):
        super(PoseCocoEvaluator, self).__init__(val_json, progress_bar, iouType)

    def collect(self, outputs):
        self.eval_resutls = outputs

    def evaluate_imgs(self, img_ids):
        cocoDt = self.cocoGt.loadRes(self.eval_resutls)
        cocoEval = COCOeval(self.cocoGt, cocoDt, self.iouType)
        cocoEval.params.imgIds = img_ids
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        return cocoEval.stats[0]

    def file_names(self, label_txt):
        with open(label_txt, 'r') as f:
            imgs = list(f.readlines())

        for img in imgs:
            self.now_img = img
            self.iter += 1
            yield img.strip()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--label', type=str, default=None, help='label dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--eval-nums', type=int, help="The number of data sets that need to be verified.")
    main(parser.parse_args())
