import sys
import rbcc
from packaging import version
from rbcc.sg.graph_rewriter import SGRewriter

RBCC_VERSION = rbcc.__version__

def set_qparam(attrs, prefix, quant_node):
    scale, zp = quant_node.output.get_scale_zp()
    attrs[f"{prefix}_scale"] = scale
    attrs[f"{prefix}_zp"] = zp

def match_apply(g):
    inputs = []
    attrs = {}
    name = g.get_node('/stage1_0/Concat').name
    outputs = g.get_node('/stage1_0/Concat').outputs

    # connect input
    input_name = 'unsqueeze_0_Unsqueeze/0/TypeCast'
    if version.parse(RBCC_VERSION) < version.parse('1.2.2'):
        input_name = 'unsqueeze_0/Unsqueeze/0/TypeCast'
        
    inputs.append(g.get_node(input_name).inputs[0])   
    inputs.extend(g.get_node('/stage1_0/qkv/MatMul').inputs[1:]) # qkv_w, qkv_b

    set_qparam(attrs, 'qkv_in', g.get_node('unsqueeze_0/Unsqueeze/0/TypeCast'))
    set_qparam(attrs, 'qkv_out', g.get_node('/stage1_0/qkv/MatMul_output_0/DataCrop'))

    # branch 1
    inputs.append( g.get_node('/stage1_0/attns_0/Mul').inputs[1]) # branch 1 q scale

    set_qparam(attrs, 'b1_bmm1_left_in', g.get_node('/stage1_0/attns_0/Transpose_2_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm1_right_in', g.get_node('/stage1_0/attns_0/Transpose_9_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm1_out', g.get_node('/stage1_0/attns_0/MatMul_output_0/DataCrop'))

    set_qparam(attrs, 'b1_bmm2_left_in', g.get_node('/stage1_0/attns_0/Softmax_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm2_right_in', g.get_node('/stage1_0/attns_0/Transpose_8_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm2_out', g.get_node('/stage1_0/attns_0/MatMul_1_output_0/DataCrop'))

    set_qparam(attrs, 'b1_depth_conv_in', g.get_node('ChannelF2L/transpose_4/TypeCast'))
    set_qparam(attrs, 'b1_depth_conv_out', g.get_node('/stage1_0/attns_0/get_v/Conv_output_0/DataCrop'))

    conv = g.get_node('/stage1_0/attns_0/get_v/Conv')
    inputs.extend(conv.inputs[1:]) # depthconv weights and bias
    for n in ('stride', 'pad_size', 'dilation', 'group'):
        attrs[f"b1_conv_{n}"] = conv.get_attr(n)


    # branch 2
    inputs.append(g.get_node('/stage1_0/attns_1/Mul').inputs[1])

    set_qparam(attrs, 'b2_bmm1_left_in', g.get_node('/stage1_0/attns_1/Transpose_2_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm1_right_in', g.get_node('/stage1_0/attns_1/Transpose_9_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm1_out', g.get_node('/stage1_0/attns_1/MatMul_output_0/DataCrop'))

    set_qparam(attrs, 'b2_bmm2_left_in', g.get_node('/stage1_0/attns_1/Softmax_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm2_right_in', g.get_node('/stage1_0/attns_1/Transpose_8_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm2_out', g.get_node('/stage1_0/attns_1/MatMul_1_output_0/DataCrop'))

    set_qparam(attrs, 'b2_depth_conv_in', g.get_node('ChannelF2L/transpose_2/TypeCast'))
    set_qparam(attrs, 'b2_depth_conv_out', g.get_node('/stage1_0/attns_1/get_v/Conv_output_0/DataCrop'))

    conv = g.get_node('/stage1_0/attns_1/get_v/Conv')
    inputs.extend(conv.inputs[1:]) # depthconv weights and bias

    for n in ('stride', 'pad_size', 'dilation', 'group'):
        attrs[f"b2_conv_{n}"] = conv.get_attr(n)

    return g.node('CSwinAtten', inputs, outputs, name=name, attrs=attrs)



def match_apply_fc(g):
    inputs = []
    attrs = {}
    name = g.get_node('/stage2_0/Concat').name
    outputs = g.get_node('/stage2_0/Concat').outputs

    # connect inputs
    inputs.append(g.get_node('unsqueeze_3/Unsqueeze/0/TypeCast').inputs[0])
    inputs.extend(g.get_node('/stage2_0/qkv/MatMul').inputs[1:]) # qkv_w and qkv_b

    set_qparam(attrs, 'qkv_in', g.get_node('unsqueeze_3/Unsqueeze/0/TypeCast'))
    set_qparam(attrs, 'qkv_out', g.get_node('/stage2_0/qkv/MatMul_output_0/DataCrop'))

    # branch 1
    inputs.append(g.get_node('/stage2_0/attns_0/Mul').inputs[1])

    set_qparam(attrs, 'b1_bmm1_left_in', g.get_node('/stage2_0/attns_0/Transpose_2_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm1_right_in', g.get_node('/stage2_0/attns_0/Transpose_9_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm1_out', g.get_node('/stage2_0/attns_0/MatMul_output_0/DataCrop'))

    set_qparam(attrs, 'b1_bmm2_left_in', g.get_node('/stage2_0/attns_0/Softmax_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm2_right_in', g.get_node('/stage2_0/attns_0/Transpose_8_output_0/TypeCast'))
    set_qparam(attrs, 'b1_bmm2_out', g.get_node('/stage2_0/attns_0/MatMul_1_output_0/DataCrop'))

    set_qparam(attrs, 'b1_depth_conv_in', g.get_node('ChannelF2L/transpose_10/TypeCast'))
    set_qparam(attrs, 'b1_depth_conv_out', g.get_node('/stage2_0/attns_0/get_v/Conv_output_0/DataCrop'))

    conv = g.get_node('/stage2_0/attns_0/get_v/Conv')
    inputs.extend(conv.inputs[1:]) # depthwise conv weights and bias
    for n in ('stride', 'pad_size', 'dilation', 'group'):
        attrs[f"b1_conv_{n}"] = conv.get_attr(n)


    # branch 2
    inputs.append(g.get_node('/stage2_0/attns_1/Mul').inputs[1])

    set_qparam(attrs, 'b2_bmm1_left_in', g.get_node('/stage2_0/attns_1/Transpose_2_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm1_right_in', g.get_node('/stage2_0/attns_1/Transpose_9_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm1_out', g.get_node('/stage2_0/attns_1/MatMul_output_0/DataCrop'))

    set_qparam(attrs, 'b2_bmm2_left_in', g.get_node('/stage2_0/attns_1/Softmax_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm2_right_in', g.get_node('/stage2_0/attns_1/Transpose_8_output_0/TypeCast'))
    set_qparam(attrs, 'b2_bmm2_out', g.get_node('/stage2_0/attns_1/MatMul_1_output_0/DataCrop'))

    set_qparam(attrs, 'b2_depth_conv_in', g.get_node('ChannelF2L/transpose_8/TypeCast'))
    set_qparam(attrs, 'b2_depth_conv_out', g.get_node('/stage2_0/attns_1/get_v/Conv_output_0/DataCrop'))

    conv = g.get_node('/stage2_0/attns_1/get_v/Conv')
    inputs.extend(conv.inputs[1:])

    for n in ('stride', 'pad_size', 'dilation', 'group'):
        attrs[f"b2_conv_{n}"] = conv.get_attr(n)

    return g.node('CSwinAttenFC', inputs, outputs, name=name, attrs=attrs)


def match_apply_last(g):
    inputs = []
    attrs = {}

    name = g.get_node('unsqueeze_73/Unsqueeze').name
    outputs = g.get_node('unsqueeze_73/Unsqueeze').outputs

    # connect inputs
    inputs.append(g.get_node('unsqueeze_72/Unsqueeze/0/TypeCast').inputs[0])
    inputs.extend(g.get_node('/stage4_0/qkv/MatMul').inputs[1:]) # qkv_w, qkv_b

    # q scale
    inputs.append(g.get_node('/stage4_0/attns_0/Mul').inputs[1])

    set_qparam(attrs, 'qkv_in', g.get_node('unsqueeze_72/Unsqueeze/0/TypeCast'))
    set_qparam(attrs, 'qkv_out', g.get_node('/stage4_0/qkv/MatMul_output_0/DataCrop'))

    # branch 1
    set_qparam(attrs, 'bmm1_left_in', g.get_node('/stage4_0/attns_0/Transpose_2_output_0/TypeCast'))
    set_qparam(attrs, 'bmm1_right_in', g.get_node('/stage4_0/attns_0/Transpose_9_output_0/TypeCast'))
    set_qparam(attrs, 'bmm1_out', g.get_node('/stage4_0/attns_0/MatMul_output_0/DataCrop'))

    set_qparam(attrs, 'bmm2_left_in', g.get_node('/stage4_0/attns_0/Softmax_output_0/TypeCast'))
    set_qparam(attrs, 'bmm2_right_in', g.get_node('/stage4_0/attns_0/Transpose_8_output_0/TypeCast'))
    set_qparam(attrs, 'bmm2_out', g.get_node('/stage4_0/attns_0/MatMul_1_output_0/DataCrop'))

    set_qparam(attrs, 'depth_conv_in', g.get_node('ChannelF2L/transpose_104/TypeCast'))
    set_qparam(attrs, 'depth_conv_out', g.get_node('/stage4_0/attns_0/get_v/Conv_output_0/DataCrop'))

    conv = g.get_node('/stage4_0/attns_0/get_v/Conv')
    inputs.extend(conv.inputs[1:])

    for n in ('stride', 'pad_size', 'dilation', 'group'):
        attrs[f"conv_{n}"] = conv.get_attr(n)

    return g.node('CSwinAttenLast', inputs, outputs, name=name, attrs=attrs)

def main(src_path, dst_path):
    sg = rbcc.load_sg(src_path)
    g = SGRewriter(sg)

    if version.parse(RBCC_VERSION) < version.parse('1.2.2'):
        g.add_path_rule(
            ['unsqueeze_0/Unsqueeze/0'],
            ['/stage1_0/Concat_output_0'],
            match_apply, as_template=False
        )
        g.add_path_rule(
            ['unsqueeze_3/Unsqueeze/0'],
            ['/stage2_0/Concat_output_0'],
            match_apply_fc
        )
        g.add_path_rule(
            ['unsqueeze_72/Unsqueeze/0'],
            ['unsqueeze_73/Unsqueeze/0'],
            match_apply_last
        )
        sg = g.run()
        rbcc.save_sg(sg, dst_path)
        return
    
    g.add_path_rule(
        ['unsqueeze_0_Unsqueeze/0'],
        ['/stage1_0/Concat_output_0'],
        match_apply, as_template=False
    )
    g.add_path_rule(
        ['unsqueeze_3_Unsqueeze/0'],
        ['/stage2_0/Concat_output_0'],
        match_apply_fc
    )
    g.add_path_rule(
        ['unsqueeze_72_Unsqueeze/0'],
        ['unsqueeze_73_Unsqueeze/0'],
        match_apply_last
    )
    sg = g.run()
    rbcc.save_sg(sg, dst_path)
    


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])