import cv2
import math
import torch
import numpy as np


def preprocess(img_path, with_meta=False, input_size=1024):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    r = input_size / max(h, w)  # resize image to img_size
    img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=1)
    resize_h, resize_w = img.shape[:2]

    img, pad = letterbox(img, input_size)
    img = np.ascontiguousarray(img)
    img = np.expand_dims(img, 0)
    img = img.astype(np.float32) / 255.

    if with_meta:
        img_meta = {
            'img_size': [h, w],
            'scale_factor': [resize_h / h, resize_w / w],
            'pad_size': pad
        }
        return img, img_meta

    return img


def regularize_rboxes(rboxes, obb):
    """
    Regularize rotated bounding boxes to range [0, pi/2].

    Args:
        rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.

    Returns:
        (torch.Tensor): Regularized rotated boxes.
    """
    x, y, w, h, t = rboxes[:, 0], rboxes[:, 1], rboxes[:, 2], rboxes[:, 3], obb
    # Swap edge if t >= pi/2 while not being symmetrically opposite
    swap = t % math.pi >= math.pi / 2
    w_ = np.where(swap, h, w)
    h_ = np.where(swap, w, h)
    t = t % (math.pi / 2)
    return np.stack([x, y, w_, h_], axis=-1), t  # regularized boxes

def postprocess(out, img_meta, conf=0.001, iou_thres=0.65, with_scale=True, max_det=300):
    out, obb = non_max_suppression(out, conf_thres=conf, iou_thres=iou_thres, max_det=max_det)
    boxes, scores, labels = out[:, :4], out[:, 4], out[:, 5]

    if with_scale:
        boxes, obb = regularize_rboxes(boxes, obb)
        boxes = scale_boxes(boxes, img_meta)
    
    return boxes, scores, labels, obb


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = img.shape[:2]  # current shape [height, width]
    if not isinstance(new_shape, (tuple, list)):
        new_shape = (new_shape, new_shape)
    r = min(min(new_shape[0] / shape[0], new_shape[1] / shape[1]), 1.0)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border

    return img, (dw, dh)

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, multi_label=False, max_det=300):
    # prediction is [1, 20, N]
    obb = prediction[0, -1, :]
    prediction = prediction[:, :-1, :]
    nc = prediction.shape[1] - 4
    max_wh = 7680
    xc = prediction[:, 4:].max(1) > conf_thres
    multi_label &= nc > 1
    prediction = np.transpose(prediction, [0, 2, 1])  # shape(1,84,6300) to shape(1,6300,84)

    output = np.zeros((0, 6), dtype=prediction.dtype)
    x = prediction[0][xc[0]]  # confidence
    obb = obb[xc[0]]
    box, cls = x[:, :4], x[:, 4:]

    if multi_label:
        i, j = (cls > conf_thres).nonzero()
        x = np.concatenate((box[i], x[i, 4 + j, None], j[:, None].astype(np.float32)), 1)
    else:  # best class only
        conf = np.max(cls, 1, keepdims=True)
        j = np.argmax(cls, 1, keepdims=True)
        x = np.concatenate((box, conf, j.astype(np.float32)), 1)[conf.reshape(-1) > conf_thres]

    c = x[:, 5:6] * (0 if agnostic else max_wh)
    scores = x[:, 4]
    boxes = np.concatenate((x[:, :2] + c, x[:, 2:4], obb[:, None]), axis=-1)  # xywhr

    index = fast_nms(torch.from_numpy(boxes), torch.from_numpy(scores), iou_thres)
    index = index.cpu().numpy()
    index = index[:max_det]

    output = x[index]
    obb = obb[index]
    return output, obb

def fast_nms(
        boxes,
        scores,
        iou_threshold: float,
        use_triu: bool = True,
        exit_early: bool = True,
    ):
        """
        Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.

        Args:
            boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
            scores (torch.Tensor): Confidence scores with shape (N,).
            iou_threshold (float): IoU threshold for suppression.
            use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
            iou_func (callable): Function to compute IoU between boxes.
            exit_early (bool): Whether to exit early if there are no boxes.

        Returns:
            (torch.Tensor): Indices of boxes to keep after NMS.

        Examples:
            Apply NMS to a set of boxes
            >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
            >>> scores = torch.tensor([0.9, 0.8])
            >>> keep = TorchNMS.nms(boxes, scores, 0.5)
        """
        if boxes.numel() == 0 and exit_early:
            return torch.empty((0,), dtype=torch.int64, device=boxes.device)

        sorted_idx = torch.argsort(scores, descending=True)
        boxes = boxes[sorted_idx]
        ious = batch_probiou(boxes, boxes)
        if use_triu:
            ious = ious.triu_(diagonal=1)
            # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
            pick = torch.nonzero((ious >= iou_threshold).sum(0) <= 0).squeeze_(-1)
        else:
            n = boxes.shape[0]
            row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
            col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
            upper_mask = row_idx < col_idx
            ious = ious * upper_mask
            # Zeroing these scores ensures the additional indices would not affect the final results
            scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
            # NOTE: return indices with fixed length to avoid TFLite reshape error
            pick = torch.topk(scores, scores.shape[0]).indices
        return sorted_idx[pick]

def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
    """
    Calculate the probabilistic IoU between oriented bounding boxes.

    Args:
        obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
        obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
        eps (float, optional): A small value to avoid division by zero.

    Returns:
        (torch.Tensor): A tensor of shape (N, M) representing obb similarities.

    References:
        https://arxiv.org/pdf/2106.06072v1.pdf
    """
    obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
    obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2

    x1, y1 = obb1[..., :2].split(1, dim=-1)
    x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
    a1, b1, c1 = _get_covariance_matrix(obb1)
    a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))

    t1 = (
        ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
    ) * 0.25
    t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
    t3 = (
        ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
        / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
        + eps
    ).log() * 0.5
    bd = (t1 + t2 + t3).clamp(eps, 100.0)
    hd = (1.0 - (-bd).exp() + eps).sqrt()
    return 1 - hd

def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate covariance matrix from oriented bounding boxes.

    Args:
        boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.

    Returns:
        (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
    """
    # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
    gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
    a, b, c = gbbs.split(1, dim=-1)
    cos = c.cos()
    sin = c.sin()
    cos2 = cos.pow(2)
    sin2 = sin.pow(2)
    return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin

def scale_boxes(boxes, img_meta):
    pad = img_meta['pad_size'] # w, h
    scale_factor = img_meta['scale_factor'] # h, w
    img_size = img_meta['img_size']

    boxes[..., 0] -= pad[0]  # x padding
    boxes[..., 1] -= pad[1]  # y padding
    scale_h, scale_w = scale_factor
    scale_factor = np.array([scale_w, scale_h, scale_w, scale_h], dtype=np.float32)
    boxes[..., :4] /= scale_factor

    return boxes