import os
import sys
import rbcc
from preprocess import preprocess


def generate_img(img_dir):
    name_txt = os.path.join(img_dir, 'train.txt')

    with open(name_txt) as f:
        for line in f.readlines():
            name, _ = line.strip().split('	')
            img_path = os.path.join(img_dir, name)
            if not os.path.exists(img_path): continue
            img = preprocess(img_path)
            yield img


def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    qsg = rbcc.quantize(
        sg, 
        dataset=generate_img(img_dir),
        gpu_id=0, 
        max_num=50,
        qconfig={"bit": 16, "observer": "MinMaxObserver"})
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])