import os
import sys
import types
import argparse
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
import pycocotools.mask as mask_util
from pycocotools.cocoeval import COCOeval


def load_py(py_path):
    with open(py_path, 'r') as f:
        source = f.read()
    mod = sys.modules.setdefault(py_path, types.ModuleType(py_path))
    code = compile(source, py_path, 'exec')
    mod.__file__ = py_path
    mod.__package__ = ''
    exec(code, mod.__dict__)
    return mod

class COCODataset(object):

    def __init__(self, coco_dir, coco, transform, total=None):
        self.coco = coco
        self.coco_dir = coco_dir
        self.transform = transform
        self.img_ids, self.img_files = self.get_coco_files()
        if total is not None:
            self.img_files = self.img_files[:total]
            self.img_ids = self.img_ids[:total]

    def __getitem__(self, index):
        img_file, img_id = self.img_files[index], self.img_ids[index]
        inputs, img_meta = self.transform(img_file, with_meta=True)
        img_meta['img_id'] = img_id
        return inputs, img_meta

    def __len__(self):
        return len(self.img_files)

    def get_coco_files(self):
        gtfile = os.path.join(self.coco_dir, 'annotations', 'instances_val2017.json')
        self.coco = self.coco(gtfile, progress_bar=False)
        img_ids, img_files = [], []
        for name in self.coco.file_names():
            img_files.append(os.path.join(self.coco_dir, 'val2017', name))
            img_ids.append(self.coco.now_img['id'])
        return img_ids, img_files

class CocoEvaluator(object):

    def __init__(self, val_json, progress_bar=True, iouType='bbox'):
        self.cocoGt = COCO(val_json)
        self.eval_resutls = []
        self.now_img = None
        self.iter = 0
        self.progress_bar = progress_bar
        self.iouType = iouType

    def collect(self, box, score, clazz, image_id):
        xmin, ymin, xmax, ymax = box
        w = xmax - xmin
        h = ymax - ymin
        self.eval_resutls.append({
            "image_id": image_id,
            "category_id": clazz,
            "bbox": [xmin, ymin, w, h],
            "score": score
        })

    def collect_with_mask(self, box, score, clazz, mask, id):
        self.collect(box, score, clazz, id)
        rle = mask_util.encode(np.array(mask, dtype=np.uint8, order="F"))
        rle['counts'] = rle['counts'].decode("utf-8")
        self.eval_resutls[-1]['segmentation'] = rle

    def collect_mask(self, score, clazz, mask, id):
        rle = mask_util.encode(np.array(mask, dtype=np.uint8, order="F"))
        rle['counts'] = rle['counts'].decode("utf-8")
        self.eval_resutls.append({
            "image_id": id,
            "category_id": clazz,
            "score": score,
            "segmentation": rle
        })

    def file_names(self, catNames=None):
        catIds = self.cocoGt.getCatIds(catNms=catNames)
        imgIds = self.cocoGt.getImgIds(catIds=catIds)
        cocoimgs = self.cocoGt.loadImgs(imgIds)
        for i, cocoimg in enumerate(cocoimgs):
            self.now_img = cocoimg
            self.iter = i + 1
            if self.progress_bar:
                print('\r{}/{}'.format(self.iter, len(cocoimgs)), end='', flush=True)
            yield cocoimg['file_name']

    def evaluate(self):
        cocoDt = self.cocoGt.loadRes(self.eval_resutls)
        cocoEval = COCOeval(self.cocoGt, cocoDt, self.iouType)
        cocoEval.params.imgIds = cocoDt.getImgIds()
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        return cocoEval.stats[0]

    def evaluate_imgs(self, img_ids=None, catNames=None):
        cocoDt = self.cocoGt.loadRes(self.eval_resutls)
        cocoEval = COCOeval(self.cocoGt, cocoDt, self.iouType)
        if catNames:
            cocoEval.params.catIds = sorted(self.cocoGt.getCatIds(catNms=catNames))
            cocoEval.params.imgIds = sorted(self.cocoGt.getImgIds(cocoEval.params.catIds))
        img_ids = [r['image_id'] for r in self.eval_resutls] if img_ids is None else img_ids
        cocoEval.params.imgIds = img_ids
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()
        return cocoEval.stats[0]
       
    def evaluate_when(self, iters, max_iter=None):
        if max_iter is not None:
            if self.iter == max_iter:
                img_ids = [r['image_id'] for r in self.eval_resutls]
                return self.evaluate_imgs(img_ids)

        if self.iter in iters:
            return self.evaluate()

class YoloCocoEvaluator(CocoEvaluator):

    def __init__(self, val_json, progress_bar=True, iouType='bbox'):
        super(YoloCocoEvaluator, self).__init__(val_json, progress_bar, iouType)
        catIds = [k for k, _ in self.cocoGt.cats.items()]
        self.label_dict = {i + 1: c for i, c in enumerate(catIds)}

    def collect(self, box, score, clazz, image_id):
        xmin, ymin, xmax, ymax = box
        w = xmax - xmin
        h = ymax - ymin
        self.eval_resutls.append({
            "image_id": image_id,
            "category_id": self.label_dict[clazz+1],
            "bbox": [xmin, ymin, w, h],
            "score": score
        })

class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id

    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc # type: ignore[import]
        from rbcc.backend.runner import SGRunner # type: ignore[import]
        sg = rbcc.load_sg(args.model_file)
        device = 'cpu' if args.device_id < 0 else f"cuda:{args.device_id}"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False, id=None)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load preprocess.")
    preprocess_module = load_py(args.preprocess_py)
    transform = preprocess_module.preprocess
    postprocess = preprocess_module.postprocess

    print("INFO: Load dataset.")
    Evaluator = YoloCocoEvaluator if args.yolo_coco else CocoEvaluator
    dataset = COCODataset(args.data, Evaluator, transform, total=args.eval_nums)
    coco = dataset.coco

    # 开始进行评价
    print("INFO: Start inference.")
    pbar = tqdm(desc=f"Eval in COCO2017", unit='img', total=len(dataset))
    for img, img_meta in dataset:
        outs = runner(img)
        image_id = img_meta['img_id']
        bbox, score, lable = postprocess(outs, img_meta, conf=args.conf)
        for box, score, clazz in zip(bbox, score, lable):
            coco.collect(box, score, clazz, image_id)
        pbar.update(1)

    pbar.close()
    runner.close()
    print("INFO: COCO evaluate....", flush=True)
    map_val = coco.evaluate_imgs()
    print(f"INFO: box@map {map_val}", flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--preprocess-py', type=str, default=None, help='preprocessing script.')
    parser.add_argument('--data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--eval-nums', type=int, default=5000, help="The number of data sets that need to be verified.")
    parser.add_argument('--conf', type=float, default=0.001, help="The value of nms score threshold.")
    parser.add_argument('--yolo-coco', action="store_true", help="If use yolo coco evaluator.")
    main(parser.parse_args())