from ultralytics import YOLO
from ultralytics.models.yolo.segment.val import SegmentationValidator

import os
import sys
import json
import yaml
import argparse
import numpy as np

import torch
import torch.distributed as dist

from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model

from preprocess import postprocess


class WrapRunner(object):

    def __init__(self, model_file, device_id=0):
        self.runner = None
        self.rbo_mode = None
        self.id = None
        self.init_runner(model_file, device_id)
    
    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

    def init_runner(self, model_file, device_id=0):
        suffix = os.path.splitext(model_file)[1]
        print(f"INFO: Load model from {model_file}")
        if suffix == '.sg':
            import rbcc # type: ignore[import]
            from rbcc.backend.runner import SGRunner # type: ignore[import]
            sg = rbcc.load_sg(args.model_file)
            device = 'cpu' if device_id < 0 else f"cuda:{device_id}"
            self.runner = SGRunner(sg, device=device, gc_collect=True)
            self.rbo_mode = False
        elif suffix == '.rbo':
            from rboexec.core.client import GRPCClient
            client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
            id = client.upload_model(model_path=os.path.abspath(model_file), device_id=device_id)
            self.runner = client
            self.rbo_mode = True
            self.id = id
        else:
            raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


class MySegmentationValidator(SegmentationValidator):

    def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
        model_file = args.pop('model_file')
        self.net_runner = None
        if model_file is not None:
            self.net_runner = WrapRunner(model_file, device_id=args['device'])

        super().__init__(dataloader, save_dir, args, _callbacks)

    @smart_inference_mode()
    def __call__(self, trainer=None, model=None):
        """
        Execute validation process, running inference on dataloader and computing performance metrics.

        Args:
            trainer (object, optional): Trainer object that contains the model to validate.
            model (nn.Module, optional): Model to validate if not using a trainer.

        Returns:
            (dict): Dictionary containing validation statistics.
        """
        self.training = trainer is not None
        augment = self.args.augment and (not self.training)
        if self.training:
            self.device = trainer.device
            self.data = trainer.data
            # Force FP16 val during training
            self.args.half = self.device.type != "cpu" and trainer.amp
            model = trainer.ema.ema or trainer.model
            if trainer.args.compile and hasattr(model, "_orig_mod"):
                model = model._orig_mod  # validate non-compiled original model to avoid issues
            model = model.half() if self.args.half else model.float()
            self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
            self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
            model.eval()
        else:
            if str(self.args.model).endswith(".yaml") and model is None:
                LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
            callbacks.add_integration_callbacks(self)
            model = AutoBackend(
                model=model or self.args.model,
                device=select_device(torch.device("cuda", self.args.device)) if RANK == -1 else torch.device("cuda", RANK),
                dnn=self.args.dnn,
                data=self.args.data,
                fp16=self.args.half,
            )
            self.device = model.device  # update device
            self.args.half = model.fp16  # update half
            stride, pt, jit = model.stride, model.pt, model.jit
            imgsz = check_imgsz(self.args.imgsz, stride=stride)
            if not (pt or jit or getattr(model, "dynamic", False)):
                self.args.batch = model.metadata.get("batch", 1)  # export.py models default to batch-size 1
                LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")

            if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
                self.data = check_det_dataset(self.args.data)
            elif self.args.task == "classify":
                self.data = check_cls_dataset(self.args.data, split=self.args.split)
            else:
                raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))

            if self.device.type in {"cpu", "mps"}:
                self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading
            if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
                self.args.rect = False
            self.stride = model.stride  # used in get_dataloader() for padding
            self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)

            model.eval()
            if self.args.compile:
                model = attempt_compile(model, device=self.device)
            model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz))  # warmup

        self.run_callbacks("on_val_start")
        dt = (
            Profile(device=self.device),
            Profile(device=self.device),
            Profile(device=self.device),
            Profile(device=self.device),
        )
        bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
        self.init_metrics(unwrap_model(model))
        self.jdict = []  # empty before each val
        for batch_i, batch in enumerate(bar):
            self.run_callbacks("on_val_batch_start")
            self.batch_i = batch_i
            # Preprocess
            with dt[0]:
                batch = self.preprocess(batch)

            # Inference
            with dt[1]:
                if self.net_runner is None:
                    preds = model(batch["img"], augment=augment)
                else:
                    preds = []
                    img_input = batch["img"].permute(0, 2, 3, 1).contiguous()
                    for single_img_input in img_input:
                        single_img_input = single_img_input[None]
                        if self.net_runner.rbo_mode:
                            single_img_input = single_img_input.cpu().numpy()
                        outs = self.net_runner(single_img_input)
                        boxes, scores, labels, masks = postprocess(outs, {'input_size': list(img_input.shape[1:3])}, 
                                                                with_scale=False, 
                                                                conf=self.args.conf,
                                                                iou_thres=self.args.iou,
                                                                max_det=self.args.max_det,
                                                                return_numpy=False)
                        preds.append({
                            "bboxes": boxes, 
                            "conf": scores, 
                            "cls": labels, 
                            "masks": masks, 
                        })

            # Loss
            with dt[2]:
                if self.training:
                    self.loss += model.loss(batch, preds)[1]

            if self.net_runner is None:
                # Postprocess
                with dt[3]:
                    preds = self.postprocess(preds)

            self.update_metrics(preds, batch)
            if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
                self.plot_val_samples(batch, batch_i)
                self.plot_predictions(batch, preds, batch_i)

            self.run_callbacks("on_val_batch_end")

        stats = {}
        self.gather_stats()
        if RANK in {-1, 0}:
            stats = self.get_stats()
            self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
            self.finalize_metrics()
            self.print_results()
            self.run_callbacks("on_val_end")

        if self.training:
            model.float()
            # Reduce loss across all GPUs
            loss = self.loss.clone().detach()
            if trainer.world_size > 1:
                dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
            if RANK > 0:
                return
            results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
            return {k: round(float(v), 5) for k, v in results.items()}  # return results as 5 decimal place floats
        else:
            if RANK > 0:
                return stats
            LOGGER.info(
                "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
                    *tuple(self.speed.values())
                )
            )
            if self.args.save_json and self.jdict:
                with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
                    LOGGER.info(f"Saving {f.name}...")
                    json.dump(self.jdict, f)  # flatten and save
                stats = self.eval_json(stats)  # update stats
            if self.args.plots or self.args.save_json:
                LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
            return stats


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, default=None, help='rbo or sg file path.')
    parser.add_argument('data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    args = parser.parse_args()

    suffix = os.path.splitext(args.model_file)[1]
    if suffix != '.pt':
        # No nedd real yaml file
        pt = 'yolov8s-seg.yaml'
    else:
        pt = args.model_file
        args.model_file = None

    model = YOLO(pt)

    # Validate with a custom dataset
    coco = {
        # # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
        "path": args.data, # dataset root dir
        "train": "train.txt",
        "val": "val2017.txt",
        "test": "test-dev2017.txt",
        # Classes
        "names": {
            0: 'person',
            1: 'bicycle',
            2: 'car',
            3: 'motorcycle',
            4: 'airplane',
            5: 'bus',
            6: 'train',
            7: 'truck',
            8: 'boat',
            9: 'traffic light',
            10: 'fire hydrant',
            11: 'stop sign',
            12: 'parking meter',
            13: 'bench',
            14: 'bird',
            15: 'cat',
            16: 'dog',
            17: 'horse',
            18: 'sheep',
            19: 'cow',
            20: 'elephant',
            21: 'bear',
            22: 'zebra',
            23: 'giraffe',
            24: 'backpack',
            25: 'umbrella',
            26: 'handbag',
            27: 'tie',
            28: 'suitcase',
            29: 'frisbee',
            30: 'skis',
            31: 'snowboard',
            32: 'sports ball',
            33: 'kite',
            34: 'baseball bat',
            35: 'baseball glove',
            36: 'skateboard',
            37: 'surfboard',
            38: 'tennis racket',
            39: 'bottle',
            40: 'wine glass',
            41: 'cup',
            42: 'fork',
            43: 'knife',
            44: 'spoon',
            45: 'bowl',
            46: 'banana',
            47: 'apple',
            48: 'sandwich',
            49: 'orange',
            50: 'broccoli',
            51: 'carrot',
            52: 'hot dog',
            53: 'pizza',
            54: 'donut',
            55: 'cake',
            56: 'chair',
            57: 'couch',
            58: 'potted plant',
            59: 'bed',
            60: 'dining table',
            61: 'toilet',
            62: 'tv',
            63: 'laptop',
            64: 'mouse',
            65: 'remote',
            66: 'keyboard',
            67: 'cell phone',
            68: 'microwave',
            69: 'oven',
            70: 'toaster',
            71: 'sink',
            72: 'refrigerator',
            73: 'book',
            74: 'clock',
            75: 'vase',
            76: 'scissors',
            77: 'teddy bear',
            78: 'hair drier',
            79: 'toothbrush',
        }
    }
    # 写入 YAML 文件
    with open('coco.yaml', 'w', encoding='utf-8') as f:
        yaml.dump(coco, f, default_flow_style=False, allow_unicode=True, indent=2)

    metrics = model.val(data='coco.yaml', validator=MySegmentationValidator, model_file=args.model_file, 
                        device=args.device_id, batch=8, rect=False)
    print(f"INFO: box@map {metrics.box.map}, mask@map {metrics.seg.map}")