import cv2
import numpy as np


def preprocess(img_path, with_meta=False, input_size=640):
    """Return updated labels and image with added border."""
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    shape = img.shape[:2]  # current shape [height, width]
    new_shape = (input_size, input_size)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    # Compute padding
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    dw /= 2  # divide padding into 2 sides
    dh /= 2

    img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    resize_h, resize_w = img.shape[:2]

    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(
        img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
    )  # add border
    img = img.astype(np.float32) / 255.
    img = img[None]

    if with_meta:
        img_meta = {
            'img_size': shape,
            'scale_factor': [resize_h / shape[0], resize_w / shape[1]],
            'pad_size': [dw, dh]
        }
        return img, img_meta

    return img


def scale_boxes(boxes, img_meta):
    pad = img_meta['pad_size'] # w, h
    scale_factor = img_meta['scale_factor'] # h, w
    img_size = img_meta['img_size']

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    scale_h, scale_w = scale_factor
    scale_factor = np.array([scale_w, scale_h, scale_w, scale_h], dtype=np.float32)
    boxes[..., :4] /= scale_factor

    boxes[..., 0] = np.clip(boxes[..., 0], 0, img_size[1]) # x1
    boxes[..., 1] = np.clip(boxes[..., 1], 0, img_size[0]) # y1
    boxes[..., 2] = np.clip(boxes[..., 2], 0, img_size[1]) # x2
    boxes[..., 3] = np.clip(boxes[..., 3], 0, img_size[0]) # y2
    return boxes

def topk(array, k, axis=-1, sorted=True):
    # Use np.argpartition is faster than np.argsort, but do not return the values in order
    # We use array.take because you can specify the axis
    partitioned_ind = (
        np.argpartition(array, -k, axis=axis)
        .take(indices=range(-k, 0), axis=axis)
    )
    # We use the newly selected indices to find the score of the top-k values
    partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
    
    if sorted:
        # Since our top-k indices are not correctly ordered, we can sort them with argsort
        # only if sorted=True (otherwise we keep it in an arbitrary order)
        sorted_trunc_ind = np.flip(
            np.argsort(partitioned_scores, axis=axis), axis=axis
        )
        
        # We again use np.take_along_axis as we have an array of indices that we use to
        # decide which values to select
        ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
        scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
    else:
        ind = partitioned_ind
        scores = partitioned_scores
    
    return scores, ind

def _postprocess(preds, max_det, nc):
    # preds is [8400, 84]
    boxes, scores = preds[..., :4], preds[..., 4:]
    max_scores = np.max(scores, -1)
    max_scores, index = topk(max_scores, max_det, axis=-1, sorted=False)
    boxes = boxes[index]
    scores = scores[index]

    scores, index = topk(scores.reshape(-1), max_det, axis=-1, sorted=False)
    labels = index % nc
    index = index // nc
    boxes = boxes[index]
    return boxes, scores, labels

def postprocess(out, img_meta, conf=None):
    # out is may [1, 300, 6] or [1, 8400, 84]
    out = out[0]
    if out.shape[-1] != 6:
        boxes, scores, labels = _postprocess(out, 300, 80)
        boxes = scale_boxes(boxes, img_meta)
    else:
        boxes = scale_boxes(out[:, :4], img_meta)
        scores, labels = out[:, 4], out[:, 5]
    return boxes, scores, labels