import os
import sys
import rbcc
import torch
from preprocess import preprocess


def generate_img(img_dir):
    img_files = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if '.jpg' in f]
    img_files.sort()

    for img_file in img_files:
        img = preprocess(img_file)
        yield img


def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    qconfig = {
        "quant_ops": {"Conv2D": {"w_axis": 0}},
        "layers": {
            "/model_22/cv4_0/cv4_0_0/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_1/cv4_1_0/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_2/cv4_2_0/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},

            "/model_22/cv4_0/cv4_0_0/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_1/cv4_1_0/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_2/cv4_2_0/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},

            "/model_22/cv4_0/cv4_0_1/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_1/cv4_1_1/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_1/cv4_2_1/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},

            "/model_22/cv4_0/cv4_0_1/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_1/cv4_1_1/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_2/cv4_2_1/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},

            "/model_22/cv4_0/cv4_0_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_1/cv4_1_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
            "/model_22/cv4_2/cv4_2_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
        }
    }

    gpu_id = None
    if torch.cuda.is_available():
        gpu_id = 0

    qsg = rbcc.quantize(sg, dataset=generate_img(img_dir), gpu_id=gpu_id, qconfig=qconfig)
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])