import sys
import rbcc
from rbcc.rbir.sg_ir import SGIR
from rbcc.sg.graph_rewriter import SGRewriter


def apply_fn_cat(g):
    ir = SGIR(g.sg)

    cat_1 = g.get_node('/model_22/Concat_1')
    cat_2 = g.get_node('/model_22/Concat_2')
    cat_3 = g.get_node('/model_22/Concat_3')
    in_c1 = int(cat_1.inputs[0].shape[-1])
    in_c2 = int(cat_1.inputs[1].shape[-1])

    out_1 = ir.concat([
        ir.reshape(cat_1.inputs[0], [1, -1, in_c1]), 
        ir.reshape(cat_2.inputs[0], [1, -1, in_c1]), 
        ir.reshape(cat_3.inputs[0], [1, -1, in_c1])
        ], axis=1)
    out_1 = ir.transpose(out_1, [0, 2, 1])

    out_2 = ir.concat([
        ir.reshape(cat_1.inputs[1], [1, -1, in_c2]), 
        ir.reshape(cat_2.inputs[1], [1, -1, in_c2]), 
        ir.reshape(cat_3.inputs[1], [1, -1, in_c2])
        ], axis=1)
    out_2 = ir.transpose(out_2, [0, 2, 1])
    
    split = g.get_node('/model_22/Split')
    return g.replace_output({
            split.outputs[0]: out_1,
            split.outputs[1]: out_2
        })

def rewriter_sg(sg):
    g = SGRewriter(sg)
    g.add_path_rule([
            '/model_22/cv2_0/cv2_0_2/Conv_output_0', 
            '/model_22/cv3_0/cv3_0_2/Conv_output_0', 
            '/model_22/cv2_1/cv2_1_2/Conv_output_0',
            '/model_22/cv3_1/cv3_1_2/Conv_output_0',
            '/model_22/cv2_2/cv2_2_2/Conv_output_0',
            '/model_22/cv3_2/cv3_2_2/Conv_output_0',
        ], 
        [
            '/model_22/Split_output_0', 
            '/model_22/Split_output_1'
        ], 
        apply_fn=apply_fn_cat, 
        as_template=False)
    sg = g.run()
    return sg


if __name__ == '__main__':
    sg = rbcc.load_sg(sys.argv[1])
    sg = rewriter_sg(sg)
    rbcc.save_sg(sg, sys.argv[1])