데이터분석/분석-지도학습

callback을 이용하여 tensorflow 모델을 학습중에 저장하기

씩씩한 IT블로그 2021. 4. 29. 16:06
반응형

용도

모델을 학습할 때 학습 중간에 프로그램이 오류가 난다면 지금까지 학습했던 가중치를 모두 잃게되는 일이 발생한다.

모델을 학습 중간중간마다 저장하면 프로그램이 끊기더라도 체크포인트부터 다시 시작할 수 있다.

텐소플로우의 콜백을 이용하여 모델을 중간중간마다 저장하고 다시 로드해서 학습하는 법을 알아본다.

 

모델 저장

# 1. 저장할 폴더와 형식을 선택
folder_directory = "체크포인트를 저장할 폴더"
checkPoint_path = folder_directory+"/model_{epoch}.ckpt" # 저장할 당시 epoch가 파일이름이 된다.

# 2. 콜백 변수를 생성
my_period = 몇번의 학습마다 저장할 것인가?
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkPoint_path,
					save_weights_only=True, verbose=1, period=my_period) 
    
# 3. fit의 파라미터에 callbacks생성
model.fit(self.x_train, self.y_train, batch_size=30, epochs=100, validation_split=0.2, 
    	callbacks=[cp_callback],  verbose=1) 

세단계로 나눌 수 있다.

1. 저장할 폴더와 형식 선택 

체크포인트들을 저장할 폴더와 체크포인트 파일의 형식을 선택한다. 폴더에 저장할 당시까지 반복한 epoch를 모델파일 이름으로 오도록 형식을 지정하였다

2. 콜백 변수 생성

몇번마다 저장할지(period), 저장하는 과정을 보여줄지(verbose) 등등을 결정하는 콜백 체크포인트 변수를 생성한다.

3. fit함수의 파라미터에 callbacks 생성

앞에서 정의한 콜백변수를 fit함수의 파라미터에 적용한다.

 

period를 3으로 지정했을 때 3번마다 저장되는것을 확인할 수 있다.

 

모델 불러오기 및 재학습

# 1. 체크포인트들이 있는 폴더 선택
checkPoint_dir = os.path.dirname(folder_directory+"/model_{epoch}.ckpt")

# 2. 해당 폴더에서 가장 마지막 체크포인트 선택
latest = tf.train.latest_checkpoint(checkPoint_dir)

# 3. 해당체크포인트에 저장한 모델의 가중치 불러오기
model.load_weights(latest)

위와 같은 단계를 거쳐서 모델을 불러올 수 있다.

이후 추가로 학습하고 싶으면 model.fit()을 하면 된다.

학습시 epoch는 1부터 시작하지만 가중치는 이미 체크포인트까지 학습된 상태이다. 따라서 체크포인트에서 불러온 가중치부터 시작하는 모델은 시작부터 loss가 낮다.

 

1. 체크포인트에서 불러온 모델의 1,2,3에포크

2. 초기화된 모델의 1,2,3에포크

반응형