import sys
import yaml
import rbcc
import math
import numpy as np
from rbcc.rbir.sg_ir import SGIR
from rbcc.utils.helpers import extract_subgraph


def get_grid(outputs_shape):
    grids = []
    for shape in outputs_shape:
        grid_y = np.tile(
            np.arange(shape[0]).reshape([-1, 1, 1, 1]), [1, shape[1], 1, 1]
        )
        grid_x = np.tile(
            np.arange(shape[1]).reshape([1, -1, 1, 1]), [shape[0], 1, 1, 1]
        )
        grid = np.concatenate([grid_x, grid_y], axis=-1)
        grid = grid.astype(np.float32)
        grids.append(grid)
    return grids


def make_detect(ir, conv_outs, anchor, nc, no, na, stride, grid):
    # nc: number of classes
    # no: number of outputs per anchor
    # nl: number of detection layers
    # na: number of anchors

    xy_conv, wh_conv, conf_conv = conv_outs

    xy_conv = ir.sigmoid(xy_conv)
    wh_conv = ir.sigmoid(wh_conv)
    conf_conv = ir.sigmoid(conf_conv)

    # 前面对输出通道进行了 pad，这里需要 crop 回去
    batch, _, _, c = xy_conv.shape
    xy_conv = ir.reshape(xy_conv, [batch, -1, c])
    xy_conv = ir.transpose(xy_conv, [0, 2, 1])
    xy_conv = ir.stride_slice(xy_conv, start=[0], end=[na*2], stride=[1], axes=[1])

    wh_conv = ir.reshape(wh_conv, [batch, -1, c])
    wh_conv = ir.transpose(wh_conv, [0, 2, 1])
    wh_conv = ir.stride_slice(wh_conv, start=[0], end=[na*2], stride=[1], axes=[1])

    batch, _, _, c = conf_conv.shape
    conf_cls_out = ir.reshape(conf_conv, [batch, -1, c])
    conf_cls_out = ir.transpose(conf_cls_out, [0, 2, 1])
    conf_cls_out = ir.stride_slice(conf_cls_out, start=[0], end=[na*(no-4)], stride=[1], axes=[1])

    xy_out = ir.mul(xy_conv, ir.constant(np.array(2 * stride, dtype=np.float32)))

    wh_out = ir.mul(wh_conv, ir.constant(np.array(2, dtype=np.float32)))
    wh_out = ir.square(wh_out)
    wh_out = ir.mul(wh_out, ir.constant(np.array(anchor, dtype=np.float32).reshape(1, 6, 1)))

    # add_rhs is [1, 2, H*W]
    add_rhs = (-0.5 + grid.reshape(1, 1, -1, 2).transpose(0, 1, 3, 2)) * stride
    # [batch, 6, H*W] -> [batch, 3, 2, H*W]
    xy_out = ir.reshape(xy_out, [batch, na, 2, -1])
    # [batch, 3, 2, H*W] + [1, 1, 2, H*W]
    xy_out = ir.add(xy_out, ir.constant(add_rhs.astype(np.float32)))

    # [batch, 3, 2, H*W] -> [batch, 2, 3, H*W]
    xy_out = ir.transpose(xy_out, [0, 2, 1, 3])
    xy_out = ir.reshape(xy_out, [batch, 2, -1])

    # [batch, 6, H*W] -> [batch, na, 2, H*W] -> [batch, na, H*W, 2] -> [batch, na*H*W, 2]
    wh_out = ir.reshape(wh_out, [batch, na, 2, -1])
    wh_out = ir.transpose(wh_out, [0, 2, 1, 3])
    wh_out = ir.reshape(wh_out, [batch, 2, -1])

    # [batch, na*nc, H*W] -> [batch, na, nc, H*W] -> [batch, na, H*W, nc] -> [batch, na*H*W, nc]
    conf_cls_out = ir.reshape(conf_cls_out, [batch, na, no-4, -1])
    conf_cls_out = ir.transpose(conf_cls_out, [0, 2, 1, 3])
    conf_cls_out = ir.reshape(conf_cls_out, [batch, no-4, -1])

    return xy_out, wh_out, conf_cls_out


