import os
import numpy as np
from tqdm import tqdm
from typing import Union, List
import torchvision.transforms as transforms

import torch
from packaging import version

from simple_tokenizer import SimpleTokenizer

def preprocess(img):
    transform=transforms.Compose([
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                             std=[0.26862954, 0.26130258, 0.27577711]),
        Channel2Last()
    ])
    return transform(img)


class Channel2Last:
    """ Make input from NCHW to NHWC.
    """

    def __call__(self, x):
        if x.ndim == 3:
            return x.permute((1, 2, 0)).contiguous()
        elif x.ndim == 4:
            return x.permute((0, 2, 3, 1)).contiguous()
        elif x.ndim == 5:
            return x.permute((0, 2, 3, 4, 1)).contiguous()
        else:
            raise ValueError(f"Unsupport {x.ndim} to channel last.")


def get_text_weights(args, text_runner):
    zeroshot_weights = []
    classes = [
        'apple', 'aquarium fish', 'baby', 'bear', 'beaver', 'bed', 'bee',
        'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly',
        'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee',
        'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house',
        'kangaroo', 'keyboard', 'lamp', 'lawn mower', 'leopard', 'lion', 'lizard', 'lobster',
        'man', 'maple tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak tree', 'orange',
        'orchid', 'otter', 'palm tree', 'pear', 'pickup truck', 'pine tree', 'plain', 'plate',
        'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
        'squirrel', 'streetcar', 'sunflower', 'sweet pepper', 'table', 'tank', 'telephone',
        'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe',
        'whale', 'willow tree', 'wolf', 'woman', 'worm'
    ]
    templates = ['itap of a {}.']
    for index, classname in enumerate(tqdm(classes, total=100, desc="Encoding CIFAR100 Label")):
        texts = [template.format(classname) for template in templates] #format with class 
        inputs = tokenize(args.data, texts, context_length=77)
        class_embeddings = text_runner(inputs)
        class_embeddings /= np.linalg.norm(class_embeddings, axis=-1, keepdims=True)
        class_embeddings = np.mean(class_embeddings, axis=0)
        class_embeddings /= np.linalg.norm(class_embeddings)
        zeroshot_weights.append(class_embeddings)

    zeroshot_weights = np.stack(zeroshot_weights, 1)
    return zeroshot_weights


def tokenize(dict_dir, texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
    """
    tokenizer = SimpleTokenizer(os.path.join(dict_dir, 'bpe_simple_vocab_16e6.txt.gz'))

    if isinstance(texts, str):
        texts = [texts]

    sot_token = tokenizer.encoder["<|startoftext|>"]
    eot_token = tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts]
    if version.parse(torch.__version__) < version.parse("1.8.0"):
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
    else:
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result