import os
import argparse

import cv2
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
from prettytable import PrettyTable

from preprocess import Cityscapes
from preprocess import preprocess, postprocess


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load dataset.")
    val_loader = generate_test_dataset(args, preprocess)

    # 开始进行评价
    print("INFO: Start inference.")
    num_samples = args.eval_nums if args.eval_nums else 500
    pbar = tqdm(desc=f"Eval in CityScapes", unit='img', total=num_samples)
    total_intersect = np.zeros([19], dtype=np.float32)
    total_union = np.zeros([19], dtype=np.float32)
    total_pred_label = np.zeros([19], dtype=np.float32)
    total_area_label = np.zeros([19], dtype=np.float32)
    for index, (img, label) in enumerate(val_loader):
        if index == args.eval_nums:
            break

        pred = runner(img)
        pred = pred[0][0]
        area_intersect, area_union, area_pred_label, area_label = postprocess(pred, label)
        total_intersect += area_intersect
        total_union += area_union
        total_pred_label += area_pred_label
        total_area_label += area_label
        pbar.update(1)

    runner.close()
    metrics = evaluate(total_intersect, total_union, total_pred_label, total_area_label)
    print(f"INFO: {metrics}", flush=True)


def generate_test_dataset(args, transform):
    data = Cityscapes(args.data, split='val', target_type='semantic')
    for i, _ in enumerate(data):
        img = transform(data.images[i])
        label = cv2.imread(data.targets[i][0].replace('labelIds.png', 'labelTrainIds.png'))
        label = label.astype(np.uint8)[:, :, 0]
        yield img, label


def evaluate(total_area_intersect, total_area_union, total_area_pred_label, total_area_label):
    metrics = {}
    metric, nan_to_num, beta = ['mIoU'], None, 1

    all_acc = total_area_intersect.sum() / total_area_label.sum()
    ret_metrics = OrderedDict({'aAcc': all_acc})
    for name in metric:
        if name == 'mIoU':
            iou = total_area_intersect / total_area_union
            acc = total_area_intersect / total_area_label
            ret_metrics['IoU'] = iou
            ret_metrics['Acc'] = acc
    # ret_metrics = {metric: value.numpy() for metric, value in ret_metrics.items()}
    ret_metrics = {metric: value for metric, value in ret_metrics.items()}
    if nan_to_num is not None:
        ret_metrics = OrderedDict({metric: np.nan_to_num(metric_value, nan=nan_to_num)
                                    for metric, metric_value in ret_metrics.items()})
    class_names = (
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
        'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
        'truck', 'bus', 'train', 'motorcycle', 'bicycle')
    # summary table
    ret_metrics_summary = OrderedDict({ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
                                        for ret_metric, ret_metric_value in ret_metrics.items()})
    metric = dict()
    for key, val in ret_metrics_summary.items():
        if key == 'aAcc':
            metric[key] = val
        else:
            metric['m' + key] = val
    # each class table
    ret_metrics.pop('aAcc', None)
    ret_metrics_class = OrderedDict({ret_metric: np.round(ret_metric_value * 100, 2)
                                      for ret_metric, ret_metric_value in ret_metrics.items()})
    ret_metrics_class.update({'Class': class_names})
    ret_metrics_class.move_to_end('Class', last=False)
    class_table_data = PrettyTable()
    for key, val in ret_metrics_class.items():
        class_table_data.add_column(key, val)
    print('per class results:')
    print('\n' + class_table_data.get_string())
    metrics.update(metric)
    return metrics


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--eval-nums', type=int, help="The number of data sets that need to be verified.")
    main(parser.parse_args())
