import cv2
import torch
import torchvision
import numpy as np


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = img.shape[:2]  # current shape [height, width]
    new_shape = (new_shape, new_shape)
    r = min(min(new_shape[0] / shape[0], new_shape[1] / shape[1]), 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border

    return img, (dw, dh)

def preprocess(img_path, with_meta=False, input_size=640):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    r = input_size / max(h, w)  # resize image to img_size
    img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=1)
    resize_h, resize_w = img.shape[:2]
    
    img, pad = letterbox(img, input_size)
    img = np.ascontiguousarray(img)
    img = np.expand_dims(img, 0)
    img = img.astype(np.float32) / 255.

    if with_meta:
        img_meta = {
            'img_size': [h, w],
            'scale_factor': [resize_h / h, resize_w / w],
            'pad_size': pad
        }
        return img, img_meta
    
    return img

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.65, agnostic=False, multi_label=False, max_det=300):
    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates
    multi_label &= nc > 1
    max_wh = 7680
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]

    for xi, x in enumerate(prediction):  # image index, image inference
        x = x[xc[xi]]  # confidence
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
        
        box = x[:, :4].clone() if isinstance(x[:, :4], torch.Tensor) else np.copy(x[:, :4])
        box[:, 0] = x[:, :4][:, 0] - x[:, :4][:, 2] / 2  # top left x
        box[:, 1] = x[:, :4][:, 1] - x[:, :4][:, 3] / 2  # top left y
        box[:, 2] = x[:, :4][:, 0] + x[:, :4][:, 2] / 2  # bottom right x
        box[:, 3] = x[:, :4][:, 1] + x[:, :4][:, 3] / 2  # bottom right y
        
        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
            
        n = x.shape[0]  # number of boxes
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        ind = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        ind = ind[:max_det]
        output[xi] = x[ind]

    return [o.numpy() for o in output]


def scale_boxes(boxes, img_meta):
    pad = img_meta['pad_size'] # w, h
    scale_factor = img_meta['scale_factor'] # h, w
    img_size = img_meta['img_size']

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    scale_h, scale_w = scale_factor
    scale_factor = np.array([scale_w, scale_h, scale_w, scale_h], dtype=np.float32)
    boxes[..., :4] /= scale_factor

    boxes[..., 0] = np.clip(boxes[..., 0], 0, img_size[1]) # x1
    boxes[..., 1] = np.clip(boxes[..., 1], 0, img_size[0]) # y1
    boxes[..., 2] = np.clip(boxes[..., 2], 0, img_size[1]) # x2
    boxes[..., 3] = np.clip(boxes[..., 3], 0, img_size[0]) # y2
    return boxes

def postprocess(out, img_meta, conf=0.001):
    # out is [B, 85, N]
    if isinstance(out, (tuple, list)):
        xy, wh, conf_cls = out
        out = np.concatenate([xy, wh, conf_cls], 1)
        out = np.transpose(out, [0, 2, 1])

    out = torch.from_numpy(out)
    outputs = non_max_suppression(out, conf_thres=conf, iou_thres=0.65 , multi_label=True, agnostic=False, max_det=300)
    assert len(outputs) == 1
    out = outputs[0]
    boxes = scale_boxes(out[:, :4], img_meta)
    scores, labels = out[:, 4], out[:, 5]
    return boxes, scores, labels