상세 컨텐츠

본문 제목

컴퓨터 비전 공부하기(딥러닝 시작하기 - 거리 학습 기반 메타러닝Prototypical Networks, Matching Networks)

카테고리 없음

by zmo 2024. 11. 24. 00:27

본문

오늘은 거리 학습 기반(Distance-Based) 메타러닝의 주요 알고리즘인 Prototypical Networks, Matching Networks에 대해 알아보도록 하자. (오늘은 티스토리에서 수식을 적는 방법을 알아내서 GPT 캡쳐를 안해도 될것 같다!)

먼저 거리 학습 기반 메타러닝이란 간단하게 서포트 셋과 쿼리 셋과의 거리를 측정하여 판단하며 서포트 셋을 통하여 이미지의 특징을 뽑아내고 분류하여 유사성에 따라 거리를 나타낸다.

 

살짝 풀어서 설명하면 다음과 같다

 

임베딩 공간 생성 - 데이터를 고차원에서 저차원의 임베딩 공간으로 매핑하여 유사한 데이터는 가까이, 다른 데이터는 멀리 위치하도록 학습한다.

유사도 측정 - 새로운 데이터 포인트와 기존 클래스 간의 거리를 계산한다. 주로 사용하는 거리 측정 방식은 유클리드 거리(Euclidean Distance) 또는 코사인 유사도(Cosine Similarity)이다.

가장 가까운 클래스 선택 - 새 데이터와 가장 가까운 클래스의 레이블을 할당한다.

 

데이터들의 서로 비슷한 특징들을 임베팅 공간에 매핑해 놓고 새로운 데이터가 왔을때 어떤 특징에 가까운지 판단하여 결과를 내놓은다고 생각하면 될것 같다.

 

 

 

오늘 알아보기 

1. Prototypical Networks

2. Matching Networks

3. Matching Networks 예재

 

 


 

Prototypical Networks

 

Prototypical Networks 알고리즘은 각 클래스의 프로토타입(Prototype)을 계산한다.

프로토타입은 해당 클래스의 임베딩 평균을 뜻하는데 임베딩이란 데이터를 특정 숫자 벡터로 전환하는 과정을 뜻한다.

고로 프로토타입이란 특정 클래스의 데이터 숫자 벡터를 평균을 내어 그 클래스를 대표하는 숫자 벡터로 사용하는 것이라 할수 있다. 이렇게 하면 수없이 많은 클래스의 모든 데이터와 비교해보지 않고 손쉽게 그 클래스와의 거리를 잴수 있다.

 

다음은 프로토 타입을 계산하는 수식이다.

\[
\mathbf{c}_k = \frac{1}{|S_k|} \sum_{(\mathbf{x}_i, y_i) \in S_k} f_\theta(\mathbf{x}_i)
\]

: 클래스 k의 프로토타입

: 클래스 k에 속한 Support Set

: 임베딩 함수 (신경망)

: 입력 샘플

\( \mathbf{ y }_ i \) : 입력 샘플의 레이블

 

 

Query 데이터(테스트 샘플)를 임베딩한 후에는, 각 클래스의 프로토타입과의 거리를 계산한다.

보통 유클리드 거리를 사용하여, 가장 가까운 프로토타입을 가진 클래스로 분류한다.

\[
d(f_\theta(\mathbf{x}_q), \mathbf{c}_k) = \|f_\theta(\mathbf{x}_q) - \mathbf{c}_k\|^2
\]

\[
\hat{y}_q = \arg\min_k d(f_\theta(\mathbf{x}_q), \mathbf{c}_k)
\]

 

이후 Negative Log-Likelihood를 사용하여 모델을 학습할수 있다.

\[
\mathcal{L} = - \frac{1}{|Q|} \sum_{(\mathbf{x}_q, y_q) \in Q} \log \frac{\exp(-d(f_\theta(\mathbf{x}_q), \mathbf{c}_{y_q}))}{\sum_{k} \exp(-d(f_\theta(\mathbf{x}_q), \mathbf{c}_k))}
\]

Matching Networks

Few-Shot Learning 문제를 해결하기 위해 제안된 알고리즘인 Matching Networks는 이미 학습된 클래스 외의 새로운 클래스를 다룰 때 효과적인 방법으로, 메모리 기반 접근 방식과 유사도 계산을 결합하여 작동한다. 모델은 새로운 데이터를 학습하지 않고, Support Set(지원 집합)이라고 불리는 소량의 레이블이 있는 데이터를 참조하고 Support Set과 비교하여 새로운 데이터를 분류한다.

간단히 말해 새로운 데이터를 이용한 학습 없이도 기존의 Support Set을 이용해 기준을 잡고 새로운 데이터(Query)를 이 기준과 비교해 가장 가까운 클래스를 예측하는 것이다.

 



