import os
import argparse
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100

from preprocess import preprocess, get_text_weights


def accuracy(output, target, topk=(1,)):
    pred = numpy_topk(output, max(topk), 1)[1].T
    correct = (pred == target.reshape(1, -1))
    return [float(np.sum(correct[:k].reshape(-1).astype('float'), 0)) for k in topk]


def numpy_topk(inputs: np.ndarray, k: int, axis: int = None):
    indices = np.argsort(inputs * -1.0, axis=axis)
    indices = np.take(indices, np.arange(k), axis=axis)
    values = np.take_along_axis(inputs, indices, axis=axis)
    return values, indices


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")


def main(args):
    print("INFO: Init Runner.")
    vision_model_file = args.vision_model_file
    text_model_file = args.text_model_file

    args.model_file = text_model_file
    text_runner = init_runner(args)
    
    args.model_file = vision_model_file
    vision_runner = init_runner(args)

    print("INFO: Load dataset.")
    cifar100 = CIFAR100(root=args.data, transform=preprocess, train=False)
    loader = DataLoader(cifar100, batch_size=1, shuffle=False, num_workers=8)

    print("INFO: Start inference.")
    zeroshot_weights = get_text_weights(args, text_runner)

    top1, top5, total = 0, 0, 0
    total_num = args.eval_nums if args.eval_nums else 10000
    pbar = tqdm(desc=f"Eval in CIFAR-100", unit='img', total=total_num)
    for index, (img, label) in enumerate(loader):
        if index == total_num:
            break
        image_features = vision_runner(img)
        image_features /= np.linalg.norm(image_features, axis=-1, keepdims=True)
        logits = 100. * image_features @ zeroshot_weights
        label = label.cpu().numpy().astype(np.int64)
        acc1, acc5 = accuracy(logits, label, topk=(1, 5))
        top1 += acc1
        top5 += acc5
        total += 1
        pbar.update(1)
    pbar.close()

    top1 = round((top1 / total) * 100, 2)
    top5 = round((top5 / total) * 100, 2) 
    print(f"Top1-Acc: {top1}, Top5-Acc: {top5}", flush=True)

    text_runner.close()
    vision_runner.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('vision_model_file', type=str, help='rbo or sg file dir.')
    parser.add_argument('text_model_file', type=str, help='rbo or sg file dir.')
    parser.add_argument('--data', type=str, default=None, help='CIFAR100 dataset dir.')
    parser.add_argument('--eval-nums', type=int, help="The number of data sets that need to be verified.")
    parser.add_argument('--device-id', type=int, default=0, help="device id.")

    main(parser.parse_args())

