import os
import argparse
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.metrics import matthews_corrcoef


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            outs = self.runner.inference(self.id, list(args))
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs
    
    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, inputs=["input_id", "attention_mask"], device=device, gc_collect=True)
        return WrapRunner(runner, False, id=None)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")

def generate_test_dataset(args):
    if os.path.isdir(args.data):
        cola_val = load_dataset(args.data, split="validation")
    else:
        cola_val = load_dataset(args.data, "cola", split="validation")
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name,
        use_fast=True
    )
    for index, sample in enumerate(cola_val):
        # 调用预处理，得到预处理后的 numpy 数组
        tokens = tokenizer(sample['sentence'], padding='max_length', max_length=128, truncation=True)
        input_ids = np.array([tokens['input_ids']])
        attention_mask = np.array([tokens['attention_mask']])
        label = sample['label']
        yield input_ids, attention_mask, label

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load dataset.")
    val_loader = generate_test_dataset(args)

    # 开始进行评价
    print("INFO: Start inference.")
    pbar = tqdm(desc=f"Eval bert on cola", unit='img')
    preds, labels = [], []
    for input_ids, attention_mask, label in val_loader:
        logit = runner(input_ids, attention_mask)

        if logit is None:
            runner.close()
            raise RuntimeError(f"Run rbo error.")

        pred = np.argmax(logit, 1)
        preds.append(pred)
        labels.append(label)
        pbar.update(1)

    pbar.close()
    mat_cor = matthews_corrcoef(labels, preds)
    print(f"INFO: matthews_corrcoef: {mat_cor * 100:.3f}", flush=True)
    runner.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default='nyu-mll/glue', help='dataset dir.')
    parser.add_argument('--tokenizer-name', type=str, default='ModelTC/bert-base-uncased-cola', help='tokenizer name or path.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    main(parser.parse_args())