HomeAboutMeBlogGuest
© 2025 Sejin Cha. All rights reserved.
Built with Next.js, deployed on Vercel
장지원 페이지/
📕
2024 UGRP
/
Member Page
Member Page
/
장지원
장지원
/
#15. 적합한 model 찾아보기

#15. 적합한 model 찾아보기

태그
자료 조사
날짜
May 1, 2024
상태
완료
++ attribute를 함께 활용할 수 있는 모델을 찾아보면 좋을 것 같다. → label을 두 개?, 아니면 음..
  1. CoCa (frozen / finetuned)
    1. about frozen….
    2. imagenet을 기반으로 학습된 모델
    3. additional
      1. open_clip: openAI에서 개발한 api, 이미지, 텍스트 상호작용 하며 이해할 수 있게 하는 라이브러리
      1. API 활용해서 model 생성
        1. model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg')
      1. 전처리
        1. …
          description 표시
          import os import skimage import IPython.display import matplotlib.pyplot as plt from PIL import Image import numpy as np from collections import OrderedDict import torch %matplotlib inline %config InlineBackend.figure_format = 'retina' // 그냥 그래프 표시 옵션 # images in skimage to use and their textual descriptions descriptions = { // discreption에 label 적으면 될 것 같다. "page": "a page of text about segmentation", "chelsea": "a facial photo of a tabby cat", "astronaut": "a portrait of an astronaut with the American flag", "rocket": "a rocket standing on a launchpad", "motorcycle_right": "a red motorcycle standing in a garage", "camera": "a person looking at a camera on a tripod", "horse": "a black-and-white silhouette of a horse", "coffee": "a cup of coffee on a saucer" }
          이미지 표시
          original_images = [] images = [] texts = [] plt.figure(figsize=(16, 5)) for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]: name = os.path.splitext(filename)[0] if name not in descriptions: continue image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") plt.subplot(2, 4, len(images) + 1) plt.imshow(image) plt.title(f"{filename}\n{descriptions[name]}") plt.xticks([]) plt.yticks([]) original_images.append(image) images.append(preprocess(image)) texts.append(descriptions[name]) plt.tight_layout()
          전처리
          image_input = torch.tensor(np.stack(images)) text_tokens = tokenizer.tokenize(["This is " + desc for desc in texts])
      1. 유사도 비교
        1. …
          이미지와 텍스트의 벡터 추출
          with torch.no_grad(): image_features = model.encode_image(image_input).float() text_features = model.encode_text(text_tokens).float()
          벡터로 유사도 비교
          image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
      1. 결과
        1. …
          1. 유사도 시각화(1)
          count = len(descriptions) plt.figure(figsize=(20, 14)) plt.imshow(similarity, vmin=0.1, vmax=0.3) # plt.colorbar() plt.yticks(range(count), texts, fontsize=18) plt.xticks([]) for i, image in enumerate(original_images): plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower") for x in range(similarity.shape[1]): for y in range(similarity.shape[0]): plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12) for side in ["left", "top", "right", "bottom"]: plt.gca().spines[side].set_visible(False) plt.xlim([-0.5, count - 0.5]) plt.ylim([count + 0.5, -2]) plt.title("Cosine similarity between text and image features", size=20)
          1. 유사도 시각화 (2)
          데이터 셋 호출 후 위와 같은 전처리
          from torchvision.datasets import CIFAR100 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
          label 구성(?)
          text_descriptions = [f"A photo of a {label}" for label in cifar100.classes] text_tokens = tokenizer.tokenize(text_descriptions)
          시각화
          with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
          plt.figure(figsize=(16, 16)) for i, image in enumerate(original_images): plt.subplot(4, 4, 2 * i + 1) plt.imshow(image) plt.axis("off") plt.subplot(4, 4, 2 * i + 2) y = np.arange(top_probs.shape[-1]) plt.grid() plt.barh(y, top_probs[i]) plt.gca().invert_yaxis() plt.gca().set_axisbelow(True) plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()]) plt.xlabel("probability") plt.subplots_adjust(wspace=0.5) plt.show()
label은 다 달려있으므로 semi supervised는 필요 X
attribute threshold를 정해줄까?
imagenet 말고 크기가 작은 dataset에 대한 모델 찾아보기 → 우리 데이터 셋 크기가 작으므로 →