def split_conv(ir, conv, na, pf=64):
    # na: number of anchors

    # 1. spit conv
    x = conv.inputs[0]
    # the shape of weight is [out_c, Kh, Kw, in_c]
    weight = conv.inputs[1].node.tensor
    out_c, kh, kw, in_c = weight.shape
    # the shape of bias is [out_c]
    bias = conv.inputs[2].node.tensor
    
    weight = weight.reshape(na, -1, kh, kw, in_c)
    bias = bias.reshape(na, -1)

    conv_outs = []
    for i in range(3):
        if i < 2:
            w = weight[:, i*2:i*2 + 2].reshape(-1, kh, kw, in_c)
            b = bias[:, i*2:i*2 + 2].reshape(-1)
        else:
            w = weight[:, i*2:].reshape(-1, kh, kw, in_c)
            b = bias[:, i*2:].reshape(-1)

        # 提前对输出通道对齐到硬件并行度
        new_out_c = w.shape[0]
        pad_c = math.ceil(new_out_c / pf) * pf
        w = np.concatenate([w, np.zeros([pad_c-new_out_c, kh, kw, in_c], dtype=np.float32)], axis=0)
        b = np.concatenate([b, np.zeros([pad_c-new_out_c], dtype=np.float32)], axis=0)

        # 返回的 value 对象，不是 node 对象
        new_conv = ir.conv2d(
            x, ir.constant(w), ir.constant(b),
            stride=conv.get_attr('stride'),
            dilation=conv.get_attr('dilation'),
            pad_size=conv.get_attr('pad_size')
        )
        conv_outs.append(new_conv)

    return conv_outs

def main(sg_file, yolo_yaml, new_sg_file):
    sg = rbcc.load_sg(sg_file)
    with open(yolo_yaml, 'r', encoding='utf-8') as file:
        yolo_cfg = yaml.safe_load(file)

    anchors = yolo_cfg['anchors']
    nc  = yolo_cfg['nc']            # number of classes
    no = nc + 5                     # number of outputs per anchor
    nl = len(anchors)               # number of detection layers
    na = len(anchors[0]) // 2       # number of anchors
    pf = 64                         # 硬件并行度

    nodes = sg.nodes()
    convs = []
    model_size = 640
    for node in nodes:
        if (
            node.op_name == 'Conv2D' and
            node.output.shape[-1] == (no * na) and
            len(node.output.users) == 1 and
            node.output.users[0].op_name == 'Sigmoid'
        ):
            # 找到模型最后输出的 Conv
            convs.append(node)
        
        if node.op_name == 'Input':
            model_size = node.get_attr('shape')[1]
        
    if len(convs) != nl:
        raise ValueError(f"expected {nl} outputs conv, not {len(convs)}")
    
    conv_map = {c: c.output.shape[1] for c in convs}
    convs = sorted(conv_map.items(), key=lambda v: v[1], reverse=True)
    convs = [c for c, _ in convs]

    ir = SGIR(sg)
    grids = get_grid([c.output.shape[1:3] for c in convs])
    xys, whs, confs = [], [], []
    for i, conv in enumerate(convs):
        anchor = anchors[i]
        stride = int(math.ceil(model_size / conv.output.shape[1]))
        conv_outs = split_conv(ir, conv, na, pf)
        xy, wh, conf = make_detect(ir, conv_outs, anchor, nc, no, na, stride, grids[i])
        xys.append(xy)
        whs.append(wh)
        confs.append(conf)

    xy = ir.concat(xys, -1)
    wh = ir.concat(whs, -1)
    conf = ir.concat(confs, -1)

    # 提取新的子图，删除原 sg 的 detect 部分
    new_sg = extract_subgraph(sg, [], [xy.name, wh.name, conf.name])
    rbcc.save_sg(new_sg, new_sg_file)


if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2], sys.argv[3])