import os
import sys
import types
import argparse
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader


def load_py(py_path):
    with open(py_path, 'r') as f:
        source = f.read()
    mod = sys.modules.setdefault(py_path, types.ModuleType(py_path))
    code = compile(source, py_path, 'exec')
    mod.__file__ = py_path
    mod.__package__ = ''
    exec(code, mod.__dict__)
    return mod
   

class Accuracy(object):

    def __init__(self, topk=(1, )):
        self.topk = (topk, ) if isinstance(topk, int) else tuple(topk)
        self.maxk = max(self.topk)
        self._results = []

    def numpy_topk(self, inputs, k, axis=None):
        indices = np.argsort(inputs * -1.0, axis=axis)
        indices = np.take(indices, np.arange(k), axis=axis)
        values = np.take_along_axis(inputs, indices, axis=axis)
        return values, indices


    def add(self, predictions, labels):
        corrects = self._compute_corrects(predictions, labels)
        for correct in corrects:
            self._results.append(correct)

    def _compute_corrects(self, predictions, labels):
        if not isinstance(predictions, np.ndarray):
            predictions = np.stack(predictions)
        if not isinstance(labels, np.ndarray):
            labels = np.stack(labels)

        if predictions.ndim == 1:
            corrects = (predictions == labels)
            return corrects.astype(np.int32)

        pred_scores, pred_label = self.numpy_topk(predictions, self.maxk, axis=1)
        pred_label = pred_label.T

        # broadcast `label` to the shape of `pred_label`
        labels = np.broadcast_to(labels.reshape(1, -1), pred_label.shape)
        # compute correct array
        corrects = (pred_label == labels)

        # compute the corrects corresponding to all topk and thrs per sample
        corrects_per_sample = np.zeros((len(predictions), len(self.topk)))
        for i, k in enumerate(self.topk):
            corrects_per_sample[:, i] = corrects[:k].sum(0, keepdims=True).astype(np.int32)

        return corrects_per_sample

    def is_scalar(self, obj):
        """Check if an object is a scalar."""
        try:
            float(obj)
            return True
        except Exception:
            return False

    def compute_metric(self, results):
        if self.is_scalar(results[0]):
            return {'top1': float(sum(results) / len(results))}

        metric_results = {}
        for i, k in enumerate(self.topk):
            corrects = [result[i] for result in results]
            acc = float(sum(corrects) / len(corrects))
            name = f'top{k}'
            metric_results[name] = acc
        return metric_results

    def compute(self):
        return self.compute_metric(self._results)


class Imagenet(object):

    def __init__(self, data_dir, transform=None, total=None):
        self.transform = transform
        self.imgs = self.get_imgs(data_dir, os.path.join(data_dir, 'val.txt'))
        if total is not None:
            self.imgs = self.imgs[:total]

    def get_imgs(self, data_dir, gt_file):
        imgs = []
        with open(gt_file, 'r') as f:
            for line in f.readlines():
                img_name, label = line.strip('\n').split(' ')
                imgs.append((os.path.join(data_dir, img_name), int(label)))
        return imgs
    
    def __getitem__(self, index):
        img_path, label = self.imgs[index]

        if self.transform:
            img = self.transform(img_path)

        return img, np.array([label])

    def __len__(self):
        return len(self.imgs)

class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id

    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")
    

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load preprocess.")
    preprocess_mod = load_py(args.preprocess_py)
    transform = lambda x: preprocess_mod.preprocess(x, need_batch=False)

    print("INFO: Load dataset.")
    dataset = Imagenet(args.data, transform, total=args.eval_nums)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)

    print("INFO: Start inference.")
    topk_acc = Accuracy((1, 5))
    pbar = tqdm(desc=f"Eval in ImageNet-1K", unit='img', total=len(loader))
    for img, label in loader:
        logits = runner(img)
        logits = logits.reshape(logits.shape[0], -1)
        label = label + args.label_offset
        topk_acc.add(logits, label)
        pbar.update(1)

    pbar.close()
    acc = topk_acc.compute()
    print(f"Top1-Acc: {acc['top1'] * 100:.2f}, Top5-Acc: {acc['top5'] * 100:.2f}", flush=True)

    runner.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--preprocess-py', type=str, default=None, help='preprocessing script.')
    parser.add_argument('--data', type=str, default=None, help='Imagenet12 dataset dir.')
    parser.add_argument('--eval-nums', type=int, default=50000, help="The number of data sets that need to be verified.")
    parser.add_argument('--label-offset', type=int, default=0, help="Label offset.")
    parser.add_argument('--batch-size', type=int, default=1, help="batch size.")
    parser.add_argument('--device-id', type=int, default=0, help="device id.")

    main(parser.parse_args())