Matching Networks는 Query 데이터 \( x \)가 클래스 \( k \)에 속할 확률을 다음과 같이 계산한다.
\[
p(y = k \mid x) = \sum_{i} a(x, x_i) \cdot 1(y_i = k)
\]
여기서
- \( a(x, x_i) \): Query 데이터 \( x \)와 Support Set 데이터 \( x_i \) 간의 유사도
- \( 1(y_i = k) \): 샘플 \( x_i \)가 클래스 \( k \)에 속하면 1, 아니면 0

이다

유사도는 Cosine Similarity를 사용하여 계산한다
\[
a(x, x_i) = \frac{f_\theta(x) \cdot f_\theta(x_i)}{\|f_\theta(x)\| \|f_\theta(x_i)\|}
\]
여기서 \( f_\theta \)는 데이터를 임베딩 공간으로 매핑하는 함수이다.

 

이제 수식만 보면 이해도 잘 안가고 어디에 써먹는지 잘 모르겠으므로 Matching Networks 예재를 살펴보자

 

Matching Networks 예재 (코랩사용)

pip install torch torchvision numpy matplotlib
import torch
import numpy as np
import matplotlib.pyplot as plt

# 데이터 생성
def generate_data(num_classes, num_samples, seed=42):
    np.random.seed(seed)
    data = []
    labels = []
    for i in range(num_classes):
        center = np.random.uniform(-5, 5, size=(2,))
        points = np.random.randn(num_samples, 2) + center
        data.append(points)
        labels.append(np.full(num_samples, i))
    return np.vstack(data), np.hstack(labels)

# Support Set (3 classes, 5 samples each)
support_data, support_labels = generate_data(num_classes=3, num_samples=5)

# Query Set (3 classes, 5 samples each)
query_data, query_labels = generate_data(num_classes=3, num_samples=5, seed=99)
import torch.nn as nn

class EmbeddingNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(EmbeddingNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)
from torch.nn.functional import softmax

class MatchingNetwork(nn.Module):
    def __init__(self, embedding_net):
        super(MatchingNetwork, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, support, support_labels, query):
        # 임베딩 계산
        support_embeddings = self.embedding_net(support)  # Support Set
        query_embeddings = self.embedding_net(query)      # Query Set
        
        # 유사도 계산 (Cosine Similarity)
        similarity = torch.mm(query_embeddings, support_embeddings.T)
        similarity = softmax(similarity, dim=1)  # 확률로 변환
        
        # 레이블 기반 가중합
        support_labels_onehot = torch.nn.functional.one_hot(support_labels, num_classes=3).float()
        predictions = torch.mm(similarity, support_labels_onehot)
        return predictions
# 데이터 텐서 변환
support_data = torch.tensor(support_data, dtype=torch.float32)
support_labels = torch.tensor(support_labels, dtype=torch.long)
query_data = torch.tensor(query_data, dtype=torch.float32)
query_labels = torch.tensor(query_labels, dtype=torch.long)

# 모델 초기화
embedding_net = EmbeddingNet(input_dim=2, hidden_dim=16, output_dim=8)
model = MatchingNetwork(embedding_net)

# 예측 수행
predictions = model(support_data, support_labels, query_data)

# 결과 평가
predicted_classes = torch.argmax(predictions, dim=1)
accuracy = (predicted_classes == query_labels).float().mean().item()
print(f"Accuracy: {accuracy * 100:.2f}%")
# Support Set 시각화
plt.scatter(support_data[:, 0], support_data[:, 1], c=support_labels, marker='o', label='Support Set')
# Query Set 시각화
plt.scatter(query_data[:, 0], query_data[:, 1], c=query_labels, marker='x', label='Query Set')
# Prediction 시각화
for i, query in enumerate(query_data):
    plt.text(query[0], query[1], str(predicted_classes[i].item()), color="red")

plt.legend()
plt.title("Matching Networks - Few Shot Learning")
plt.show()

 

이 예재를 실행해보고 비교군은 몇가지인지 찾아보고, 학습 결과는 어떻게 나오는지 확인해보자

더보기
더보기

Accuracy: 6.67% …. 이 나온다 비교군이 15개 밖에 안되서 아주 정상적으로 나온것 같다 실제 사용하려면 15개 정도로는 택도 없다는 것을 알아두자

 

 


 

 

오늘은 거리 학습 기반(Distance-Based) 메타러닝의 주요 알고리즘인 Prototypical Networks, Matching Networks에 대해 알아보았다. 적은 데이터로도 학습할수 있는 이 방법들은 꽤나 유용할것 같다. 하지만 학습 데이터가 적기 때문에 잘못 분류되었거나 잘못 학습하기 시작한다면 걷잡을수 없이 산으로 가는 AI가 될것 같다.