import os
import sys
import rbcc
from preprocess import preprocess


def generate_img(img_dir):
    img_files = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if '.jpg' in f]
    img_files.sort()

    for img_file in img_files:
        img = preprocess(img_file)
        yield img

def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    advanced_ptq = {
        "load_all_to_gpu": True,
        "recon_param": {
            "epoch": 1000,
            "bs": 32,
            "lr": 1e-2,
            "calib_num": 256, 
            "nodes_per_reconstruction": 3
        }
    }
    qsg = rbcc.quantize(
        sg, 
        dataset=generate_img(img_dir), 
        gpu_id=0,
        qconfig={
            "mixed_type": "output",
            "quant_ops": {
                "Conv2D": {"w_axis": 0},
                'Sigmoid': {'input_quant_loc': [], 'output_quant_loc': []}
            }
        }, 
        advanced_ptq=advanced_ptq
    )
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])
