같은 학습 데이터로 학습을 하고, 동일한 테스트 데이터로 테스트를 하였음에도 매번 실행해보면 모델의 학습 파라미터와 테스트 결과가 동일하지 않은 경우가 많다. 이는, 높은 수준의 재생산성(Reproducibility)을 요구하는 대회나 업무에 지장을 줄 수 있다. 이 글에선 Pytorch를 사용할 때 최대한 Reproducibility를 유지할 수 있는 방법에 대해 적어보았다.
Seed 고정
난수 생성기의 seed를 고정하면, 매번 프로그램을 실행할 때마다 생성되는 난수들의 수열이 같게 할 수 있다. 그래서 pytorch와 관련 라이브러리에서 사용되는 난수 관련 seed를 고정하여야 한다. pytorch_lightning에선 pytorch와 관련된 난수 생성기의 seed를 고정하는 코드가 있다. 그 코드를 보면, pytorch의 seed 설정 함수와 함께 python random 모듈, numpy의 seed를 고정하는 모습을 볼 수 있다.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Nondeterministic 한 작업 피하기
Cudnn benchmark 해제
cudnn은 convolution을 수행하는 과정에 벤치마킹을 통해서 지금 환경에 가장 적합한 알고리즘을 선정해 수행한다고 한다. 이 과정에서 다른 알고리즘이 선정되면 연산 후 값이 달라질 수 있는 것이다. 이 설정을 켜 놓으면 성능 향상에 도움이 된다고 한다.
torch.backends.cudnn.benchmark = False
Deterministic 한 알고리즘만 사용하게 하기
Cudnn에서 수행하는 연산에 대해 적용하려면 아래 코드를 사용한다.
torch.backends.cudnn.deterministic = True
또는 아래 코드를 사용하여 다른 pytorch연산들도 같이 deterministic 하게 설정할 수 있다. Deterministic 한 알고리즘으로 수행 가능한 연산은 Deterministic으로 진행하되, 불가능한 경우에는 RuntimeError를 던지게 한다.
torch.use_deterministic_algorithms(True)
아래 pytorch 공식문서에서 설정에 영향을 받는 연산들을 확인할 수 있다.
DataLoader worker에 대한 Seed 설정
Dataloader에서 Multiprocess를 사용할 때 각 worker에는 base_seed + worker_id로 시드가 설정된다고 한다. 그런데, 다른 라이브러리의 seed는 이와 같지 않을 수 있다. worker마다 seed를 설정하는 함수를 dataloader 생성 시 worker_init_fn으로 넣어주자.
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=seed_worker
)
PYTHONHASHSEED
파이썬의 자체 hashing 알고리즘은 random요소가 있다. 그리고 그 hash결과에 영향을 주는 요소가 PYTHONHASHSEED이다.
os.environ['PYTHONHASHSEED'] = str(seed)
결론
def seed(seed = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["PYTHONHASHSEED"] = str(seed)
def seed_worker(_worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
참고
https://pytorch.org/docs/stable/notes/randomness.html
'Study > Python' 카테고리의 다른 글
numpy 배열의 생성과 조작 (0) | 2021.06.22 |
---|---|
Python pathlib.Path로 경로관리하기 (0) | 2021.06.10 |
Python multiprocessing.Pool 멀티프로세싱 2 (1) | 2021.06.07 |
Python multiprocessing.Process 멀티프로세싱 1 (0) | 2021.06.03 |
Python collections의 Counter로 개수 세기 (0) | 2021.05.27 |
댓글