import os
import sys
import numpy as np
from tqdm import tqdm
from preprocess import tokenize


def get_cifar100_label():
    classes = [
        'apple', 'aquarium fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
        'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly',
        'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee',
        'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house',
        'kangaroo', 'keyboard', 'lamp', 'lawn mower', 'leopard', 'lion', 'lizard', 'lobster',
        'man', 'maple tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak tree', 'orange',
        'orchid', 'otter', 'palm tree', 'pear', 'pickup truck', 'pine tree', 'plain', 'plate',
        'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
        'squirrel', 'streetcar', 'sunflower', 'sweet pepper', 'table', 'tank', 'telephone',
        'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe',
        'whale', 'willow tree', 'wolf', 'woman', 'worm'
    ]
    templates = [
        # 'a photo of a {}.',
        # 'a blurry photo of a {}.',
        # 'a black and white photo of a {}.',
        # 'a low contrast photo of a {}.',
        # 'a high contrast photo of a {}.',
        # 'a bad photo of a {}.',
        # 'a good photo of a {}.',
        # 'a photo of a small {}.',
        # 'a photo of a big {}.',
        # 'a photo of the {}.',
        # 'a blurry photo of the {}.',
        # 'a black and white photo of the {}.',
        # 'a low contrast photo of the {}.',
        # 'a high contrast photo of the {}.',
        # 'a bad photo of the {}.',
        # 'a good photo of the {}.',
        # 'a photo of the small {}.',
        # 'a photo of the big {}.',
        'itap of a {}.' # 这一个提示词精度就可以达到 68.45，使用官方推荐的反而是 68.26
    ]

    return classes, templates



def main(data_dir, npy_dir):
    classes, templates = get_cifar100_label()

    if not os.path.isdir(npy_dir):
        os.mkdir(npy_dir)

    print("INFO: Convert jpg to npy file.")
    with open(os.path.join(npy_dir, 'quant.txt'), 'w') as f:
        # 遍历需要处理的图片文件路径
        for index, classname in enumerate(tqdm(classes)):
            texts = [template.format(classname) for template in templates] #format with class
            text = tokenize(data_dir, texts, context_length=77)
            # 文件名可以随便取，但是需要保持唯一
            npy_name = f"{index}.npy"
            np.save(os.path.join(npy_dir, npy_name), text)
            # 文件名写入 quant.txt 文件中，后续量化校准会根据此文件读取使用
            f.write(npy_name + '\n')


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])
