데이터 마다 가중치를 달리하여 학습 시키는 방법!
What
네이버 블로그 | How Deep is the Learning[Tensorflow] Class Imbalance 문제 대처하기 (WCE, Focal Loss)
[Tensorflow] Class Imbalance 문제 대처하기 (WCE, Focal Loss)
Introduction 분류 문제를 다룰 때, 특히 메디컬 이미지의 경우 Normal Data는 수없이 많지만 Abnormal...
검색해 보면 imbalanced class dataset에 대해서 사용하는 방법이라고 나와있다.
그러나 우리는 데이터 class가 불균형 해서 이 방법을 사용하는 것이 아니다.
데이터 퀄리티가 각기 다르기 때문에 사용하는 것이다.
지피티 피셜 ⬇️
각 데이터마다 다른 가중치를 두고 학습을 진행하는 것은 머신러닝에서 Weighted Loss Function을 사용하는 방식으로 해결할 수 있습니다. 특히 감정 데이터셋을 학습할 때, 특정 데이터에 더 많은 가중치를 주어 모델이 해당 데이터를 더 잘 학습하도록 유도할 수 있습니다.
Weighted Loss Function 사용 방법
- 데이터 가중치 설정:
각 데이터마다 다르게 설정할 가중치를
weights라는 리스트나 배열로 준비합니다. 예를 들어, 데이터셋의 크기가 100이라면weights리스트의 길이도 100이어야 합니다. 각 값은 해당 데이터에 부여할 가중치를 나타냅니다.
- 손실 함수 변경: 사용하는 프레임워크에 따라 손실 함수(loss function)에 가중치를 적용하는 방법이 다를 수 있습니다.
- PyTorch:
PyTorch에서는
torch.nn.functional.cross_entropy나torch.nn.CrossEntropyLoss에서weight파라미터를 사용할 수 있습니다. 예를 들어: - TensorFlow/Keras:
TensorFlow/Keras에서는
sample_weight를 사용하여 가중치를 적용할 수 있습니다. 예를 들어:
python코드 복사 import torch.nn as nn # 클래스별로 가중치를 부여할 수 있습니다. criterion = nn.CrossEntropyLoss(weight=class_weights) loss = criterion(outputs, targets)
python코드 복사 model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # 학습 시 각 샘플에 대해 가중치를 부여 model.fit(x_train, y_train, sample_weight=weights, epochs=10)
- 가중치 적용: 각 데이터에 가중치를 직접 부여하는 방식으로 진행할 수 있습니다. 예를 들어, 데이터가 여러 클래스로 이루어져 있고 특정 클래스에 대해 가중치를 더 부여하고 싶다면, 해당 클래스에 해당하는 데이터의 가중치를 높게 설정합니다.
실제 예제
예를 들어, 당신이 10개의 이미지를 가지고 있고, 각 이미지에 대한 기쁨의 정도가
joy_scores라는 리스트로 주어졌다고 가정하겠습니다. 이 리스트의 값을 통해 가중치를 설정하고, 이 가중치를 사용하여 학습을 진행할 수 있습니다.python코드 복사 # 기쁨 점수에 따른 가중치 설정 joy_scores = [0.8, 0.5, 0.3, 0.9, 0.2, 0.6, 0.7, 0.4, 0.1, 0.5] weights = [score / sum(joy_scores) for score in joy_scores] # 모델 학습 시 sample_weight 파라미터로 가중치 전달 model.fit(x_train, y_train, sample_weight=weights, epochs=10)
이렇게 하면 기쁨의 정도가 높은 데이터에 더 많은 가중치가 부여되어, 모델이 이 데이터를 더 중요하게 학습하게 됩니다.
![[Tensorflow] Class Imbalance 문제 대처하기 (WCE, Focal Loss)](https://blogimgs.pstatic.net/nblog/mylog/post/og_default_image_160610.png)