import sys
import rbcc
import numpy as np

from rbcc.rbir.sg_ir import SGIR
from rbcc.sg.graph_rewriter import SGRewriter


def match_apply_einsum2d(g):
    transpose = g.get_node('ChannelF2L/transpose_13')
    einsum = g.get_node('/model_22/cv4_0/Einsum')
    sgir = SGIR(g.sg, scope=einsum.name)
    
    x = transpose.inputs[0]
    w = sgir.constant(np.reshape(einsum.inputs[1].node.tensor, [80, 1, 1, 512]), dtype='float32')
    b = sgir.constant(np.zeros([80], dtype=np.float32), dtype='float32')
    x = sgir.conv2d(x, w, b, pad_size=[0, 0, 0, 0], dilation=[1, 1], stride=[1, 1])
    x.node.set_attr('data_format', 'channel_last')
    x = sgir.transpose(x, [0, 3, 1, 2])
    
    output_map = {einsum.output: x}
    return g.replace_output(output_map)


def match_apply_einsum3d(g):
    transpose = g.get_node('ChannelF2L/transpose_3')
    einsum = g.get_node('/model_12/attn/Einsum')
    sgir = SGIR(g.sg, scope=einsum.name)
    
    inputs = transpose.inputs[0]
    fake_b, h, w, c = inputs.shape
    real_b = c // 32
    inputs = sgir.reshape(inputs, [h, w, real_b, 32])
    inputs = [sgir.reshape(sgir.stride_slice(inputs, i, i+1, 1, axes=[2]), [1, h, w, 32]) for i in range(real_b)]
    weights = [sgir.constant(np.reshape(einsum.inputs[1].node.tensor[:, :, i, :], [80, 1, 1, 32]), dtype='float32') for i in range(real_b)]
    biases = [sgir.constant(np.zeros([80], dtype=np.float32), dtype='float32') for i in range(real_b)]
    outputs = [sgir.conv2d(inputs[i], weights[i], biases[i], pad_size=[0, 0, 0, 0], dilation=[1, 1], stride=[1, 1]) for i in range(real_b)]
    _ = [v.node.set_attr('data_format', 'channel_last') for v in outputs]
    outputs = sgir.concat(outputs, axis=0)
    outputs = sgir.unsqueeze(outputs, 0)
    output_map = {einsum.output: outputs}
    return g.replace_output(output_map)
    

def rewrite_sg(sg):
    g = SGRewriter(sg)
    g.add_path_rule(['/model_22/cv3_0/cv3_0_2/Conv_output_0'], ['/model_22/cv4_0/Einsum_output_0'], match_apply_einsum2d)
    g.add_path_rule(['/model_12/m_0/cv2/act/Mul_output_0'], ['/model_12/attn/Einsum_output_0'], match_apply_einsum3d)
    g.run()
    return sg


if __name__ == "__main__":
    sg = rbcc.load_sg(sys.argv[1])
    sg = rewrite_sg(sg)
    rbcc.save_sg(sg, sys.argv[1])