from PIL import Image
import torchvision.transforms as transforms


def preprocess(img_path, need_batch=True):
    img = Image.open(img_path)
    img = img.convert("RGB")

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    img = transform(img)
    img = img.permute(1, 2, 0)
    if need_batch:
        return img[None, ...].numpy()

    return img.numpy()
