import numpy as np
from rbcc.custom_op import BaseOp, register_op
import rbpy

import math

# B, H, N, D, C
shape_info = {
    'Concat_2770': (56, 1, 56, 32, 64),
    'Concat_3284': (14, 2, 56, 32, 128),
    'Concat_3760': (14, 2, 56, 32, 128),
    'Reshape_14088': (1, 16, 49, 32, 512)
    # 'other': (2, 4, 98, 32, 256),
}


@register_op('CSwinAtten')
class CSwinAtten(BaseOp):

    def __init__(self, node):
        super().__init__(node)
        self.const_cache = {}

    def get_scale_zp(self, prefix):
        scale = self.node.get_attr(f'{prefix}_scale')
        zp = self.node.get_attr(f'{prefix}_zp')
        return scale, zp

    def attn(self, x, **kwargs):
        q, k, v = x.unbind(0)
        B, L, C = q.shape

        H, W = kwargs['H'], kwargs['W']
        self.num_heads = kwargs['num_heads']
        self.H_sp, self.W_sp = kwargs['H_sp'], kwargs['W_sp']

        q = self.im2cswin(q)
        k = self.im2cswin(k)

        v, lepe = self.get_lepe(v, **kwargs)

        # print('q,k,v,x =', q.shape, k.shape, v.shape, x.shape)
        # exit(0)

        q = q * self.const_cache[kwargs['scale_name']]

        q = self.fakequant(q, kwargs['bmm1_left_in_scale'], kwargs['bmm1_left_in_zp'])
        attn = q @ self.fakequant(k.transpose(-2, -1), kwargs['bmm1_right_in_scale'], kwargs['bmm1_right_in_zp'])
        attn = self.fakequant(attn, kwargs['bmm1_out_scale'], kwargs['bmm1_out_zp'])
        attn = F.softmax(attn, dim=-1)
        attn = self.fakequant(attn, kwargs['bmm2_left_in_scale'], kwargs['bmm2_left_in_zp'])
        x = (attn @ self.fakequant(v, kwargs['bmm2_right_in_scale'], kwargs['bmm2_right_in_zp'])) + lepe
        x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C)  # B head N N @ B head N C
        x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)  # B H' W' C
        return x

    def padN(self, x, zp):
        N, _ = x.shape
        PC = 64
        newN = (N // PC + 1) * PC
        if N != newN:
            x = rbpy.pad(x, [0, newN - N], axes=[0], value=rbpy.constant(zp, dtype=rbpy.uint8))
        return x

    def padND(self, x, zp, isv=False):
        PC = 64
        if len(x.shape) == 3:
            N, _, D = x.shape
            newN = (N // PC + 1) * PC
            newD = (D // PC + 1) * PC
        else:
            N, _, H, D = x.shape
            HD = H * D
            newN = (N // PC + 1) * PC
            if isv or HD % PC != 0:
                newD = (D // PC + 1) * PC
            else:
                newD = D

        if newD != D or newN != N:
            old_shape = x.shape
            if len(old_shape) == 4:
                x = rbpy.reshape(x, [old_shape[0], -1, old_shape[-1]])
            x = rbpy.pad(x, [0, newN - N, 0, newD - D], axes=[0, 2],
                            value=rbpy.constant(zp, dtype=rbpy.uint8))
            if len(old_shape) == 4:
                x = rbpy.reshape(x, [newN, old_shape[1], old_shape[2], newD])
        return x

    def pad_constant(self, v, has_qparam=False, use_self=True, need_filp=False):
        data = v.numpy()
        if has_qparam:
            value = v.qparam.zp
        else:
            value = 0

        if use_self:
            padv = data
        else:
            padv = np.zeros_like(data, dtype=data.dtype) + np.array(value, dtype=data.dtype)
        data = np.concatenate([data, padv], axis=0)

        assert data.shape[0] == 64

        if need_filp:
            shape = data.shape
            data = data.reshape(shape[0] // 64, 64, -1)[:, ::-1, :].reshape(shape)

        if has_qparam:
            return rbpy.constant(data, qparam=v.qparam)
        else:
            return rbpy.constant(data)

    def attn_to_ast(self, x, **kwargs):
        # print('x =', x)
        # cswin需要考虑是concat左边的还是右边的，因为需要从其中取出的大小不一样
        xB, H, W, C = x.shape
        Hsp, Wsp = kwargs['H_sp'], kwargs['W_sp']
        w_scale, w_zp = self.node.inputs[1].get_scale_zp()
        # 说明：H // Hsp 和 W // Wsp必然有一个是为1的
        # 而不为1的那个结果就是B，如果H // Hsp是B，那么说明和之前处理的情况一致
        # 如果W // Wsp是B，那么gconv的mode设置为0，后续单独处理
        if Hsp > Wsp: # H // Hsp == 1
            assert H == Hsp
            B = W // Wsp
            N = Wsp * Hsp
            need_transpose = False
        else: # W // Wsp == 1
            assert W == Wsp
            B = H // Hsp
            N = Wsp * Hsp
            need_transpose = True

        def get_const(w_name, b_name, need_filp=False):
            w = kwargs[w_name]
            b = kwargs[b_name]

            if need_filp and w.shape[0] % 64 == 0:
                ws, bs = w.shape, b.shape
                w = w.reshape(ws[0] // 64, 64, -1)[:, ::-1, :].reshape(ws)
                b = b.reshape(bs[0] // 64, 64)[:, ::-1].reshape(bs)

            return rbpy.constant(w, qparam=rbpy.QuantParam(w_scale, w_zp)), rbpy.constant(b)

        q_w, q_b = get_const('q_w', 'q_b')
        k_w, k_b = get_const('k_w', 'k_b')
        # v_w, v_b = get_const('v_w', 'v_b', True)
        v_w, v_b = get_const('v_w', 'v_b')
        HD, = q_b.shape
        D = 32

        # 因为整个cswin中的所有情况都是D == 32，而H除了1之外都是偶数
        Head = HD // 32
        # x_scale, x_zp = self.get_scale_zp('qkv_in')
        scale, zp = self.get_scale_zp('qkv_out')

        def get_scalezp(scale_name, zp_name):
            cur_scale = kwargs[scale_name]
            cur_zp = kwargs[zp_name]
            if scale > cur_scale:
                return scale, zp
            else:
                return cur_scale, cur_zp

        if HD == 32:
            q_w = self.pad_constant(q_w, True)
            k_w = self.pad_constant(k_w, True)
            v_w = self.pad_constant(v_w, True)
            q_b = self.pad_constant(q_b)
            k_b = self.pad_constant(k_b)
            v_b = self.pad_constant(v_b)

        q_scale, q_zp = scale, zp # get_scalezp('bmm1_left_in_scale', 'bmm1_left_in_zp')
        # k_scale, k_zp = scale, zp # get_scalezp('bmm1_right_in_scale', 'bmm1_right_in_zp')
        # v_scale, v_zp = scale, zp # get_scalezp('bmm2_right_in_scale', 'bmm2_right_in_zp')
        k_scale, k_zp = kwargs['bmm1_right_in_scale'], kwargs['bmm1_right_in_zp']
        v_scale, v_zp = kwargs['bmm2_right_in_scale'], kwargs['bmm2_right_in_zp']
        qqparam = rbpy.QuantParam(scale=q_scale, zp=q_zp)
        kqparam = rbpy.QuantParam(scale=k_scale, zp=k_zp)
        vqparam = rbpy.QuantParam(scale=v_scale, zp=v_zp)
        if need_transpose:
            if HD == 32:
                # print('构造gconv')
                qconv = rbpy.nn.gconv(x, q_w, q_b, [B, N, 2, D], 1, 0, B, qparam=qqparam)
                kconv = rbpy.nn.gconv(x, k_w, k_b, [B, N, 2, D], 1, 0, B, qparam=kqparam)
                vconv = rbpy.nn.gconv(x, v_w, v_b, [B, N, 2, D], 1, 0, B, qparam=vqparam)

                qconv = rbpy.crop(qconv.reshape(N, B, 2, D), crop_size=[0, 1], axes=[2]).reshape(1, 1, N, B * HD)
            else:
                assert(B == H // Hsp and N == Hsp * Wsp)
                x = x.reshape(1, H // Hsp, Hsp * Wsp, C) # == B, N, C
                qconv = rbpy.nn.gconv(x, q_w, q_b, [B, N, Head, D], 1, 0, B, qparam=qqparam)
                kconv = rbpy.nn.gconv(x, k_w, k_b, [B, N, Head, D], 1, 0, B, qparam=kqparam)
                vconv = rbpy.nn.gconv(x, v_w, v_b, [B, N, Head, D], 1, 0, B, qparam=vqparam)
        else:
            # 需要将b从w中transpose到channel上，目前后端暂未支持fuse，先手动插入transpose
            if HD == 32:
                # print('构造conv')
                x = x.reshape(Hsp, W // Wsp, Wsp, C).transpose(0, 2, 1, 3).reshape(1, N, B // 2, 2 * C)
                qconv = rbpy.nn.conv2d(x, q_w, q_b, group=2, qparam=qqparam).reshape(1, 1, N, B * HD)
                kconv = rbpy.nn.conv2d(x, k_w, k_b, group=2, qparam=kqparam).reshape(1, 1, N, B * HD)
                vconv = rbpy.nn.conv2d(x, v_w, v_b, group=2, qparam=vqparam).reshape(1, 1, N, B * HD)
            else:
                x = x.reshape(Hsp, W // Wsp, Wsp, C).transpose(0, 2, 1, 3).reshape(1, N, B, C)
                qconv = rbpy.nn.conv2d(x, q_w, q_b, qparam=qqparam).reshape(1, 1, N, B * HD)
                kconv = rbpy.nn.conv2d(x, k_w, k_b, qparam=kqparam).reshape(1, 1, N, B * HD)
                vconv = rbpy.nn.conv2d(x, v_w, v_b, qparam=vqparam).reshape(1, 1, N, B * HD)

            # if B == 14:
            #     print('qconv', qconv)
            #     print('x', x)
            #     print('q_w', q_w)
            #     print('============================')

        # 无论是哪个分支，出来的qkv都满足1x1xNx(B*H*D)这样的格式
        # print(B, H, W, C, D, Head, 'hsp =', Hsp, 'wsp =', Wsp)

        # TODO(kma) 强行quant/dequant一下看看结果
        # vconv = rbpy.nn.dequant(vconv)
        # vconv = rbpy.nn.quant(vconv, rbpy.QuantParam(kwargs['bmm2_right_in_scale'], kwargs['bmm2_right_in_zp']))
        # v_scale, v_zp = kwargs['bmm2_right_in_scale'], kwargs['bmm2_right_in_zp']

        if HD == 32 and need_transpose:
            oldvconv = rbpy.crop(vconv.reshape(N, B, 2, D), crop_size=[0, 1], axes=[2]).reshape(1, 1, N, B * HD)

            # 这个分支因为d已经pad了，没啥必要去crop
            kconv = self.padN(kconv.reshape([N, -1]), k_zp).reshape(-1, B, 2 * D)

            # vconv = rbpy.flip(vconv.reshape(-1, 2 * D), [1])
            vconv = self.padN(vconv.reshape([N, -1]), v_zp).reshape(-1, B, 1, 2 * D)
        else:
            # vconv走lepe分支的时候可以不pad，减少计算量
            oldvconv = vconv

            kconv = self.padND(kconv.reshape([N, B * Head, D]), k_zp)

            # vconv = rbpy.flip(vconv.reshape(-1, D), [1])
            vconv = self.padND(vconv.reshape([N, B, Head, D]), v_zp, True)

        # 更新pad之后的shape
        oldD = D
        pN, _, kD = kconv.shape
        _, _, _, vD = vconv.shape

        assert pN % 64 == 0 and kD % 64 == 0

        # TODO(kma) 强行quant/dequant一下看看结果
        # kconv = rbpy.nn.dequant(kconv)
        # kconv = rbpy.nn.quant(kconv, rbpy.QuantParam(kwargs['bmm1_right_in_scale'], kwargs['bmm1_right_in_zp']))

        kWeight = kconv.reshape(pN // 64, 64, B * Head, kD // 64, 64) # N // 64, 64, BH, D // 64, 64
        kWeight = rbpy.fliptranspose(kWeight, [2, 0, 3, 1, 4], 1) # BH, N // 64, D // 64, 64, 64
        kWeight = rbpy.reshape(kWeight, [B*Head*pN//64, 1, 1, kD//64, 64, 64])

        kBias = rbpy.reshape(kconv, [pN, B*Head*kD])
        kBias = rbpy.transpose(kBias, [1, 0])
        kBias = rbpy.reshape(kBias, [B*Head, kD, pN])
        kBias = rbpy.genbias(kBias, q_zp, k_zp, D)

        qkbmmscale, qkbmmzp = kwargs['bmm1_out_scale'], kwargs['bmm1_out_zp']

        # TODO(kma) 这一段想要优化需要重新量化
        # qconv = rbpy.nn.dequant(qconv)
        # qconv = qconv * kwargs['bmm_mul_scalar']
        # qconv = rbpy.nn.quant(qconv, rbpy.QuantParam(kwargs['bmm1_left_in_scale'], kwargs['bmm1_left_in_zp']), dtype=rbpy.uint8)

        # 1, 1, qN, B * Head * kN
        qkbmmqparam = rbpy.QuantParam(scale=qkbmmscale, zp=qkbmmzp)
        qkbmm = rbpy.nn.gconv(qconv, kWeight, kBias, [B, N, Head, D], 0, 1, B*Head, qparam=qkbmmqparam)
        # if B == 14:
        #     print('qconv', qconv)
        #     print('qkbmm', qkbmm)
        #     print('kWeight', kWeight)
        #     print(B, Head, N, D, vD, H, W, need_transpose)
        qkbmm = qkbmm.reshape(N * B * Head, pN).transpose([1, 0])
        # print('qkbmm transpose', qkbmm)

        qkbmm = rbpy.nn.dequant(qkbmm)
        qkbmm = qkbmm * kwargs['bmm_mul_scalar']
        # qkbmm = qkbmm * kwargs['bmm_mul_scalar']
        qkbmm = rbpy.nn.softmax(qkbmm, 0, N)
        qkscale, qkzp = kwargs['bmm2_left_in_scale'], kwargs['bmm2_left_in_zp']
        bmm2inqparam = rbpy.QuantParam(scale=qkscale, zp=qkzp)
        qkbmm = rbpy.nn.quant(qkbmm, qparam=bmm2inqparam, dtype=rbpy.uint8)
        # print('qkbmm softmax', qkbmm)
        qkbmm = qkbmm.transpose(1, 0)
        # print('qkbmm in bmm2', qkbmm)

        qkbmm = rbpy.reshape(qkbmm, [1, 1, N, B*Head*pN])
        # print(qkbmm)

        # 计算lepe，在pad之前，避免多余的计算
        lepe_scale, lepe_zp = kwargs['depth_conv_out_scale'], kwargs['depth_conv_out_zp']
        lepeqparam = rbpy.QuantParam(lepe_scale, lepe_zp)
        lepe_weight_qparam = kwargs['depth_w'].qparam
        weight_data = kwargs['depth_w'].numpy()
        # print(weight_data.shape, oldvconv.shape, Hsp, Wsp, B, Head, oldD)
        weight_data = weight_data.reshape(Head * oldD, 3, 3, 1, 1)
        lepe_weight = rbpy.constant(weight_data, qparam=lepe_weight_qparam)
        lepe_bias = kwargs['depth_b']
        # lepe_wscale, lepe_wzp = lepe_weight.qparam.scale, lepe_weight.qaram.zp

        lepe_group = kwargs['conv_group']
        # 先不做校验了，等给到后端再说是不是有问题
        # if lepe_group == 32:
        #     # lepe_weight = np.pad(lepe_weight.numpy(), ((0, 0, 0, 0), (32, 0, 0, 0)), 'constant', lepe_wzp)
        #     # lepe_weight = rbpy.constant(lepe_weight, qparam=rbpy.QuantParam(lepe_wscale, lepe_wzp))
        #     # lepe_bias = np.pad(lepe_bias.numpy(), (0, 32), 'constant', 0)
        #     # lepe_bias = rbpy.constant(lepe_bias)
        #     lepe_group = 64
        # else:
        #     # 这个分支不用特殊处理，唯一需要处理的是B和weight的问题
        #     oldvconv = oldvconv.reshape(1, N, B, H * D)
        #     assert(lepe_group == H * D and lepe_group == lepe_weight.shape[0])

        # 用conv3d的groupconv，来避免B这个维度导致的weight的filter需要copy的问题
        oldvconv = oldvconv.reshape(1, Hsp, Wsp, B, Head * oldD)# [..., ::-1]
        lepe = rbpy.nn.conv3d(oldvconv, lepe_weight, lepe_bias, pad_size=[1, 1, 1, 1, 0, 0], group=lepe_group, qparam=lepeqparam)
        if need_transpose:
            assert(B == H // Hsp and N == Hsp * Wsp)
            lepe = lepe.reshape(1, Hsp * Wsp, H // Hsp, Head * oldD).transpose(0, 2, 1, 3).reshape(1, H, W, Head * oldD)
        else:
            assert(B == W // Wsp and N == Hsp * Wsp)
            lepe = lepe.reshape(Hsp, Wsp, W // Wsp, Head * oldD).transpose(0, 2, 1, 3).reshape(1, H, W, Head * oldD)

        vWeight = vconv.reshape(pN // 64, 64, B * Head * vD // 64, 64)
        vWeight = rbpy.fliptranspose(vWeight, [2, 0, 3, 1], 3)
        vWeight = rbpy.reshape(vWeight, [B * Head * vD // 64, 1, 1, pN // 64, 64, 64])
        vBias = rbpy.reshape(vconv, [pN, B*Head*vD])
        # print(vBias)
        vBias = rbpy.genbias(vBias, qkzp, v_zp, pN)
        # print(vBias)

        bmmscale, bmmzp = kwargs['bmm2_out_scale'], kwargs['bmm2_out_zp']
        bmm2qparam = rbpy.QuantParam(bmmscale, bmmzp)


        # bmm2 = N, B, Head, oldD
        bmm2 = rbpy.nn.gconv(qkbmm, vWeight, vBias, [B, N, Head, D], 0, 2, B*Head, qparam=bmm2qparam)
        # if vD != oldD:
        #     bmm2 =  rbpy.crop(bmm2.reshape(1, N, B * Head, vD), crop_size=[0, vD - oldD], axes=[3])

        if need_transpose:
            # B = H // Hsp, N = Hsp * Wsp
            assert(B == H // Hsp and N == Hsp * Wsp)
            bmm2 = bmm2.reshape(1, Hsp * Wsp, H // Hsp, Head * oldD).transpose(0, 2, 1, 3).reshape(1, H, W, Head * oldD)
        else:
            #  B = W // Wsp, N = Hsp * Wsp
            assert(B == W // Wsp and N == Hsp * Wsp)
            bmm2 = bmm2.reshape(Hsp, Wsp, W // Wsp, Head * oldD).transpose(0, 2, 1, 3).reshape(1, H, W, Head * oldD)

        # 然后是add
        res = rbpy.nn.dequant(bmm2) + rbpy.nn.dequant(lepe)

        return res.reshape(xB, -1, HD)

    def to_ast(self):
        x, qkv_w, qkv_b, b1_scale, b1_depth_w, b1_depth_b, b2_scale, b2_depth_w, b2_depth_b = self.inputs
        scale, zp = self.get_scale_zp('qkv_in')
        x = rbpy.nn.quant(x, rbpy.QuantParam(scale, zp), dtype=rbpy.uint8)

        B = 56
        N = 56
        Head = 1
        C = 64
        D = 32

        H = int(math.sqrt(B * N))
        W = H
        x = x.reshape(1, H, W, C)

        qkvx1x2_w = qkv_w.numpy().reshape(3, 2, Head * D, 1, 1, C)
        qkvx1x2_b = qkv_b.numpy().reshape(3, 2, Head * D)

        qx1x2_w, kx1x2_w, vx1x2_w = qkvx1x2_w
        qx1x2_b, kx1x2_b, vx1x2_b = qkvx1x2_b

        qx1_w, qx2_w = qx1x2_w
        kx1_w, kx2_w = kx1x2_w
        vx1_w, vx2_w = vx1x2_w

        qx1_b, qx2_b = qx1x2_b
        kx1_b, kx2_b = kx1x2_b
        vx1_b, vx2_b = vx1x2_b

        names = [
            'bmm1_left_in_scale',
            'bmm1_left_in_zp',
            'bmm1_right_in_scale',
            'bmm1_right_in_zp',
            'bmm1_out_scale',
            'bmm1_out_zp',

            'bmm2_left_in_scale',
            'bmm2_left_in_zp',
            'bmm2_right_in_scale',
            'bmm2_right_in_zp',
            'bmm2_out_scale',
            'bmm2_out_zp',

            'depth_conv_out_scale',
            'depth_conv_out_zp',

            'conv_stride',
            'conv_dilation',
            'conv_pad_size',
            'conv_group',
        ]

        attrs = {n: self.node.get_attr(f"b1_{n}") for n in names}
        attrs['q_w'] = qx1_w
        attrs['q_b'] = qx1_b
        attrs['k_w'] = kx1_w
        attrs['k_b'] = kx1_b
        attrs['v_w'] = vx1_w
        attrs['v_b'] = vx1_b
        attrs['depth_w'] = b1_depth_w
        attrs['depth_b'] = b1_depth_b
        attrs['bmm_mul_scalar'] = b1_scale
        attrs['H_sp'] = 56
        attrs['W_sp'] = 1
        x1 = self.attn_to_ast(x, **attrs)

        attrs = {n: self.node.get_attr(f"b2_{n}") for n in names}
        attrs['q_w'] = qx2_w
        attrs['q_b'] = qx2_b
        attrs['k_w'] = kx2_w
        attrs['k_b'] = kx2_b
        attrs['v_w'] = vx2_w
        attrs['v_b'] = vx2_b
        attrs['depth_w'] = b2_depth_w
        attrs['depth_b'] = b2_depth_b
        attrs['bmm_mul_scalar'] = b2_scale
        attrs['H_sp'] = 1
        attrs['W_sp'] = 56
        x2 = self.attn_to_ast(x, **attrs)

        return rbpy.concat([x1, x2], axis=2)


@register_op('CSwinAttenFC')
class CSwinAttenFC(CSwinAtten):

    def to_ast(self):
        if self.node.name in ('/stage2_0/Concat', '/stage2_1/Concat'):
            B, Head, N, D, C = 14, 2, 56, 32, 128
            Hsp, Wsp = 28, 2
        else:
            B, Head, N, D, C = 2, 4, 98, 32, 256
            Hsp, Wsp = 14, 7

        x, qkv_w, qkv_b, b1_scale, b1_depth_w, b1_depth_b, b2_scale, b2_depth_w, b2_depth_b = self.inputs
        scale, zp = self.get_scale_zp('qkv_in')
        x = rbpy.nn.quant(x, rbpy.QuantParam(scale, zp), dtype=rbpy.uint8)

        H = int(math.sqrt(B * N))
        W = H
        x = x.reshape(1, H, W, C)

        qkvx1x2_w = qkv_w.numpy().reshape(3, 2, Head * D, 1, 1, C)
        qkvx1x2_b = qkv_b.numpy().reshape(3, 2, Head * D)
        qx1x2_w, kx1x2_w, vx1x2_w = qkvx1x2_w
        qx1x2_b, kx1x2_b, vx1x2_b = qkvx1x2_b

        qx1_w, qx2_w = qx1x2_w
        kx1_w, kx2_w = kx1x2_w
        vx1_w, vx2_w = vx1x2_w

        qx1_b, qx2_b = qx1x2_b
        kx1_b, kx2_b = kx1x2_b
        vx1_b, vx2_b = vx1x2_b

        names = [
            'bmm1_left_in_scale',
            'bmm1_left_in_zp',
            'bmm1_right_in_scale',
            'bmm1_right_in_zp',
            'bmm1_out_scale',
            'bmm1_out_zp',

            'bmm2_left_in_scale',
            'bmm2_left_in_zp',
            'bmm2_right_in_scale',
            'bmm2_right_in_zp',
            'bmm2_out_scale',
            'bmm2_out_zp',

            'depth_conv_out_scale',
            'depth_conv_out_zp',

            'conv_stride',
            'conv_dilation',
            'conv_pad_size',
            'conv_group',
        ]

        attrs = {n: self.node.get_attr(f"b1_{n}") for n in names}
        attrs['q_w'] = qx1_w
        attrs['q_b'] = qx1_b
        attrs['k_w'] = kx1_w
        attrs['k_b'] = kx1_b
        attrs['v_w'] = vx1_w
        attrs['v_b'] = vx1_b
        attrs['depth_w'] = b1_depth_w
        attrs['depth_b'] = b1_depth_b
        attrs['bmm_mul_scalar'] = b1_scale
        attrs['H_sp'] = Hsp
        attrs['W_sp'] = Wsp
        x1 = self.attn_to_ast(x, **attrs)

        attrs = {n: self.node.get_attr(f"b2_{n}") for n in names}
        attrs['q_w'] = qx2_w
        attrs['q_b'] = qx2_b
        attrs['k_w'] = kx2_w
        attrs['k_b'] = kx2_b
        attrs['v_w'] = vx2_w
        attrs['v_b'] = vx2_b
        attrs['depth_w'] = b2_depth_w
        attrs['depth_b'] = b2_depth_b
        attrs['bmm_mul_scalar'] = b2_scale
        attrs['H_sp'] = Wsp
        attrs['W_sp'] = Hsp
        x2 = self.attn_to_ast(x, **attrs)

        return rbpy.concat([x1, x2], axis=2)


@register_op('CSwinAttenLast')
class CSwinAttenLast(CSwinAtten):

    def to_ast(self):
        x, qkv_w, qkv_b, b1_scale, b1_depth_w, b1_depth_b = self.inputs
        scale, zp = self.get_scale_zp('qkv_in')
        x = rbpy.nn.quant(x, rbpy.QuantParam(scale, zp), dtype=rbpy.uint8)

        B, Head, N, D, C = 1, 16, 49, 32, 512
        H = int(math.sqrt(B * N))
        W = H
        x = x.reshape(1, H, W, C)

        qkv_w = qkv_w.numpy().reshape(3, Head * D, 1, 1, C)
        qkv_b = qkv_b.numpy().reshape(3, Head * D)

        q_w, k_w, v_w = qkv_w
        q_b, k_b, v_b = qkv_b

        names = [
            'bmm1_left_in_scale',
            'bmm1_left_in_zp',
            'bmm1_right_in_scale',
            'bmm1_right_in_zp',
            'bmm1_out_scale',
            'bmm1_out_zp',

            'bmm2_left_in_scale',
            'bmm2_left_in_zp',
            'bmm2_right_in_scale',
            'bmm2_right_in_zp',
            'bmm2_out_scale',
            'bmm2_out_zp',

            'depth_conv_out_scale',
            'depth_conv_out_zp',

            'conv_stride',
            'conv_dilation',
            'conv_pad_size',
            'conv_group',
        ]

        attrs = {n: self.node.get_attr(n) for n in names}
        attrs['q_w'] = q_w
        attrs['q_b'] = q_b
        attrs['k_w'] = k_w
        attrs['k_b'] = k_b
        attrs['v_w'] = v_w
        attrs['v_b'] = v_b
        attrs['depth_w'] = b1_depth_w
        attrs['depth_b'] = b1_depth_b
        attrs['bmm_mul_scalar'] = b1_scale
        attrs['H_sp'] = 7
        attrs['W_sp'] = 7
        x = self.attn_to_ast(x, **attrs)
        x = rbpy.expand_dims(x, 1)
        return x
