from PIL import Image


def preprocess(img_path, need_batch=True):
    import torchvision.transforms as transforms

    img = Image.open(img_path)
    img = img.convert("RGB")

    transform=transforms.Compose([
        transforms.Resize(342),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5])
    ])
    img = transform(img)
    img = img.permute(1, 2, 0)
    if need_batch:
        return img[None, ...]

    return img