import os
import sys
import numpy as np
from tqdm import tqdm
from preprocess import preprocess
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100


def main(data_dir, npy_dir):
    cifar100 = CIFAR100(root=data_dir, transform=preprocess, train=True)
    loader = DataLoader(cifar100, batch_size=1, shuffle=False, num_workers=8)

    os.makedirs(npy_dir, exist_ok=True)

    print("INFO: Convert jpg to npy file.")
    with open(os.path.join(npy_dir, 'quant.txt'), 'w') as f:
        # 遍历需要处理的图片文件路径
        for index, (img, _) in tqdm(enumerate(loader), total=100):
            if index == 100:
                break
            # 文件名可以随便取，但是需要保持唯一
            npy_name = f"{index}.npy"
            np.save(os.path.join(npy_dir, npy_name), img)
            # 文件名写入 quant.txt 文件中，后续量化校准会根据此文件读取使用
            f.write(npy_name + '\n')


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])