import cv2
import torch
import torchvision
import numpy as np


def preprocess(img_path, with_meta=False, input_size=(640, 640)):
    img = cv2.imread(img_path)
    ori_size = img.shape[:2]
    scale = min(input_size[0] / ori_size[0], input_size[1] / ori_size[1])
    resized_size = int(img.shape[0] * scale), int(img.shape[1] * scale)
    resized_img = cv2.resize(img, resized_size[::-1], interpolation=cv2.INTER_LINEAR)
    padded_img = np.full((input_size[0], input_size[1],  3), 114., dtype=np.float32)
    padded_img[:resized_size[0], :resized_size[1]] = resized_img.astype(np.float32)
    img = np.expand_dims(padded_img, 0)
    if with_meta:
        img_meta = {"scale": scale}
        return img, img_meta
    return img


def postprocess(prediction, img_meta, conf=0.001, nms_thre=0.65, num_classes=80):
    prediction = torch.from_numpy(prediction[0]).to(torch.float32)
    box_corner = torch.empty_like(prediction, dtype=torch.float32)
    box_corner[:, 0] = prediction[:, 0] - prediction[:, 2] / 2
    box_corner[:, 1] = prediction[:, 1] - prediction[:, 3] / 2
    box_corner[:, 2] = prediction[:, 0] + prediction[:, 2] / 2
    box_corner[:, 3] = prediction[:, 1] + prediction[:, 3] / 2
    prediction[:, :4] = box_corner[:, :4]
    class_conf, class_pred = torch.max(prediction[:, 5:5+num_classes], 1, keepdim=True)
    conf_mask = (prediction[:, 4] * class_conf.squeeze() >= conf).squeeze()
    detections = torch.cat((prediction[:, :5], class_conf, class_pred.float()), 1)
    detections = detections[conf_mask]
    boxes = detections[:, :4]
    scores = detections[:, 4] * detections[:, 5]
    labels = detections[:, 6]
    c = labels.unsqueeze(-1) * 7680
    idx = torchvision.ops.nms(boxes + c, scores, nms_thre)
    bboxes = (boxes[idx] / img_meta['scale']).cpu().numpy()
    scores = scores[idx].cpu().numpy()
    labels = labels[idx].cpu().numpy().astype(np.int64).tolist()
    return bboxes, scores, labels