상황는 이렇다. Segmentation task를 진행하다가 CrossentropyLoss에 ground truth label (mask)를 넣어야 하는 상황이었는데, 원하는 segmentation class 개수는 3개(물, 하늘, 장애물)이지만 mask에는 총 4개의 label이 달려있었던 상황. 그래서 Data 설명을 읽어봤다.
'4'라고 분류된 픽셀들은 물체/물/하늘 사이 boundary에 해당해서 모호한 픽셀로 남겨놨던 것이다. 학습에 사용되면 안되는 픽셀들이기 때문에 '4'라고 분류된 픽셀들은 포함시키면 안된다.
그래서 이 픽셀들을 어떡하지? 그냥 '0'이나 '1' 혹은 '2'로 labeling 후처리를 해야하나... 그러기엔 픽셀들이 너무 애매한걸?
이러고 있다가 동료가 찾아준 아주 효과적인 해결책이 torch.nn.CrossEntropyLoss에 있는 'ignore_index' 인자이다.
ignore_index = 4 를 넣어주면 '4'에 해당하는 픽셀들은 loss 계산에 무시되거나 gradient 계산에 무시된다. 즉, 학습에 관여를 안시키겠다는 말!! 너무 아름답다. 끗.
'Coding' 카테고리의 다른 글
[Python] HackerRank - Capitalize! (0) | 2023.07.23 |
---|---|
[Python] torch.scatter_ 이해하기 (0) | 2023.04.18 |
[Python] string split과 rsplit method 차이 (0) | 2023.04.16 |
[Python] 클래스 상속(Class inheritance) 그리고 Pytorch 모델에서의 해석. (1) | 2023.02.24 |
torch.nn.utils.prune에서 L1Unstructured와 l1_unstructured의 차이 (0) | 2023.01.25 |