import sys
import rbcc
import numpy as np
from rbcc.rbir.sg_ir import SGIR
from rbcc.sg.graph_rewriter import SGRewriter


def apply_fn(g):
    ir = SGIR(g.sg)
    conv = g.get_node('/model_6/m_0/m_0_0/attn/qkv/conv/Conv')
    out_h, out_w = conv.output.shape[1:3]
    weight = conv.inputs[1].node.tensor
    bias = conv.inputs[2].node.tensor

    batch = g.get_node('/model_6/m_0/m_0_0/attn/Reshape_1').output.shape[0]
    head_num = g.get_node('/model_6/m_0/m_0_0/attn/Reshape_1').output.shape[-2]
    out_channel, _, _, in_channel = weight.shape
    head_dim = out_channel // head_num
    qkv_dim = head_dim // 3

    weight = weight.reshape([head_num, head_dim, 1, 1, in_channel])
    bias = bias.reshape([head_num, head_dim])
    weights = np.split(weight, 3, axis=1)
    biass = np.split(bias, 3, axis=1)

    q_w = weights[0].reshape(-1, 1, 1, in_channel)
    q_b = biass[0].reshape(-1)
    q_conv = ir.conv2d(conv.inputs[0], ir.constant(q_w), ir.constant(q_b))
    # [batch, H, W, in_channel] -> [batch, H*W//4, head_num*qkv_dim] -> [batch, head_num, H*W//4, qkv_dim]
    q_conv = ir.reshape(q_conv, [batch, -1, head_num, qkv_dim])
    q_conv = ir.transpose(q_conv, [0, 2, 1, 3])

    k_w = weights[1].reshape(-1, 1, 1, in_channel)
    k_b = biass[1].reshape(-1)
    k_conv = ir.conv2d(conv.inputs[0], ir.constant(k_w), ir.constant(k_b))
    # [batch, H, W, in_channel] -> [batch, H*W//4, head_num*qkv_dim] -> [batch, head_num, qkv_dim, H*W//4]
    k_conv = ir.reshape(k_conv, [batch, -1, head_num, qkv_dim])
    k_conv = ir.transpose(k_conv, [0, 2, 3, 1])

    v_w = weights[2].reshape(-1, 1, 1, in_channel)
    v_b = biass[2].reshape(-1)
    v_conv = ir.conv2d(conv.inputs[0], ir.constant(v_w), ir.constant(v_b))
    v_conv_v = ir.reshape(v_conv, [batch, -1, head_num, qkv_dim])
    v_conv_v = ir.transpose(v_conv_v, [0, 2, 1, 3]) # [B, H, K, qkv_dim]

    bmm1 = ir.matmul(q_conv, k_conv)
    bmm1 = ir.mul(bmm1, ir.constant(g.get_node('/model_6/m_0/m_0_0/attn/Mul').inputs[1].node.get_const_val()))
    bmm1 = ir.softmax(bmm1, -1) # [B, H, Q, K]

    out = ir.matmul(bmm1, v_conv_v) # [B, H, Qn, qkv_dim]
    out = ir.transpose(out, [0, 2, 1, 3]) # [B, Qn, H, qkv_dim]
    out = ir.reshape(out, [1, out_h, out_w, head_num * qkv_dim])

    return g.replace_output({
        g.get_node('/model_6/m_0/m_0_0/attn/Reshape_3').output: out,
        g.get_node('/model_6/m_0/m_0_0/attn/Reshape_4').output: v_conv
    })

def apply_fn_cat(g):
    ir = SGIR(g.sg)

    cat_1 = g.get_node('/model_21/Concat')
    cat_2 = g.get_node('/model_21/Concat_1')
    cat_3 = g.get_node('/model_21/Concat_2')
    in_c1 = int(cat_1.inputs[0].shape[-1])
    in_c2 = int(cat_1.inputs[1].shape[-1])

    out_1 = ir.concat([
        ir.reshape(cat_1.inputs[0], [1, -1, in_c1]), 
        ir.reshape(cat_2.inputs[0], [1, -1, in_c1]), 
        ir.reshape(cat_3.inputs[0], [1, -1, in_c1])
        ], axis=1)
    out_1 = ir.transpose(out_1, [0, 2, 1])

    out_2 = ir.concat([
        ir.reshape(cat_1.inputs[1], [1, -1, in_c2]), 
        ir.reshape(cat_2.inputs[1], [1, -1, in_c2]), 
        ir.reshape(cat_3.inputs[1], [1, -1, in_c2])
        ], axis=1)
    out_2 = ir.transpose(out_2, [0, 2, 1])
    
    split = g.get_node('/model_21/Split')
    return g.replace_output({
            split.outputs[0]: out_1,
            split.outputs[1]: out_2
        })

def rewriter_sg(sg):
    g = SGRewriter(sg)
    g.add_path_rule(
        ['/model_6/cv1/act/Mul_output_0'], 
        ['/model_6/m_0/m_0_0/attn/Reshape_5_output_0', '/model_6/m_0/m_0_0/attn/Reshape_6_output_0'],
        apply_fn=apply_fn)

    g.add_path_rule([
            '/model_21/cv2_0/cv2_0_2/Conv_output_0', 
            '/model_21/cv3_0/cv3_0_2/Conv_output_0', 
            '/model_21/cv2_1/cv2_1_2/Conv_output_0',
            '/model_21/cv3_1/cv3_1_2/Conv_output_0',
            '/model_21/cv2_2/cv2_2_2/Conv_output_0',
            '/model_21/cv3_2/cv3_2_2/Conv_output_0',
        ], 
        [
            '/model_21/Split_output_0', 
            '/model_21/Split_output_1'
        ], 
        apply_fn=apply_fn_cat, 
        as_template=False)
    sg = g.run()
    return sg


if __name__ == '__main__':
    sg = rbcc.load_sg(sys.argv[1])
    sg = rewriter_sg(sg)
    rbcc.save_sg(sg, sys.argv[1])