머신러닝

style transfer model-2

content0474 2024. 11. 6. 13:25

이전 모델에서 Adam 대신 LGFBS를 사용했다.

 

전체코드

더보기

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

vgg16 = models.vgg16(pretrained=True).features
for layer in vgg16:
    if isinstance(layer, nn.ReLU):
        layer.inplace = False

def gram_matrix(features):
    with torch.no_grad():
        _, C, H, W = features.size()
        features = features.view(C, H * W)
        gram = torch.mm(features, features.t()) / (C * H * W)
    return gram

def get_features(image, model, layer_idx):
    content_feature = None
    style_features = []
    x = image
    
    for i, layer in enumerate(model):
        x = layer(x)
        
        if i == layer_idx['content']:
            content_feature = x.clone()
        elif i in layer_idx['style']:
            style_features.append(gram_matrix(x))
    
    return content_feature, style_features


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

content_image = transform(Image.open("your content image")).unsqueeze(0)
style_image = transform(Image.open("your style image")).unsqueeze(0)

layer_idx = {'content': 21, 'style': [0, 5, 7, 10, 16, 24, 26]}

with torch.no_grad():
    content_features_from_content, style_features_from_content = get_features(content_image, vgg16, layer_idx)
    content_features_from_style, style_features_from_style = get_features(style_image, vgg16, layer_idx)


generated_image = content_image.clone().requires_grad_(True)
optimizer = torch.optim.LBFGS([generated_image], lr=0.01)
criterion = nn.MSELoss()

content_weight = 1
style_weight = 1000

num_iterations = 5000
for i in range(num_iterations):
    
    def closure():
        optimizer.zero_grad()
        
        content_features_from_generated, style_features_from_generated = get_features(generated_image, vgg16, layer_idx)
        

        content_loss = criterion(content_features_from_content, content_features_from_generated)
        
        style_loss = sum(criterion(style_gram, generated_gram)
                         for style_gram, generated_gram in zip(style_features_from_style, style_features_from_generated))
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        
        total_loss.backward()
        
        return total_loss

    optimizer.step(closure)
    with torch.no_grad():
        # 생성 이미지 값의 범위를 유지
        generated_image.clamp_(0, 1)
        
        content_features_from_generated, style_features_from_generated = get_features(generated_image, vgg16, layer_idx)
        content_loss = criterion(content_features_from_content, content_features_from_generated)
        style_loss = sum(criterion(style_gram, generated_gram)
                         for style_gram, generated_gram in zip(style_features_from_style, style_features_from_generated))
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        print(f"Iteration {i+1}/{num_iterations}, Total Loss: {total_loss.item()}")

    if (i + 1) % 5 == 0 or i == num_iterations - 1:
        plt.figure()
        unloader = transforms.ToPILImage()
        image = generated_image.cpu().clone().squeeze(0)  # 차원 줄이기
        image = unloader(image)
        plt.imshow(image)
        plt.title(f"Generated Image at Iteration {i+1}")
        plt.axis('off')
        plt.show()

 

 

주요 코드 수정은

optimizer = torch.optim.LBFGS([generated_image], lr=0.01)

def closure():
    optimizer.zero_grad()
        ...
    return total_loss

optimizer.step(closure)

 

LBFGS는 여러 차례 손실을 평가하고 그라디언으를 계산한다.

그래서 closer함수가 필요하다. 

LBFGS는 closer 함수를 호출해서 손실정보를 수집하고 역전파를 한다.

optimizer.zero_grad() 로 그라디언트를 초기화하여 새로운 step에서 그라디언트 계산을 시작하고

optimizer.step(closure) 는 optimizer.step에 closure를 전달하여 LBFGS가 closure를 여러 번 호출하게 해준다.

 

LBFGS는 Adam에 비해 손실함수의 최적화 방향을 더욱 세밀하게 조정한다.

 

content image

gpt가 그려준 이미지

 

 

style image

gpt가 그려준 이미지2

 

 

흰색->고등어

 

colab 무료버전을 사용하기에 오천번까지 돌릴수 없었다. 손실도 42 이후로는 오르락내리락 반복함

 

가능한 문제와 해결점

.vgg16 외에 다른 모델을 사용해볼까

.layer 선택에서 실수

.optimizer재선택 또는 학습률 조정, 가중치 조정

. 이전 코드에서 역전파 오류가 계속 발생했는데 그 부분을 고치면서 뭔가 문제가 생겼다..?

 

 

더보기

논문 결과물들은 멋있던데.. ㅜㅜ

'머신러닝' 카테고리의 다른 글

Langchain  (0) 2024.11.08
예측모델(데이터전처리)  (0) 2024.11.07
style transfer model  (0) 2024.11.05
openAI API를 활용한 챗봇 구현  (0) 2024.11.04
Sentiment Analysis with LSTM  (2) 2024.11.01