import os
import sys
import rbcc
from easydict import EasyDict
from preprocess import preprocess, HRNetCOCODataset


def generate_img(img_dir):
    cfg = EasyDict({
        'MODEL': {
            'NUM_JOINTS': 17,
            'IMAGE_SIZE': [192, 256],
        },
        'TEST': {
            'COLOR_RGB': True,
            'POST_PROCESS': True, 
            'SHIFT_HEATMAP': True,
            'USE_GT_BBOX': True, 
            'IMAGE_THRE': 0.0, 
            'NMS_THRE': 1.0, 
            'SOFT_NMS': False, 
            'OKS_THRE': 0.9,
            'IN_VIS_THRE': 0.2,
            'BBOX_THRE': 1.0,
            'COCO_BBOX_FILE': ''
        }}
    )
    val_loader = HRNetCOCODataset(cfg, img_dir, 'val2017', transform=preprocess)
    
    for img, meta in val_loader:
        yield img


def main(sg_file, img_dir):
    sg = rbcc.load_sg(sg_file)
    qsg = rbcc.quantize(sg, dataset=generate_img(img_dir), gpu_id=0, max_num=1000, qconfig={"observer": "GaussObserver"})
    rbcc.save_sg(qsg, sg_file.replace('.sg', '_8bit.sg'))


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])