본문 바로가기
Study/Python

Pytorch의 Reproducibility를 위한 설정들

by 개발새-발 2021. 6. 9.
반응형

같은 학습 데이터로 학습을 하고, 동일한 테스트 데이터로 테스트를 하였음에도 매번 실행해보면 모델의 학습 파라미터와 테스트 결과가 동일하지 않은 경우가 많다. 이는, 높은 수준의 재생산성(Reproducibility)을 요구하는 대회나 업무에 지장을 줄 수 있다. 이 글에선 Pytorch를 사용할 때 최대한 Reproducibility를 유지할 수 있는 방법에 대해 적어보았다.

 

Seed 고정

난수 생성기의 seed를 고정하면, 매번 프로그램을 실행할 때마다 생성되는 난수들의 수열이 같게 할 수 있다. 그래서 pytorch와 관련 라이브러리에서 사용되는 난수 관련 seed를 고정하여야 한다. pytorch_lightning에선 pytorch와 관련된 난수 생성기의 seed를 고정하는 코드가 있다. 그 코드를 보면, pytorch의 seed 설정 함수와 함께 python random 모듈, numpy의 seed를 고정하는 모습을 볼 수 있다.

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

 

Nondeterministic 한 작업 피하기

Cudnn benchmark 해제

cudnn은 convolution을 수행하는 과정에 벤치마킹을 통해서 지금 환경에 가장 적합한 알고리즘을 선정해 수행한다고 한다. 이 과정에서 다른 알고리즘이 선정되면 연산 후 값이 달라질 수 있는 것이다. 이 설정을 켜 놓으면 성능 향상에 도움이 된다고 한다.

torch.backends.cudnn.benchmark = False

 

Deterministic 한 알고리즘만 사용하게 하기

Cudnn에서 수행하는 연산에 대해 적용하려면 아래 코드를 사용한다.

torch.backends.cudnn.deterministic = True

또는 아래 코드를 사용하여 다른 pytorch연산들도 같이 deterministic 하게 설정할 수 있다. Deterministic 한 알고리즘으로 수행 가능한 연산은 Deterministic으로 진행하되, 불가능한 경우에는 RuntimeError를 던지게 한다.

torch.use_deterministic_algorithms(True)

아래 pytorch 공식문서에서 설정에 영향을 받는 연산들을 확인할 수 있다.

https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

 

torch.use_deterministic_algorithms — PyTorch 1.8.1 documentation

Shortcuts

pytorch.org

 

DataLoader worker에 대한 Seed 설정

Dataloader에서 Multiprocess를 사용할 때 각 worker에는 base_seed + worker_id로 시드가 설정된다고 한다. 그런데, 다른 라이브러리의 seed는 이와 같지 않을 수 있다. worker마다 seed를 설정하는 함수를 dataloader 생성 시 worker_init_fn으로 넣어주자.

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker
)

 

PYTHONHASHSEED

파이썬의 자체 hashing 알고리즘은 random요소가 있다. 그리고 그 hash결과에 영향을 주는 요소가 PYTHONHASHSEED이다.

os.environ['PYTHONHASHSEED'] = str(seed)

 

결론

def seed(seed = 42):
    random.seed(seed)
    np.random.seed(seed) 
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["PYTHONHASHSEED"] = str(seed)

def seed_worker(_worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

 

참고

https://pytorch.org/docs/stable/notes/randomness.html

 

Reproducibility — PyTorch 1.8.1 documentation

Reproducibility Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds. However, t

pytorch.org

 

반응형

댓글