import os
import argparse
import rbcc
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import numpy as np


def generate_quant_data(dataset_name, tokenizer_name):
    if os.path.isdir(dataset_name):
        cola = load_dataset(dataset_name, split="train")
    else:
        cola = load_dataset(dataset_name, "cola", split="train")
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name,
        use_fast=True
    )
    for index, sample in tqdm(enumerate(cola)):
        # 调用预处理，得到预处理后的 numpy 数组
        tokens = tokenizer(sample['sentence'], padding='max_length', max_length=128, truncation=True)
        input_ids = np.array([tokens['input_ids']])
        attention_mask = np.array([tokens['attention_mask']])
        yield input_ids, attention_mask


def main(args):
    sg = rbcc.load_sg(args.sg_file)
    quant_data = generate_quant_data(args.data, args.tokenizer_name)
    qsg = rbcc.quantize(
        sg, 
        dataset=quant_data,
        gpu_id=args.device_id, 
        max_num=100, 
        qconfig={"quant_ops": {"Eltwise": {"input_quant_loc": [], "output_quant_loc": []}}},
        inputs=['input_id', 'attention_mask']
    )
    rbcc.save_sg(qsg, args.sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('sg_file', type=str, help='rbo or sg file path.')
    parser.add_argument('--data', type=str, default='nyu-mll/glue', help='dataset dir.')
    parser.add_argument('--tokenizer-name', type=str, default='ModelTC/bert-base-uncased-cola', help='tokenizer name or path.')
    parser.add_argument('--device-id', type=int, default=0, help="device id.")
    main(parser.parse_args())