import cv2
import numpy as np
import torch
import torchvision
import torch.nn.functional as F


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = img.shape[:2]  # current shape [height, width]
    new_shape = (new_shape, new_shape)
    r = min(min(new_shape[0] / shape[0], new_shape[1] / shape[1]), 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border

    return img, (dw, dh)

def preprocess(img_path, with_meta=False, input_size=640):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    r = input_size / max(h, w)  # resize image to img_size
    img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=1)
    resize_h, resize_w = img.shape[:2]

    img, pad = letterbox(img, input_size)
    img = np.ascontiguousarray(img)
    img = np.expand_dims(img, 0)
    img = img.astype(np.float32) / 255.

    if with_meta:
        img_meta = {
            'img_size': [h, w],
            'scale_factor': [resize_h / h, resize_w / w],
            'pad_size': pad,
            'input_size': (input_size, input_size)
        }
        return img, img_meta

    return img

def to_tensor(x):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
        if torch.cuda.is_available():
            x = x.cuda()
    return x

def postprocess(out, img_meta, conf=0.001, iou_thres=0.65, max_det=300, with_scale=True, return_numpy=True):
    pred, protos = out
    pred = to_tensor(pred)
    protos = to_tensor(protos)
    output = non_max_suppression(pred, conf_thres=conf, iou_thres=iou_thres, 
                                 multi_label=True, agnostic=False, max_det=max_det)
    boxes, scores, labels, masks = output[:, :4], output[:, 4], output[:, 5], output[:, 6:]

    if with_scale:
        boxes = scale_boxes(boxes, img_meta)
        masks = process_mask(protos[0], masks, boxes, img_meta['img_size'])
    else:
        masks = process_mask(protos[0], masks, boxes, img_meta['input_size'])

    if return_numpy:
        return boxes.cpu().numpy(), scores.cpu().numpy(), labels.cpu().numpy(), masks.cpu().numpy()
    
    return boxes, scores, labels, masks

def xywh2xyxy(x):
    y = torch.empty_like(x, dtype=torch.float32)
    xy = x[..., :2]  # centers
    wh = x[..., 2:] / 2  # half width-height
    y[..., :2] = xy - wh  # top left xy
    y[..., 2:] = xy + wh  # bottom right xy
    return y

def non_max_suppression(prediction, nc=80, conf_thres=0.25, iou_thres=0.45, multi_label=False,
                        max_nms=30000, agnostic=False, max_wh=7680, max_det=300):
    nc = nc or (prediction.shape[1] - 4)  # number of classes
    extra = prediction.shape[1] - nc - 4  # number of extra info
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    # Settings
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)

    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
    prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy

    output = torch.zeros((0, 6 + extra), dtype=prediction.dtype, device=prediction.device)
    x = prediction[0][xc[0]]  # confidence
    box, cls, mask = x[:, :4], x[:, 4:mi], x[:, mi:]

    if multi_label:
        i, j = torch.where(cls > conf_thres)
        x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
    else:
        # best class only
        conf, j = cls.max(1, keepdim=True)
        filt = conf.view(-1) > conf_thres
        x = torch.cat((box, conf, j.float(), mask), 1)[filt]

    n = x.shape[0]
    if not n:  # no boxes
        return output

    if n > max_nms:  # excess boxes
        filt = x[:, 4].argsort(descending=True)[:max_nms]  # sort by confidence and remove excess boxes
        x = x[filt]

    c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
    scores = x[:, 4]  # scores
    boxes = x[:, :4] + c  # boxes (offset by class)
    i = torchvision.ops.nms(boxes, scores, iou_thres)

    i = i[:max_det]  # limit detections
    output = x[i]
    return output

def scale_boxes(boxes, img_meta):
    pad = img_meta['pad_size'] # w, h
    scale_factor = img_meta['scale_factor'] # h, w
    img_size = img_meta['img_size']

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    scale_h, scale_w = scale_factor
    scale_factor = torch.tensor([scale_w, scale_h, scale_w, scale_h], dtype=torch.float32, device=boxes.device)
    boxes[..., :4] /= scale_factor

    boxes[..., 0] = torch.clip(boxes[..., 0], 0, img_size[1]) # x1
    boxes[..., 1] = torch.clip(boxes[..., 1], 0, img_size[0]) # y1
    boxes[..., 2] = torch.clip(boxes[..., 2], 0, img_size[1]) # x2
    boxes[..., 3] = torch.clip(boxes[..., 3], 0, img_size[0]) # y2
    return boxes

def process_mask(protos, masks_in, bboxes, shape):
    """
    Apply masks to bounding boxes using mask head output.

    Args:
        protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
        masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
        bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
        shape (tuple): Input image size as (height, width).
        upsample (bool): Whether to upsample masks to original image size.

    Returns:
        (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
            are the height and width of the input image. The mask is applied to the bounding boxes.
    """
    protos = protos.permute(2, 0, 1)
    c, mh, mw = protos.shape  # CHW
    masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)  # CHW
    masks = scale_masks(masks[None], shape)[0]  # CHW
    masks = crop_mask(masks, boxes=bboxes)  # CHW
    return masks.gt_(0.0).byte()


def crop_mask(masks, boxes):
    """
    Crop masks to bounding box regions.

    Args:
        masks (torch.Tensor): Masks with shape (N, H, W).
        boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.

    Returns:
        (torch.Tensor): Cropped masks.
    """
    n, h, w = masks.shape
    if n < 50:  # faster for fewer masks (predict)
        for i, (x1, y1, x2, y2) in enumerate(boxes.round().int()):
            masks[i, :y1] = 0
            masks[i, y2:] = 0
            masks[i, :, :x1] = 0
            masks[i, :, x2:] = 0
        return masks
    else:  # faster for more masks (val)
        x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
        r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
        c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)
        return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))


def scale_masks(masks, shape, padding: bool = True):
    """
    Rescale segment masks to target shape.

    Args:
        masks (torch.Tensor): Masks with shape (N, C, H, W).
        shape (tuple): Target height and width as (height, width).
        padding (bool): Whether masks are based on YOLO-style augmented images with padding.

    Returns:
        (torch.Tensor): Rescaled masks.
    """
    mh, mw = masks.shape[2:]
    gain = min(mh / shape[0], mw / shape[1])  # gain  = old / new
    pad_w = mw - shape[1] * gain
    pad_h = mh - shape[0] * gain
    if padding:
        pad_w /= 2
        pad_h /= 2
    top, left = (int(round(pad_h - 0.1)), int(round(pad_w - 0.1))) if padding else (0, 0)
    bottom = mh - int(round(pad_h + 0.1))
    right = mw - int(round(pad_w + 0.1))

    return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear")  # NCHW masks
