import sys
import rbcc
import numpy as np
from rbcc.rbir.sg_ir import SGIR
from rbcc.sg.graph_rewriter import SGRewriter


def match_apply(g):
    transpose = g.get_node("ChannelF2L/transpose_0")
    sgir = SGIR(g.sg, scope=transpose.name)
    x = sgir.custom_op(op_name='Input', outputs=[([1, 640, 640, 3], 'float32')], shape=[1, 640, 640, 3])
    bias = np.zeros([12], dtype=np.float32)
    weight = np.zeros([12, 2, 2, 3], dtype=np.float32)
    weight[0:3, 0, 0] = np.eye(3, dtype=np.float32)
    weight[3:6, 1, 0] = np.eye(3, dtype=np.float32)
    weight[6:9, 0, 1] = np.eye(3, dtype=np.float32)
    weight[9:12, 1, 1] = np.eye(3, dtype=np.float32)
    x = sgir.conv2d(x, sgir.constant(weight), bias=sgir.constant(bias), stride=2)
    x.node.attrs['data_format'] = 'channel_last'
    output_map = {transpose.output: x}
    return g.replace_output(output_map)


def main(sg_file):
    sg_file = sys.argv[1]
    sg = rbcc.load_sg(sg_file)
    g = SGRewriter(sg)
    g.add_path_rule([], ['ChannelF2L/transpose_0'], match_apply, as_template=False)
    sg = g.run()
    rbcc.save_sg(sg, sg_file)


if __name__ == '__main__':
    main(sys.argv[1])

