import os
import string
import argparse
import numpy as np
from tqdm import tqdm

from preprocess import preprocess, postprocess


class WrapRunner(object):

    def __init__(self, runner, rbo_mode, id=None):
        self.runner = runner
        self.rbo_mode = rbo_mode
        self.id = id
    
    def __call__(self, *args):
        if self.rbo_mode:
            inputs = list(args)
            for i in range(len(inputs)):
                if not isinstance(inputs[i], np.ndarray):
                    inputs[i] = inputs[i].numpy()
            outs = self.runner.inference(self.id, inputs)
            if isinstance(outs, list) and len(outs) == 1:
                return outs[0]
            return outs
        else:
            outs = self.runner.run(*args, as_numpy=True)
            return outs

    def close(self):
        if self.rbo_mode:
            self.runner.remove_model(self.id)

def init_runner(args):
    suffix = os.path.splitext(args.model_file)[1]
    print(f"INFO: Load model from {args.model_file}")
    if suffix == '.sg':
        import rbcc
        from rbcc.backend.runner import SGRunner
        sg = rbcc.load_sg(args.model_file)
        device = f"cuda:{args.device_id}" if args.device_id >= 0 else "cpu"
        runner = SGRunner(sg, device=device, gc_collect=True)
        return WrapRunner(runner, False)
    elif suffix == '.rbo':
        from rboexec.core.client import GRPCClient
        client = GRPCClient(os.environ['CAISA_REMOTE_IP'], int(os.environ['CAISA_REMOTE_PORT']))
        id = client.upload_model(model_path=os.path.abspath(args.model_file), device_id=args.device_id)
        return WrapRunner(client, True, id)
    else:
        raise ValueError(f"excepted .sg or .rbo file, not {suffix}")

def main(args):
    print("INFO: Init Runner.")
    runner = init_runner(args)

    print("INFO: Load dataset.")
    val_loader = generate_test_dataset(args, preprocess)

    # 开始进行评价
    print("INFO: Start inference.")
    num_samples = args.eval_nums if args.eval_nums else 2077
    pbar = tqdm(desc=f"Eval in OCR-Rec-Dataset-Examples", unit='img', total=num_samples)
    char_dict_path = os.path.join(args.data, 'ppocr_keys_v1.txt')
    character = gen_character(char_dict_path)
    preds, labels = [], []
    for index, (img, label) in enumerate(val_loader):
        if index == args.eval_nums:
            break

        pred = runner(img)
        preds.append(pred)
        labels.append((label, None))
        pbar.update(1)

    runner.close()
    preds = np.concatenate(preds, 0)
    result_list = postprocess(preds, character)
    metric_dict= RecMetric()(result_list, labels)
    acc = round(metric_dict['acc'] * 100, 2)
    norm_edit_dis = round(metric_dict['norm_edit_dis'] * 100, 2)
    print(f"INFO: Acc: {acc}  Norm Edit Dis: {norm_edit_dis}", flush=True)


def generate_test_dataset(args, transform):
    root = args.data
    name_txt = os.path.join(root, 'val.txt')
    with open(name_txt) as f:
        for line in f.readlines():
            name, label = line.strip().split('	')
            img_path = os.path.join(root, name)
            if not os.path.exists(img_path): continue
            img = transform(img_path)
            yield img, label


def gen_character(char_dict_path):
    with open(char_dict_path, "rb") as fin:
        lines = fin.readlines()
        character = ['blank']
        for line in lines:
            line = line.decode('utf-8').strip("\n").strip("\r\n")
            character.append(line)
        character.append(" ")
    return character


class RecMetric(object):
    def __init__(
        self, main_indicator="acc", is_filter=False, ignore_space=True, **kwargs
    ):
        self.main_indicator = main_indicator
        self.is_filter = is_filter
        self.ignore_space = ignore_space
        self.eps = 1e-5
        self.reset()

    def _normalize_text(self, text):
        text = "".join(
            filter(lambda x: x in (string.digits + string.ascii_letters), text)
        )
        return text.lower()

    def reset(self):
        self.correct_num = 0
        self.all_num = 0
        self.norm_edit_dis = 0

    def normalized_levenshtein_distance(self, pred, target):
        if not pred or not target:
            return 1.0 if pred != target else 0.0
        len_pred, len_target = len(pred), len(target)
        dp = np.zeros((len_pred + 1, len_target + 1), dtype=int)
        for i in range(len_pred + 1):
            dp[i, 0] = i
        for j in range(len_target + 1):
            dp[0, j] = j

        for i in range(1, len_pred + 1):
            for j in range(1, len_target + 1):
                cost = 0 if pred[i - 1] == target[j - 1] else 1
                dp[i, j] = min(
                    dp[i - 1, j] + 1, 
                    dp[i, j - 1] + 1, 
                    dp[i - 1, j - 1] + cost 
                )
        max_len = max(len_pred, len_target)
        return dp[len_pred, len_target] / max_len if max_len > 0 else 0.0
    
    def __call__(self, preds, labels, *args, **kwargs):
        correct_num = 0
        all_num = 0
        norm_edit_dis = 0.0
        for (pred, pred_conf), (target, _) in zip(preds, labels):
            if self.ignore_space:
                pred = pred.replace(" ", "")
                target = target.replace(" ", "")
            if self.is_filter:
                pred = self._normalize_text(pred)
                target = self._normalize_text(target)
            norm_edit_dis += self.normalized_levenshtein_distance(pred, target)
            if pred == target:
                correct_num += 1
            all_num += 1
        self.correct_num += correct_num
        self.all_num += all_num
        self.norm_edit_dis += norm_edit_dis
        return {
            "acc": correct_num / (all_num + self.eps),
            "norm_edit_dis": 1 - norm_edit_dis / (all_num + self.eps),
        }



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default=None, help='dataset dir.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    parser.add_argument('--eval-nums', type=int, help="The number of data sets that need to be verified.")
    main(parser.parse_args())
