import os
import sys
import rbcc
from preprocess import preprocess


yolov7_qconfig = {"layers": {
    "/model_102/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_102/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/ia_0/Add": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/m_0/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/im_0/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/m_kpt_0/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/Concat": {"bit": 16, "observer": "MinMaxObserver"},
    
    "/model_103/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_103/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/ia_1/Add": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/m_1/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/im_1/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/m_kpt_1/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/Concat_8": {"bit": 16, "observer": "MinMaxObserver"},
    
    "/model_104/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/ia_2/Add": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/m_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/im_2/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/m_kpt_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_105/Concat_16": {"bit": 16, "observer": "MinMaxObserver"},
    }
}


yolov7s_qconfig = {"layers": {
    "/model_101/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_101/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/ia_0/Add": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/m_0/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/im_0/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/m_kpt_0/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/Concat": {"bit": 16, "observer": "MinMaxObserver"},
    
    "/model_102/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_102/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/ia_1/Add": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/m_1/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/im_1/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/m_kpt_1/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/Concat_8": {"bit": 16, "observer": "MinMaxObserver"},
    
    "/model_103/conv/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_103/act/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/ia_2/Add": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/m_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/im_2/Mul": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/m_kpt_2/Conv": {"bit": 16, "observer": "MinMaxObserver"},
    "/model_104/Concat_16": {"bit": 16, "observer": "MinMaxObserver"},
    }
}


def generate_img(img_dir):
    img_files = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if '.jpg' in f]
    img_files.sort()

    for img_file in img_files:
        img = preprocess(img_file, input_size=[640, 640])
        yield img


def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    
    if 'yolov7s-face' in os.path.basename(sg_file):
        qconfig = yolov7s_qconfig
    elif 'yolov7-face' in os.path.basename(sg_file):
        qconfig = yolov7_qconfig
    else:
        raise RuntimeError(f"unexpect model: {sg_file}.")
    
    qsg = rbcc.quantize(
        sg, 
        dataset=generate_img(img_dir),
        max_num=100,
        qconfig=qconfig,
        gpu_id=0)
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])