import cv2
import numpy as np


def preprocess(img_path, with_meta=False, input_size=640):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    r = input_size / max(h, w)  # resize image to img_size
    img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=1)
    resize_h, resize_w = img.shape[:2]

    img, pad = letterbox(img, input_size)
    img = np.ascontiguousarray(img)
    img = np.expand_dims(img, 0)
    img = img.astype(np.float32) / 255.

    if with_meta:
        img_meta = {
            'img_size': [h, w],
            'scale_factor': [resize_h / h, resize_w / w],
            'pad_size': pad
        }
        return img, img_meta

    return img


def postprocess(out, img_meta, conf=0.001):
    out = non_max_suppression(out, conf_thres=conf, iou_thres=0.7 , multi_label=True, agnostic=False, max_det=300)
    boxes = scale_boxes(out[:, :4], img_meta)
    scores, labels = out[:, 4], out[:, 5]
    return boxes, scores, labels


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = img.shape[:2]  # current shape [height, width]
    new_shape = (new_shape, new_shape)
    r = min(min(new_shape[0] / shape[0], new_shape[1] / shape[1]), 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border

    return img, (dw, dh)


def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, multi_label=False, max_det=300):
    # prediction is [1, 84, N]
    nc = prediction.shape[1] - 4
    max_wh = 7680
    xc = prediction[:, 4:].max(1) > conf_thres
    multi_label &= nc > 1
    prediction = np.transpose(prediction, [0, 2, 1])  # shape(1,84,6300) to shape(1,6300,84)

    y = np.zeros_like(prediction[..., :4])
    dw = prediction[..., :4][..., 2] / 2
    dh = prediction[..., :4][..., 3] / 2
    y[..., 0] = prediction[..., :4][..., 0] - dw
    y[..., 1] = prediction[..., :4][..., 1] - dh
    y[..., 2] = prediction[..., :4][..., 0] + dw
    y[..., 3] = prediction[..., :4][..., 1] + dh
    prediction[..., :4] = y

    output = np.zeros((0, 6), dtype=prediction.dtype)
    x = prediction[0][xc[0]]  # confidence
    box, cls = x[:, :4], x[:, 4:]

    if multi_label:
        i, j = (cls > conf_thres).nonzero()
        x = np.concatenate((box[i], x[i, 4 + j, None], j[:, None].astype(np.float32)), 1)
    else:  # best class only
        conf, j = cls.max(1, keepdim=True)
        x = np.concatenate((box, conf, j.asdtype(np.float32)), 1)[conf.reshape(-1) > conf_thres]

    # if x.shape[0] > max_nms:  # excess boxes
    #     x = x[x[:, 4].argsort(descending=True)[:max_nms]]

    c = x[:, 5:6] * (0 if agnostic else max_wh)
    boxes, scores = x[:, :4] + c, x[:, 4]
    index = nms(boxes, scores, iou_thres, max_det)
    output = x[index]
    return output

def nms(dets, scores, iou_thres, max_det):
    x1 = dets[:, 0]  # xmin
    y1 = dets[:, 1]  # ymin
    x2 = dets[:, 2]  # xmax
    y2 = dets[:, 3]  # ymax
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        if order.size == 1:
            break
        if len(keep) == max_det:
            break
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        # 计算相交框的面积
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        IOU = inter / (areas[i] + areas[order[1:]] - inter)

        left_index = (np.where(IOU <= iou_thres))[0]
        order = order[left_index + 1]

    return keep

def scale_boxes(boxes, img_meta):
    pad = img_meta['pad_size'] # w, h
    scale_factor = img_meta['scale_factor'] # h, w
    img_size = img_meta['img_size']

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    scale_h, scale_w = scale_factor
    scale_factor = np.array([scale_w, scale_h, scale_w, scale_h], dtype=np.float32)
    boxes[..., :4] /= scale_factor

    boxes[..., 0] = np.clip(boxes[..., 0], 0, img_size[1]) # x1
    boxes[..., 1] = np.clip(boxes[..., 1], 0, img_size[0]) # y1
    boxes[..., 2] = np.clip(boxes[..., 2], 0, img_size[1]) # x2
    boxes[..., 3] = np.clip(boxes[..., 3], 0, img_size[0]) # y2
    return boxes