import cv2
import numpy as np

import torch
import torchvision


def preprocess(img_path, with_meta=False, input_size=416):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    shape = img.shape[:2]
    img = cv2.resize(img, (input_size, input_size), interpolation=cv2.INTER_LINEAR)
    resize_h, resize_w = img.shape[:2]
    img = img.astype(np.float32) / 255.
    img = np.expand_dims(img, 0)

    if with_meta:
        img_meta = {
            'img_size': shape,
            'scale_factor': [resize_h / shape[0], resize_w / shape[1]]
        }
        return img, img_meta

    return img


def postprocess(logits, img_meta, conf=0.001):
    logits = torch.from_numpy(logits).permute(0, 3, 1, 2)
    anchors = torch.tensor([
        0.57273, 0.677385,
        1.87446, 2.06253,
        3.33843, 5.47434,
        7.88282, 3.52778,
        9.77052, 9.16828], dtype=torch.float32).reshape(-1, 2)
    
    predictions = post_processing(logits, anchors, conf_threshold=conf, nms_threshold=0.5)

    if len(predictions) == 0:
        return [], [], []
    else:
        predictions = np.array(predictions[0])

    for index, pred in enumerate(predictions):
        height, width = img_meta['img_size']
        height_ratio, width_ratio = img_meta['scale_factor']
        xmin = int(max(pred[0] / width_ratio, 0))
        ymin = int(max(pred[1] / height_ratio, 0))
        xmax = int(min((pred[0] + pred[2]) / width_ratio, width))
        ymax = int(min((pred[1] + pred[3]) / height_ratio, height))
        predictions[index][:4] = np.array([xmin, ymin, xmax, ymax])

    boxes, scores, labels = predictions[:, :4], predictions[:, 4], predictions[:, 5]
    boxes = boxes.astype(np.int64)

    return boxes, scores, labels



def post_processing(logits, anchors, conf_threshold, nms_threshold):
    image_size = 416
    num_anchors = len(anchors)
    
    if logits.dim() == 3:
        logits.unsqueeze_(0)

    batch = logits.size(0)
    h = logits.size(2)
    w = logits.size(3)

    # Compute xc,yc, w,h, box_score on Tensor
    lin_x = torch.linspace(0, w - 1, w).repeat(h, 1).view(h * w)
    lin_y = torch.linspace(0, h - 1, h).repeat(w, 1).t().contiguous().view(h * w)
    anchor_w = anchors[:, 0].contiguous().view(1, num_anchors, 1)
    anchor_h = anchors[:, 1].contiguous().view(1, num_anchors, 1)

    logits = logits.view(batch, num_anchors, -1, h * w)
    logits[:, :, 0, :].sigmoid_().add_(lin_x).div_(w)
    logits[:, :, 1, :].sigmoid_().add_(lin_y).div_(h)
    logits[:, :, 2, :].exp_().mul_(anchor_w).div_(w)
    logits[:, :, 3, :].exp_().mul_(anchor_h).div_(h)
    logits[:, :, 4, :].sigmoid_()

    cls_scores = torch.nn.functional.softmax(logits[:, :, 5:, :], 2)

    cls_max, cls_max_idx = torch.max(cls_scores, 2)
    cls_max_idx = cls_max_idx.float()
    cls_max.mul_(logits[:, :, 4, :])

    score_thresh = cls_max > conf_threshold
    score_thresh_flat = score_thresh.view(-1)

    if score_thresh.sum() == 0:
        predicted_boxes = []
        for i in range(batch):
            predicted_boxes.append(torch.Tensor([]))
    else:
        coords = logits.transpose(2, 3)[..., 0:4]
        coords = coords[score_thresh[..., None].expand_as(coords)].view(-1, 4)
        scores = cls_max[score_thresh]
        idx = cls_max_idx[score_thresh]
        detections = torch.cat([coords, scores[:, None], idx[:, None]], dim=1)

        max_det_per_batch = num_anchors * h * w
        slices = [slice(max_det_per_batch * i, max_det_per_batch * (i + 1)) for i in range(batch)]
        det_per_batch = torch.IntTensor([score_thresh_flat[s].int().sum() for s in slices])
        split_idx = torch.cumsum(det_per_batch, dim=0)

        # Group detections per image of batch
        predicted_boxes = []
        start = 0
        for end in split_idx:
            predicted_boxes.append(detections[start: end])
            start = end

    selected_boxes = []

    for boxes in predicted_boxes:
        if boxes.numel() == 0:
            return boxes

        a = boxes[:, :2]
        b = boxes[:, 2:4]
        bboxes = torch.cat([a - b / 2, a + b / 2], 1)
        scores = boxes[:, 4]

        i = torchvision.ops.nms(bboxes, scores, nms_threshold) 
        selected_boxes.append(boxes[i])

    final_boxes = []
    for boxes in selected_boxes:
        if boxes.dim() == 0:
            final_boxes.append([])
        else:
            boxes[:, 0:3:2] *= image_size
            boxes[:, 0] -= boxes[:, 2] / 2
            boxes[:, 1:4:2] *= image_size
            boxes[:, 1] -= boxes[:, 3] / 2

            final_boxes.append([[box[0].item(), box[1].item(), box[2].item(), box[3].item(), box[4].item(),
                                 int(box[5].item())] for box in boxes])
    return final_boxes



