import os
import sys
import clip
import torch
from torchvision.datasets import CIFAR100


class TextModel(torch.nn.Module):
    def __init__(self, model):
        super(TextModel, self).__init__()
        self.model = model

    def forward(self, text):
        x = self.model.token_embedding(text)
        x = x + self.model.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.model.ln_final(x)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection

        return x


def main(onnx_dir):
    # Load the model
    device = "cpu"
    model, preprocess = clip.load('ViT-B/16', device)

    # Download the dataset
    cifar100 = CIFAR100(root='./cifar100', download=False, train=False)

    # Prepare the inputs
    image, class_id = cifar100[3637]
    image_input = preprocess(image).unsqueeze(0).to(device)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

    # Calculate features
    with torch.no_grad():
        vision_onnx_path = os.path.join(onnx_dir, "clip_vision.onnx")
        torch.onnx.export(model.visual, image_input, vision_onnx_path)
        os.system(f"onnxsim {vision_onnx_path} {vision_onnx_path}")
        print(f"ONNX file has been saved in {vision_onnx_path}")

        text_onnx_path = os.path.join(onnx_dir, "clip_text.onnx")
        torch.onnx.export(TextModel(model), text_inputs[:1], text_onnx_path)
        os.system(f"onnxsim {text_onnx_path} {text_onnx_path}")
        print(f"ONNX file has been saved in {text_onnx_path}")

if __name__ == "__main__":
    main(sys.argv[1])