https://mz-moonzoo.tistory.com/18
이전 글 CycleGAN 논문 리뷰입니다. 이론적인 부분을 참고하시면 될 것 같습니다.
CycleGAN 논문 구현
환경 : Anaconda, python, pytorch
데이터셋 : summer2winter-yosemite
CycleGAN은 과거 스터디를 진행할 때 구현 코드에 주석을 달면서 공부를 진행했었습니다.
그래서 주석이 된 부분을 참고하시면 좋을 것 같습니다.
라이브러리 불러오기
데이터 불러오기
여기서 image_size에 따라 잔차 블록의 갯수가 달라지는데
128*128이미지에서는 6개, 256*256이상의 고해상도 이미지에서는 9개의 잔차 블록을 사용합니다.
로컬에 저장해둔 데이터셋 경로인 summer2winter_yosemite에서 dataloader를 통해 이미지를 불러옵니다.
데이터를 1개, 1개 학습을 시키는 방법도 가능하지만 Pytorch의 dataloader를 활용하면 Mini-Batch 단위의 학습이 가능하고 무작위로 데이터를 섞을 수 있다는 장점이 있기 때문에 거의 대부분의 파이토치 학습 코드에선 dataloader을 사용하고 있습니다. 여기서 주의 하실점은 테스트 dataloader에서는 shuffle을 False로 지정해주셔야 합니다.
이미지 확인
데이터 전처리
여기서 약간의 데이터 전처리를 진행하게 됩니다. 우리는 tanh 활성화 함수의 출력이 -1에서 1 사이의 픽셀 값을 포함할 것이라는 것을 알고 있기 때문에, 우리는 훈련 이미지의 크기를 -1에서 1 사이의 범위로 조정해주도록 합니다.
(현재, 이미지들의 크기는 0에서 1 사이의 범위에 있습니다.)
모델 정의 및 학습
Discriminator
CycleGAN에서 Discriminator Dx와 𝐷y는 이미지를 보고 실제 또는 가짜로 분류하려는 컨볼루션 신경망입니다.
여기서 real은 1에 가까운 출력으로 표시되고 fake는 0에 가까운 출력으로 표시됩니다.
이 네트워크는 이미지가 5개의 컨볼루션 레이어를 통과합니다.
처음 4개의 컨볼루션 레이어는 출력에 BatchNorm 및 ReLu 활성화 기능이 적용되며, 마지막은 하나의 값을 출력하는 분류 레이어 역할을 합니다.
ResidualBlock
제너레이터를 정의하려면 제너레이터의 인코더 및 디코더 부분을 연결하는데 필요한 ResidentBlock클래스를 정의해야 합니다. 딥러닝 모델은 시간이 지남에 따라 실제로 훈련 정확도가 악화되는 것을 볼 수 있습니다. 이 문제에 대한 한가지 해결 책으로 입력 레이어에 적용되는 residual functions를 학습할 수 있는 Resnet 블록을 사용하는 것입니다.
Generator
생성자는 이미지를 더 작은 특징 표현으로 바꾸는 것을 책임지는 인코더와그 표현을 변환된 이미지로 바꾸는 것을 책임지는 transpose_convnet(de-conv layers)인 디코더로 구성됩니다.
이 네트워크는 128x128x3 이미지를 보고 3개의 컨볼루션 레이어를 통과하여 일련의 잔여 블록에 도달할 때 이를 특징 표현으로 압축한다. 이는 잔여 블록 중 일반적으로 6개 이상(여기서는 6개)를 통과한 다음 , Resnet 블록의 출력을 업샘플링하고 새로운 이미지를 생성하는 3개의 전치 컨볼루션 레이어(디콘볼루션 레이어라고도 함)를 통과합니다.
대부분의 컨볼루션 및 전치 컨볼루션 레이어는 출력에 적용된 탄 활성화 기능이 있는 최종 전치 컨볼루션 레이어를 제외하고 출력에 적용된 BatchNorm및 ReLu 함수를 가지고 있다. 또한, 잔여 블록은 컨볼루션 및 배치 정규화 레이어로 구성되어 있습니다. 당연하듯이 마지막 레이어에는 ReLu가아닌 tanh 활성화 함수를 사용해주도록 합니다.
de-conv
Generator을 구현하기 위해 위의 conv 함수와 ResidentBlock 클래스 그리고 위의 deconv 함수를 사용해야 합니다.
모델 생성하기
이제 전체 네트워크를 생성하도록 하겠습니다.
지금까지 정의한 클래스와 함수를 사용하여 CycleGAN을 만드는데 필요한 생성자와 판별자를 생성하도록 하겠습니다.
손실함수 정의하기
훈련 중 판별자와 생성자의 손실을 계산하는데 도움이 되도록 몇 가지 유용한 손실 함수를 정의하겠습니다.
1. real_mse_loss
real_mse_loss는 판별자의 출력을 보고 해당 출력이 실제로 분류되는 것에 얼마나 가까운지에 따라 오류를 반환합니다.
2. fake_mse_loss
fake_mse_loss는 판별기의 출력을 보고 해당 출력이 가짜로 분류되는 것에 얼마나 가까운지에 따라 오류를 반환합니다.
3. cycle_consistency_loss
cycle_reconsistency_loss는 실제 영상 세트와 재구성/생성된 영상 세트를 보고 이들 사이의 평균 절대 오차를 반환합니다.
이 값에는 배치의 평균 절대 오차에 가중치를 부여하는 람다_weight 매개 변수가 있습니다.
Optimizers 정의하기
이전에 리뷰했던 GAN과 마찬가지로 판별자와 생성자에 Adam 옵티마이저를 사용합니다.
하이퍼 파라미터는 앞서 논문의 Training detail을 참고해 정의해줬습니다.
그리고 논문 리뷰에서 가볍게 언급한 부분이 있는데 여기서 LSGAN을 활용해 일반 GAN보다 더 높은 품질의 이미지를 생성하도록 했습니다. 논문 리뷰를 참고하시면 자세한 설명을 보실 수 있습니다.
학습하기
학습하는 부분은 코드를 앞의 설명과 함께 보시면 이해하기 쉬우실테니 넘어가도록 하겠습니다.
결과
저번 GAN처럼 Generator loss는 통통 튀는 것을 볼 수 있습니다.
그렇지만 학습이 진행될 수록 점점 loss가 떨어지는 것을 보실 수 있습니다.
시각화
Train 100_epoch
Train 5000_epoch
아직 학습이 많이 부족해보이지만 epochs가 늘어날수록 이미지가 선명해지는 것을 확인할 수 있습니다.
다음 논문 리뷰로는 StarGAN으로 돌아오도록 하겠습니다.