Coding

[Python] torch.nn.CrossEntropyLoss 에서 ignore_index

Belter 2023. 4. 18. 01:52

상황는 이렇다. Segmentation task를 진행하다가 CrossentropyLoss에 ground truth label (mask)를 넣어야 하는 상황이었는데, 원하는 segmentation class 개수는 3개(물, 하늘, 장애물)이지만 mask에는 총 4개의 label이 달려있었던 상황. 그래서 Data 설명을 읽어봤다.

Data description

'4'라고 분류된 픽셀들은 물체/물/하늘 사이 boundary에 해당해서 모호한 픽셀로 남겨놨던 것이다. 학습에 사용되면 안되는 픽셀들이기 때문에 '4'라고 분류된 픽셀들은 포함시키면 안된다.

Mask 픽셀값 예시

그래서 이 픽셀들을 어떡하지? 그냥 '0'이나 '1' 혹은 '2'로 labeling 후처리를 해야하나... 그러기엔 픽셀들이 너무 애매한걸?

(좌) 원본 이미지, (우) 시각화 된 mask 이미지. 하얀색 픽셀들이 '4'에 해당하는 픽셀들임.

이러고 있다가 동료가 찾아준 아주 효과적인 해결책이 torch.nn.CrossEntropyLoss에 있는 'ignore_index' 인자이다.

torch.nn.CrossEntropyLoss 공식 문서에서 ignore_index 부분 설명

ignore_index = 4 를 넣어주면 '4'에 해당하는 픽셀들은 loss 계산에 무시되거나 gradient 계산에 무시된다. 즉, 학습에 관여를 안시키겠다는 말!! 너무 아름답다. 끗.