취미가 좋다
Progressive Domain Expansion Network for Single Domain Generalization 논문 읽기 본문
Progressive Domain Expansion Network for Single Domain Generalization 논문 읽기
benlee73 2021. 6. 15. 15:45
Progressive Domain Expansion Network for Single Domain Generalization
Lei Li, Ke Gao, Juan Cao, Ziyao Huang, Yepeng Weng, Xiaoyue Mi, Zhengze Yu, Xiaoya Li, Boyang xia
CVPR 2021
Abstract
Single domain generalization은 하나의 도메인에서 학습되고 unseen 도메인에서 테스트하는 과제로, model generalization 중에서 어려운 케이스이다.
이 과제를 확실하게 해결하는 방법은, 학습 도메인의 범위를 확장시켜서 도메인에 따라 변하지 않는 representation을 학습하는 것이다.
이 방법은 적절한 safety and effectiveness constraints 가 없기 때문에, 실제 어플리케이션에서 활용할 만큼 일반화 성능 향상을 가져오기가 어렵다.
그래서 이 논문은 새로운 학습 프레임워크 progressive domain expansion network(PDEN)을 제안한다.
PDEN의 도메인 확장 서브넷과 representation 학습 서브넷이 공동 학습을 통해 효과를 발휘한다.
도메인 확장 서브넷은 여러 도메인을 점점 생성하고, 확장된 도메인에서의 safety and effectiveness 를 보장한다.
representation 학습 서브넷에서는, 각 클래스 별로 잘 클러스터된 invariant representation을 학습하기 위해 constrastive learning을 사용한다.
더 나은 decision boundary로부터 generalization을 향상할 수 있다.
classification 과 segmentation 에 대한 여러 실험을 통해, PDEN이 SOTA보다 15.28% 향상된 성능을 보여준다.
코드는 https://github.com/lileicv/PDEN 에 공개된다.
1. Introduction
이 논문에서 우리는 도메인을 외부 조건(날씨, 배경, 조명...) 과 내재된 속성(색, 질감, 포즈...) 로 인한 오브젝트 표현의 분포라고 정의한다.
deep model의 성능은 unseen 도메인에서 항상 떨어진다.
예를 들어, MNIST 로 학습시킨 CNN 모델의 성능이 MNIST 에 대해 99% 이지만, SVHN 에서는 30% 에 불과하다.
그래서 머신 러닝의 Model generalization은 중요하다.
이러한 이슈를 다루기 위해 제안된 두 가지 방법이 domain adaptation, domain generalization이다.
domain adaptation은 타겟 도메인은 알지만 label이 없는 도메인으로 일반화하는 것이 목표이다.
Distribution alignment(MMD) 와 style transfer(CycleGAN) 이 주로 사용되지만, 학습에 필요한 타겟 도메인의 데이터가 부족하여 어렵다.
domain generalization은 타겟 도메인으로 접근하지 않는 과제이다.
하나 이상의 도메인으로부터 domain-agnostic(도메인에 구애받지 않는) 모델을 학습시키는 과제이다.
그중 하나의 source domain만 주어진 문제를 single domain generalization이라고 하고, 최근 관련된 연구가 진행되고 있다.
결국 data augmentation인 지금까지의 연구들은, 소스 도메인의 분포를 확장시켜 unseen 도메인에 대한 robustness를 향상시켰다.
수동으로 augmentation 방법을 선택하거나, neural networks를 통해 학습한 augmentation을 사용했다.
data augmentation은 모델 일반화에 효과적이었다.
그러나 타겟 도메인에 따라 방법과 강도를 달리해야 하므로 여러 task에 동일하게 적용하기는 어렵다.
합성 데이터의 safety 와 effectiveness 를 보장하지 못하고, 성능을 떨어뜨릴 수도 있다.
이 논문은 single domain generalization 문제를 해결하기 위해 progressive domain expansion network(PDEN) 을 제안한다.
PDEN의 task 모델과 generator는 같이 학습하면서 서로 많은 이점을 얻는다.
task 모델의 정확한 지시에 따른 generator가 safe and effective 도메인을 생성한다.
생성된 도메인은 점점 확장되면서 완성도를 높인다.
생성된 도메인으로부터 invariant representation을 학습하기 위해 constrastive learning이 사용된다.
PDEN의 generator를 쉽게 교체하여 다른 타입의 도메인으로 확장할 수 있다는 점이 주목할만한 점이다.
논문의 주요 기여는 다음과 같다.
- 새롭게 제안한 프레임워크 PDEN은 도메인 확장 서브넷과 도메인 invariant representation을 학습하는 서브넷을 가지며, joint learning을 통해 서로 이익을 얻는다.
- 도메인 확장 서브넷에서는, 여러 도메인을 점점 생성하고, 생성된 도메인의 safety 와 effectiveness를 보장한다.
- 도메인 Invariant representation 서브넷에서는, 좋은 deceision boundary로 잘 클러스터 된 각 클래스의 invariant representation 을 학습하는 contrastive learning을 사용한다.
- classification 과 segmenation에 대한 여러 실험들을 통해 SOTA를 뛰어넘는 결과를 보여준다.
2. Related Work
Domain Adaptation
최근 소스 도메인과 타겟 도메인의 drift를 해결하기 위해 많은 domain adaptation 방법들이 제안되었다.
딥러닝 기반의 방법들은 주로 두 도메인의 분포를 정렬하는 MMD 기반 혹은 adversarial 기반의 방법들을 사용한다.
DDC는 처음 제안된 방법으로 AlexNet 의 첫 7개의 레이어의 weight를 고정하였다.
MMD는 8번 레이어를 두 도메인 분포의 차이를 줄이는 데 사용하였다.
DAN은 classifier head 앞에 있는 adaptive 레이어의 수를 줄여서 MK-MMD를 제안하였다.
AdaBN은 BN 레이어에서 두 도메인의 분포를 측정하였다.
GAN이 등장하면서 adversarial 학습 기반의 여러 논문들이 나왔다.
DANN은 adversarial 학습으로 두 도메인의 분포 차이를 줄이는 첫 논문이었다.
DSN은 각 도메인이 공유하는 분포와 도메인 만의 분포를 가진다고 가정하였다.
이 가정에 따라, DSN은 공유하는 feature와 각 도메인만의 feature를 따로 학습하였다.
DANN은 학습 가능한 weight로 marginal distribtuion과 conditional distribution을 측정한다.
Domain Generalization
domain generalizaztion은 domain adaption보다 어려운 과제로, 소스 데이터의 학습만으로 unseen 도메인으로의 일반화를 목표로 한다.
domain generalization은 domain alignment 와 domain ensemble 로 카테고리를 나눌 수 있다.
domain alignment는 다른 도메인들과 공유하는 분포가 있다고 가정한다.
이 방법은 다른 도메인의 분포를 공유하는 분포로 매핑한다.
CCSA는 도메인은 다르지만 라벨이 같은 데이터의 거리를 최소화하고, 라벨이 다른 데이터의 거리를 최대화하는 contrastive semantic alignment loss를 제안한다.
MMD-AAE에서, 소스 도메인과 타겟 도메인의 분포는 MMD에 의해 정렬되고, feature representation은 AAE에 의해 이전 laplace 분포와 매치시킨다.
모델 앙상블 방법은, 여러 소스 도메인의 데이터셋으로 학습시킨 모델을 앙상블하여 각 모델의 confidence 값에 따라 결과를 결정한다.
single domain generalization은 소스 도메인이 하나밖에 없을 때의 과제로 최근 많은 논문들이 나오고 있다.
최근 방법들은 합성을 통해 훈련 데이터셋의 distribution을 확장시킨다.
BigAug 는 의료 영상의 차이가 주로 3가지 (image quality, image appearance, spatial configuration) 에 따라 달라진다는 것을 보여준다.
그래서 그들은 3가지 측면에서 augmentation을 적용했다.
그러나, 이러한 방법은 타겟 도메인에 따라 augmentation 타입과 강도를 결정해야 하므로, 다른 task에 적용하기 어렵다.
그래서 GUD, MADA 는 모델의 robustness를 향상시키기 위해 adversarial learning을 통해 데이터를 합성한다.
이에 비해 augmentation 타입은 비교적 단순하기도 하지만, 너무 많은 adversarial examples은 classifier 의 성능을 저하시킨다.
Constrastive Learning
constrastive learning은 image recognintion을 위한 unsupervised pre-training의 한 방법으로, 최근 자주 사용된다.
주요 아이디어는, positive pairs 는 가깝게 하고, negative pairs 는 멀게 하는 것이다.
SimCLR 는 모든 이미지에 강한 augmentation을 주어 positive pair를 생성한다.
3. Method
소스 도메인과 타겟 도메인을 아래와 같이 정의한다.
xi, yi 는 i 번째 이미지와 라벨이고, Ns, Nt 는 각 도메인의 샘플 수를 의미한다.
우리는 S 만으로 모델을 학습하고 unseen T 로 일반화하는 것을 목표로 한다.
PDEN 구조는 아래 Fig.2 에서 볼 수 있다.
3.1. The task model M
task 모델 M에는 3개의 파트가 있다.
1) Feature extractor F : X → H
X는 이미지 공간이고 H는 feature 공간이다.
F는 pooling layer와 activation layer가 뒤에 붙은 convolution layer stack 이다.
F의 출력은 global pooling으로 얻은 1차원 벡터이다.
2) Classifier head C : H → Y
Y는 라벨 공간이다.
여기서 우리는 classification task로 확인할 것이고, 그래서 C는 cross-entropy loss로 최적화된다.
우리의 실험에서 C는 nonlinear activation layer가 뒤에 붙은 fully connected layer stack 이고, 마지막 activation function은 softmax이다.
3) Projection head P : H → Z
Z는 constrastive loss가 계산될 hidden space이다.
우리 실험에서 P는 딱 하나의 full connection layer를 포함한다.
우리는 P의 출력 벡터를 하나의 hypersphere 에 있도록 정규화 하고, 내적을 사용하여 Z 공간에서 유사도를 측정할 수 있도록 한다.
3.2. The Unseen Domain Generator G
G는 아래 식으로 소스 도메인의 이미지 x 를 unseen domain S' 의 새로운 이미지 x' 로 바꿀 수 있다.
x' 은 x 와 semantic information을 가지고 있지만 도메인은 다르다.
G는 AutoEncoder, HRNet, STN 또는 이런 네트워크의 조합과 같은 여러 downstream task에 따라 다양한 구조를 갖는다.
1) Autoencoder as G
우리의 실험에서는, Fig 2 의 Gk 와 같이 AdaIN을 generator로 하는 Autoencoder를 사용한다.
G는 인코더 Ge, AdaIN, 디코더 Gd 를 포함하고, Adain 은 두 개의 fully-connected layer를 가진다.
n ~ N(0,1)
Fig 3(a) 는 autoencoder 로 생성된 unseen domain을 보여준다.
2) STN as G
Autoencoder는 STN을 generator로 대체할 수 있다.
STN은 이미지의 공간 구조를 변환할 수 있는 geometry-aware 모듈이다.
아래 Fig 3(b)가 STN으로 생성한 unseen domain 이다.
PDEN 은 task에 따라 generator의 구조가 바뀔 수 있는 프레임워크이고, 우리 실험에서는 autoencoder가 적용되었다.
3.3. Progressive Domain Expansion
생성된 도메인의 완성도를 높이고 coverage를 확장하기 위해, 우리는 학습 가능한 G를 가지고 점진적으로 K unseen 도메인을 생성한다.
task 모델 M 은 invariant representation을 학습하기 위해 이렇게 생성된 도메인에서 학습된다.
task 모델과 G가 번갈아 학습되는 것을 Fig 2. 에서 볼 수 있다.
kth 도메인 확장을 예를 들어보자.
먼저 G와 M이 safe and effective unseen 도메인 S'k를 합성하기 위해 같이 학습된다.
그리고 M이 Equ 3. 을 최소화하도록 업데이트된 데이터셋 (기존 + 지금까지 합성된) 다시 학습된다.
M의 성능이 올라가면서 Gk+1이 better unseen 도메인을 합성하도록 가이드해준다.
이 과정이 Alg 1. 이다.
3.4. Domain Alignment and Classification
이 섹션에서는 어떻게 cross-domain invariant representation을 학습하는지 소개한다.
xi : source image
xi+ = G(xi, n) : synthetic image
yi : class label
y_i^m 은 mth dimension의 yi 이다.
M은 아래의 식으로 optimized 된다.
Lce는 classification을 위해 사용되는 cross-entropy loss이다.
Lnce는 contrastive learning을 위해 사용되는 InfoNCE loss이다.
minibatch B에서, zi 와 zi+ 는 같은 semantic information을 가지지만, 다른 도메인에 속해있다.
Lnce를 최소화 함으로써, zi 와 zi+ 의 거리는 작아진다.
즉, 동일한 semantic 정보를 가지는 다른 도메인의 샘플들이 z 공간에서 더 가까워질 것이다.
Lnce 가 F 가 domain-invariant representation을 학습하게 돕는 것이다.
3.5. Unseen Domain S' Generation
이 섹션에서는 Gk가 S를 가지고 어떻게 k번 째 unseen domain S'k 를 생성하는지 보여준다.
S'은 safety 와 effectiveness 의 제약 조건을 만족한다.
Safety 는 생성된 샘플이 domain-invariant information을 가지고 있다는 것을 의미한다.
Effectiveness 는 생성된 샘플이 다양한 unseen 도메인의 정보를 가지고 있다는 것을 의미한다.
Safety
S'에 속하는 모든 x가 task 모델 M에 의해 올바르게 예측된다면, S' 은 안전하다.
아래의 loss를 optimize 함으로 safety를 보장한다.
Cycle consistency loss는 S'의 safety 를 보장하기 위해 이전에 소개되었다.
만약 S' 이 Gcyc 를 통해 S 로 바뀔 수 있다면, S'은 안전하다.
Gcyc는 G와 동일한 구조를 가지지만, noise input은 없다.
Effectiveness
adversarial learning은 효과적인 unseen domain을 생성한다고 알려져 왔다.
M 과 G 는 같이 학습된다.
domain-share representation을 추출하는 M은 InfoNCE loss 를 최소화하도록 학습되고, G는 최대화하도록 학습된다.
adversarial training 을 통해, G는 M이 domain-share representation을 추출할 수 없는 도메인을 생성하고, M은 더 잘 추출할 수 있도록 된다.
그러나 위의 Ladv는 수렴하기 어렵기 때문에, 아래로 근사한다.
또한 아래의 loss function을 통해 G가 더 다양한 샘플을 생성하도록 한다.
여기 n1, n2 ~N(0,1) 에서 n1 ≠ n2 이다.
G 를 학습시키는 데 사용하는 모든 loss function을 종합하면 아래와 같다.
Lcls의 weight는 항상 1이고, 나머지는 모두 weight를 갖는다.
4. Experiment
4.1. Datasets and Evaluate
Digits Dataset
총 5가지 데이터셋(MNIST, MNIST-M, SVHN, USPS, SYNDIGIT) 을 포함한다.
각 데이터셋은 도메인으로 간주되어, MNIST 를 소스 도메인으로 사용하고 나머지를 타겟 도메인으로 사용한다.
MNIST의 첫 10,000 장의 이미지로 모델을 학습한다.
CIFAR10-C Dataset
CIFAR10을 소스 도메인으로 사용하고, CIFAR10-C 를 타겟 도메인으로 사용한다.
CIFAR10-C 는 classification 모델의 robustness를 평가하는 벤치마크 데이터셋이다.
알고리즘적으로 생성된 19가지 corruption type을 가지는 이미지로 구성되어 있다.
corruption 은 4개의 카테고리와 5 단계의 강도를 가진다.
SYNTHIA Dataset
SYNTHIA VIDEO SEQUENCES 데이터셋은 traffic scene segmentation에서 사용된다.
데이터셋은 3가지 장소로 구성된다. (High- way, New York ish, Old European Town)
각 장소는 같은 traffic situation 을 가지지만 다른 날씨 / 조명 / 계절을 가진다.
한 도메인에서 학습시키고, 다른 도메인에서 평가한다.
각 도메인에 대해, 왼쪽 카메라에서 900장의 이미지를 샘플링하고, 192 x 320 픽셀로 resize한다.
Evaluate
Digit 과 CIFAR10 데이터셋에 대해서는, 각 unseen domain에 대해 mean accuracy를 계산한다.
SYNTHIA 데이터셋에 대해서는, 각 unseen 도메인데 대해 standard mean Intersection over Union(mIoU) 를 사용한다.
4.2. Evaluation of Single Domain Generalization
우리의 방법을 다음의 SOTA 방법들과 비교한다.
(1) Empirical Rist Minimization(ERM)은 cross-entropy loss 만으로 학습된 baseline method 이다.
(2) CCSA 는 domain 일반화를 위한 robust feature space 를 얻기 위해서, 같은 카테고리면서 다른 도메인의 샘플들을 정렬한다.
(3) d-SNE 는 같은 클래스를 가진 쌍 사이의 최대 거리를 최소화하고, 다른 클래스를 가진 쌍의 최소 거리를 최대화한다.
(4) GUD 는 classifier의 robustness를 향상시키는 더 hard 한 샘플을 합성하는 adversarial data augmentation 방법을 사용한다.
(5) MADA 는 더 effective한 샘플을 생성하기 위해, semantic space의 거리를 최소화하고, pixel space의 거리를 최대화한다.
(6) JiGen 은 target recognition tast 와 Jigsaw classification task 를 결합한 Multi-tast learning 방법을 사용한다.
(7) AutoAugment(AA) 는 특정 데이터셋에 대해 더 좋은 augmentation 방법을 자동으로 찾는 방법을 사용한다.
(8) AA를 기반으로 하는 RandAugment(RA) 는 policies space를 줄여서 더 나은 augment policies 를 가진다.
Comparison on Digits
모델을 MNIST train set의 10,000 개 이미지로 학습시키고, test set으로 validate 하고, MNIST-M, SVHN, USPS and Syndigits datasets 으로 evaluate 한다.
평가 지표로 각 데이터셋의 mean accuracy 를 계산한다.
가장 먼저 Table 1. 의 위쪽 부분을 보면, single domain generalization 방법들과 비교했다.
공정하게 어떤 data augmentation도 사용하지 않았다.
각 데이터셋(도메인)에서우리의 방법이 가장 좋은 성능을 보여준다.
USPS 데이터셋은 MNIST 와 비슷하기 때문에, 다른 방법들과 비슷한 성능을 가진다.
d-SNE가 USPS 에서 좋은 성능을 보이지만, 다른 데이터셋에서는 성능이 좋지 않다.
Table 1. 아래쪽에서는 data augmentation 방법들과 비교한 것을 볼 수 있다.
hyperparameters는 원본 논문의 것을 사용하였다.
역시나 우리의 방법이 가장 성능이 좋았다.
What’s more, our approach is orthogonal to these data augmentation techniques.
Comparison on CIFAR10
CIFAR10 train set으로 학습시키고, test set으로 validate 하고, CIFAR10-C로 evaluate 하였다.
5가지 corruption 강도 단계에 따른 실험 결과를 Table 2. 에서 볼 수 있다.
우리의 방법이 다른 방법들보다 좋은 성능을 보였다.
corruption이 심할수록, MADA를 더욱 능가했다.
data augmentation 방법과 비교해보면, 낮은 단계뿐 아니라 높은 단계에서 우수한 성능을 보였다.
Table 3. 에서는 다른 타입의 corruption에 대한 실험 결과를 보여준다.
우리 방식의 평균 accuracy가 가장 높은 것을 볼 수 있다.
몇 가지 타입에 대해서는 RandAugment가 우리 방법보다 더 성능이 좋다.
하지만 우리 방법은 augmentation을 사용하지 않았을 때 이므로, RandAugment 를 함께 사용하면 더 좋은 성능을 보일 것이다.
Comparison on SYNTHIA
Highway-Dawn, Highway-Fog, Highway-Spring 을 각각 소스 도메인으로 사용하고, 모든 날씨의 New York ish and Old European Town을 unseen 타겟 도메인으로 하는 3가지 실험을 하였다.
scence segmentation 결과 (mIoU)는 Table 4.에서 볼 수 있다.
우리의 방법이 가장 성능이 좋았고, Highway-Dawn 과 Highway-Fog 일 때, 성능 향상이 더 크다.
4.3. Additional Analysis
Validation of K
우리는 Digits dataset의 하이퍼 파라미터 K 에 대해 연구했고, Fig 5.(a) 를 보면 실험 결과를 볼 수 있다.
K 에 따른 타겟 도메인에서의 classification accuracy를 나타낸 그림이다.
K 가 작을 때는, accuracy가 급격히 증가하고, K 가 커질수록 점차 수렴한다.
Digits 실험에서는 K=20으로 하고 실험했었다.
MADA 논문의 Digits 실험에서는, K=3일 때 가장 성능이 좋았고, K 가 더 커지면 성능은 감소했다.
이것은 우리의 방법이 MADA 보다 더 safety 하다는 것을 보여준다.
Validation of Wadv
우리는 Digits dataset 의 하이퍼 파라미터 Wadv 의 효과에 대해 연구했고, Fig 5.(b) 를 보면 실험 결과를 볼 수 있다.
Wadv 가 0.02 / 0.05 / 0.08 / 0.1 / 0.13 / 0.16 / 0.2 일 때, 타겟 도메인의 classification accuracy를 보여준다.
Wadv가 커질수록 unseen 타겟 도메인에서 정확도가 증가한다는 것을 발견했다.
Validation of Wcyc
하이퍼 파라미터 Wcyc 의 효과를 연구했고, Fig 5.(c) 에서 실험 결과를 볼 수 있다.
Wcyc 가 0 / 10 / 20 / 30 / 40 / 50 일 때, 각 데이터셋(도메인) 에서의 classification accuracy를 보여준다.
USPS 를 제외하고는 Wcyc 증가에 따라 accuracy도 증가했다.
USPS 가 MNIST 와 아주 유사하기 때문에 accuracy 의 변화가 없는 것이다.
Validation of Wdiv
하이퍼 파라미터 Wdiv 의 효과를 연구했고, Fig 5.(d) 에서 실험 결과를 볼 수 있다.
Digit dataset의 모든 unseen 도메인에서, Wdiv 에 증가에 따라 accuracy도 같이 증가하였다.
Visualization of the feature space
Figure 4. 는 baseline 모델과 PDEN 간의 2차원 feature space 차이를 보여준다.
PDEN을 보면, 타겟 도메인의 sample distribution 이 소스 도메인의 분포와 일치한다.
baseline 모델을 보면, 대부분의 타겟 샘플이 feature space 에서 섞여 분류하기가 어렵다.
4.4 Evaluation of Few-shot domain Adaptation
우리의 방법을 few-shot domain adaptation 과도 비교하였다.
few-shot domain adaptation 에서는, 소스 도메인의 데이터와 타겟 도메인의 적은 데이터로 모델을 학습시킨다.
MNIST 를 소스 도메인으로 사용하고, SVHN 을 타겟 도메인으로 사용하였다.
먼저 제안된 PDEN 을 MNIST로 학습하고, 적은 SVHN 데이터로 finetune 한다.
그렇게 학습된 모델을 SVHN 에 evalute 한 결과가 Fig 6. 이다.
이 실험을 통해, 타겟 도메인의 적은 데이터로 finetuning 하는 것이 타겟 도메인에 대한 모델의 성능을 향상시키는 것을 발견하였다.
MADA 와 비교하면, PDEN 이 더 좋은 성능을 보였다.
5. Conclusion
이 논문은 모델이 unseen 도메인에서 일반화될 수 있도록, domain-invariant feature 를 학습하는 single domain generalization learning framework 를 제안한다.
소스 도메인과 같은 semantic information을 공유하는 unseen domain을 합성하는 generator를 학습시킨다.
domain-invariant representation은 소스 도메인과 unseen 도메인의 distribution을 정렬하면서 학습할 수 있다.
task 모델로 domain-invariant representation 을 추출할 수 없는 unseen domain을 생성한다.
이렇게 생성된 도메인을 training set에 추가하면서, 모델은 더 robust 해 질 것이다.
결론적으로, 새로 제안된 PDEN 은 single domain generalization 을 해결하기 위한 유망한 방향을 제공한다.
'논문 > Domain Adaptation & Generalization' 카테고리의 다른 글
Learning to Learn Single Domain Generalization 논문 정리 (0) | 2021.05.11 |
---|---|
Do Adversarially Robust ImageNet Models Transfer Better? (0) | 2021.05.04 |
Bidirectional Learning for Domain Adaptation of Semantic Segmentation 논문 정리 (0) | 2021.05.04 |
CyCADA: Cycle-Consistent Adversarial Domain Adaptation 논문 정리 (0) | 2021.05.04 |
Adversarial Discriminative Domain Adaptation 논문 정리 (0) | 2021.05.04 |