import sys
import rbcc

from rbcc.rbir.sg_ir import SGIR
from rbcc.sg.graph_rewriter import SGRewriter


def apply_fn_cat(g):
    ir = SGIR(g.sg)
    n1 = g.get_node('ChannelF2L/transpose_4')
    n2 = g.get_node('ChannelF2L/transpose_8')
    n3 = g.get_node('ChannelF2L/transpose_12')

    batch, _, _, c = n1.inputs[0].shape
    x1 = ir.reshape(n1.inputs[0], [batch, -1, c])

    batch, _, _, c = n2.inputs[0].shape
    x2 = ir.reshape(n2.inputs[0], [batch, -1, c])

    batch, _, _, c = n3.inputs[0].shape
    x3 = ir.reshape(n3.inputs[0], [batch, -1, c])

    out = ir.concat([x1, x2, x3], 1)

    output_map = {
        g.get_node('/detect/Transpose_3').output: out
    }
    return g.replace_output(output_map)

def main(sg_file, new_sg_file):
    sg = rbcc.load_sg(sg_file)
    g = SGRewriter(sg)
    g.add_path_rule(
        ['/detect/Sigmoid_output_0', '/detect/Sigmoid_1_output_0', '/detect/Sigmoid_2_output_0'],
        ['/detect/Transpose_3_output_0'],
        apply_fn_cat,
        as_template=False
    )
    g.run()
    rbcc.save_sg(sg, new_sg_file)
    return sg


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])
