import os
import sys
import rbcc
from preprocess import preprocess


def generate_img(img_dir):
    img_files = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if '.JPEG' in f]
    img_files.sort()

    for img_file in img_files:
        img = preprocess(img_file)
        yield img


def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    # 该模型量化损失严重，需要使用 use_eq 跨层均衡量化方法调整模型权重之后再量化
    qsg = rbcc.quantize(sg, dataset=generate_img(img_dir), gpu_id=0, use_eq=True)
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])