import os
import sys
import rbcc
import torch
from preprocess import preprocess


def generate_img(img_dir):
    img_files = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if '.jpg' in f]
    img_files.sort()

    for img_file in img_files:
        img = preprocess(img_file)
        yield img


def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    qconfig = {
        "layers": {
            "/model_22/dfl/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"}
        }
    }

    gpu_id = None
    if torch.cuda.is_available():
        gpu_id = 0

    qsg = rbcc.quantize(sg, dataset=generate_img(img_dir), gpu_id=gpu_id, qconfig=qconfig)
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])