사냥꾼의 IT 노트

TensorFlow를 이용한 YOLO v1 논문 구현 #7 - train.py 본문

YOLO

TensorFlow를 이용한 YOLO v1 논문 구현 #7 - train.py

가면 쓴 사냥꾼 2022. 7. 8. 14:02

이전 글: https://it-the-hunter.tistory.com/34

 

TensorFlow를 이용한 YOLO v1 논문 구현 #6 - model.py

이전 글: https://it-the-hunter.tistory.com/33 TensorFlow를 이용한 YOLO v1 논문 구현 #5 - utils.py 이전 글: https://it-the-hunter.tistory.com/32 TensorFlow를 이용한 YOLO v1 논문 구현 #4 - datasets.py..

it-the-hunter.tistory.com

train.py

목표: 모델 class를 인스턴스로 선언해 for 루프를 돌면서, gradient descet를 수행하며 파라미터를 업데이트


필요한 모듈, 라이브러리 import

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os
import random

from absl import flags
from absl import app

from loss import yolo_loss
from model import YOLOv1
from dataset import process_each_ground_truth
from utils import draw_bounding_box_and_label_info, generate_color, find_max_confidence_bounding_box, yolo_format_to_bounding_box_dict

flags는 tensorflow에서 제공하는 객체로, 고정값으로 되어있는 데이터를 쉽게 가져올 수 있지만 실행 시 유동적으로 변경도 가능한 장점이 있다.

flags.DEFINE_string('checkpoint_path', default='saved_model', help='path to a directory to save model checkpoints during training')
flags.DEFINE_integer('save_checkpoint_steps', default=50, help='period at which checkpoints are saved (defaults to every 50 steps)')
flags.DEFINE_string('tensorboard_log_path', default='tensorboard_log', help='path to a directory to save tensorboard log')
flags.DEFINE_integer('validation_steps', default=50, help='period at which test prediction result and save image')
flags.DEFINE_integer('num_epochs', default=135, help='training epochs') # original paper : 135 epoch
flags.DEFINE_float('init_learning_rate', default=0.0001, help='initial learning rate') # original paper : 0.001 (1epoch) -> 0.01 (75epoch) -> 0.001 (30epoch) -> 0.0001 (30epoch)
flags.DEFINE_float('lr_decay_rate', default=0.5, help='decay rate for the learning rate')
flags.DEFINE_integer('lr_decay_steps', default=2000, help='number of steps after which the learning rate is decayed by decay rate')
flags.DEFINE_integer('num_visualize_image', default=8, help='number of visualize image for validation')

FLAGS = flags.FLAGS

위 코드에서는 epoch의 default값이 135로 설정되어 있는데, 다음과 같이 실행할 때 유동적으로 설정이 가능하다.

python train.py --num_epoch=150

상세 코드

#cat label 설정
cat_label_dict = {
  0: "cat"
}
cat_class_to_label_dict = {v: k for k, v in cat_label_dict.items()}

사람이 이해할 수 있는 "cat"으로 label해주기 위한 코드. 우리가 사용할 데이터셋은 PASCAL 데이터셋인데, 이중에서 cat 데이터셋을 가져와 train 시켜줄 것이다.


#각 변수 설정
batch_size = 24 # original paper : 64
input_width = 224 # original paper : 448
input_height = 224 # original paper : 448
cell_size = 7
num_classes = 1 # original paper : 20
boxes_per_cell = 2

#drawing을 위한 색 설정
color_list = generate_color(num_classes)

#loss function 계수 설정
coord_scale = 10 # original paper : 5
class_scale = 0.1  # original paper : 1
object_scale = 1
noobject_scale = 0.5

각 함수별로 필요한 인자와 변수 설정이다.

original paper는 원본 논문의 값이며, 필자는 좀 더 계산을 편하게 하기 위해 위와 같이 설정해주었다.


#pascal voc2007/voc2012 가져오기
# notice : voc2007 train data(=2,501 images) for test & voc2007 test data(=4,952 images) for training
voc2007_test_split_data = tfds.load("voc/2007", split=tfds.Split.TEST, batch_size=1)
voc2012_train_split_data = tfds.load("voc/2012", split=tfds.Split.TRAIN, batch_size=1)
voc2012_validation_split_data = tfds.load("voc/2012", split=tfds.Split.VALIDATION, batch_size=1)
train_data = voc2007_test_split_data.concatenate(voc2012_train_split_data).concatenate(voc2012_validation_split_data)

#validation data 설정
voc2007_validation_split_data = tfds.load("voc/2007", split=tfds.Split.VALIDATION, batch_size=1)
validation_data = voc2007_validation_split_data

위에서 말했듯이 PASCAL 데이터셋을 사용할 것이다. 이를 가져오기 위한 코드이다. 2007 버전의 데이터셋을 train dataset으로, 2012버전의 데이터셋을 validation datasets으로 사용한다.


#label 7 : cat
# Reference : https://stackoverflow.com/questions/55731774/filter-dataset-to-get-just-images-from-specific-class
def predicate(x, allowed_labels=tf.constant([7.0])):
  label = x['objects']['label']
  isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
  reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))

  return tf.greater(reduced, tf.constant(0.))

train_data = train_data.filter(predicate)
train_data = train_data.padded_batch(batch_size)

validation_data = validation_data.filter(predicate)
validation_data = validation_data.padded_batch(batch_size)

필자가 사용할 labe은 cat 이다. 훈련 데이터와 validation 데이터를 위와 같이 설정해주었다.


def reshape_yolo_preds(preds):
  #flatten vector -> cell_size x cell_size x (num_classes + 5 * boxes_per_cell)
  return tf.reshape(preds, [tf.shape(preds)[0], cell_size, cell_size, num_classes + 5 * boxes_per_cell])

YOLO 모델의 최종 output은 S x S x (5 * B + C) (5 : x, y, w, h, confidence)이다. 이를 계산하기 위해 코드를 입력하고, flatten 된 객체를 cell * cell 로 reshape 해준다.


def calculate_loss(model, batch_image, batch_bbox, batch_labels):
  total_loss = 0.0
  coord_loss = 0.0
  object_loss = 0.0
  noobject_loss = 0.0
  class_loss = 0.0
  for batch_index in range(batch_image.shape[0]):
    image, labels, object_num = process_each_ground_truth(batch_image[batch_index], batch_bbox[batch_index], batch_labels[batch_index], input_width, input_height)
    image = tf.expand_dims(image, axis=0)

    predict = model(image)
    predict = reshape_yolo_preds(predict)

    for object_num_index in range(object_num):
      each_object_total_loss, each_object_coord_loss, each_object_object_loss, each_object_noobject_loss, each_object_class_loss = yolo_loss(predict[0],
                                   labels,
                                   object_num_index,
                                   num_classes,
                                   boxes_per_cell,
                                   cell_size,
                                   input_width,
                                   input_height,
                                   coord_scale,
                                   object_scale,
                                   noobject_scale,
                                   class_scale
                                   )

      total_loss = total_loss + each_object_total_loss
      coord_loss = coord_loss + each_object_coord_loss
      object_loss = object_loss + each_object_object_loss
      noobject_loss = noobject_loss + each_object_noobject_loss
      class_loss = class_loss + each_object_class_loss

  return total_loss, coord_loss, object_loss, noobject_loss, class_loss

모델, batch image, batch bounding box, batch label을 loss 계산하기 위한 함수 정의다. 총 loss 값을 의미하는 total_loss 변수를 생성하고 total_loss, 좌표를 뜻하는 coord_loss, 오브젝트를 뜻하는 object_loss, 오브젝트가 없는 것을 뜻하는 noobject_loss, 마지막으로 class_loss 값들을 return해준다.


def train_step(optimizer, model, batch_image, batch_bbox, batch_labels):
  with tf.GradientTape() as tape:
    total_loss, coord_loss, object_loss, noobject_loss, class_loss = calculate_loss(model, batch_image, batch_bbox, batch_labels)
  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  return total_loss, coord_loss, object_loss, noobject_loss, class_loss

실제 gradient를 실행하는 함수이다. train_step이 한번 호출될 때마다 YOLO v1 모델의 파라미터가 데이터셋의 오브젝트를 잘 검출할 수 있는 방향의 gradient로 진행된다.


def save_validation_result(model, ckpt, validation_summary_writer, num_visualize_image):
  total_validation_total_loss = 0.0
  total_validation_coord_loss = 0.0
  total_validation_object_loss = 0.0
  total_validation_noobject_loss = 0.0
  total_validation_class_loss = 0.0
  for iter, features in enumerate(validation_data):
    batch_validation_image = features['image']
    batch_validation_bbox = features['objects']['bbox']
    batch_validation_labels = features['objects']['label']

    batch_validation_image = tf.squeeze(batch_validation_image, axis=1)
    batch_validation_bbox = tf.squeeze(batch_validation_bbox, axis=1)
    batch_validation_labels = tf.squeeze(batch_validation_labels, axis=1)

    validation_total_loss, validation_coord_loss, validation_object_loss, validation_noobject_loss, validation_class_loss = calculate_loss(model, batch_validation_image, batch_validation_bbox, batch_validation_labels)

    total_validation_total_loss = total_validation_total_loss + validation_total_loss
    total_validation_coord_loss = total_validation_coord_loss + validation_coord_loss
    total_validation_object_loss = total_validation_object_loss + validation_object_loss
    total_validation_noobject_loss = total_validation_noobject_loss + validation_noobject_loss
    total_validation_class_loss = total_validation_class_loss + validation_class_loss

  #validation tensorboard log 저장
  with validation_summary_writer.as_default():
    tf.summary.scalar('total_validation_total_loss', total_validation_total_loss, step=int(ckpt.step))
    tf.summary.scalar('total_validation_coord_loss', total_validation_coord_loss, step=int(ckpt.step))
    tf.summary.scalar('total_validation_object_loss ', total_validation_object_loss, step=int(ckpt.step))
    tf.summary.scalar('total_validation_noobject_loss ', total_validation_noobject_loss, step=int(ckpt.step))
    tf.summary.scalar('total_validation_class_loss ', total_validation_class_loss, step=int(ckpt.step))

  # validation test image 저장
  for validation_image_index in range(num_visualize_image):
    random_idx = random.randint(0, batch_validation_image.shape[0] - 1)
    image, labels, object_num = process_each_ground_truth(batch_validation_image[random_idx], batch_validation_bbox[random_idx],
                                                          batch_validation_labels[random_idx], input_width, input_height)

    drawing_image = image

    image = tf.expand_dims(image, axis=0)
    predict = model(image)
    predict = reshape_yolo_preds(predict)

    #예측값 parsing
    predict_boxes = predict[0, :, :, num_classes + boxes_per_cell:]
    predict_boxes = tf.reshape(predict_boxes, [cell_size, cell_size, boxes_per_cell, 4])

    confidence_boxes = predict[0, :, :, num_classes:num_classes + boxes_per_cell]
    confidence_boxes = tf.reshape(confidence_boxes, [cell_size, cell_size, boxes_per_cell, 1])

    class_prediction = predict[0, :, :, 0:num_classes]
    class_prediction = tf.argmax(class_prediction, axis=2)

    #예측 영역(bounding box) 리스트 생성
    bounding_box_info_list = []
    for i in range(cell_size):
      for j in range(cell_size):
        for k in range(boxes_per_cell):
          pred_xcenter = predict_boxes[i][j][k][0]
          pred_ycenter = predict_boxes[i][j][k][1]
          pred_box_w = tf.minimum(input_width * 1.0, tf.maximum(0.0, predict_boxes[i][j][k][2]))
          pred_box_h = tf.minimum(input_height * 1.0, tf.maximum(0.0, predict_boxes[i][j][k][3]))

          pred_class_name = cat_label_dict[class_prediction[i][j].numpy()]
          pred_confidence = confidence_boxes[i][j][k].numpy()[0]

          # add bounding box dict list
          bounding_box_info_list.append(yolo_format_to_bounding_box_dict(pred_xcenter, pred_ycenter, pred_box_w, pred_box_h, pred_class_name, pred_confidence))

    #정답 영역(bounding box) 리스트 생성
    ground_truth_bounding_box_info_list = []
    for each_object_num in range(object_num):
      labels = np.array(labels)
      labels = labels.astype('float32')
      label = labels[each_object_num, :]
      xcenter = label[0]
      ycenter = label[1]
      box_w = label[2]
      box_h = label[3]
      class_label = label[4]

      # label 7 : cat
      # add ground-turth bounding box dict list
      if class_label == 7:
        ground_truth_bounding_box_info_list.append(
          yolo_format_to_bounding_box_dict(xcenter, ycenter, box_w, box_h, 'cat', 1.0))

    ground_truth_drawing_image = drawing_image.copy()
    #정답 이미지 drawing
    for ground_truth_bounding_box_info in ground_truth_bounding_box_info_list:
      draw_bounding_box_and_label_info(
        ground_truth_drawing_image,
        ground_truth_bounding_box_info['left'],
        ground_truth_bounding_box_info['top'],
        ground_truth_bounding_box_info['right'],
        ground_truth_bounding_box_info['bottom'],
        ground_truth_bounding_box_info['class_name'],
        ground_truth_bounding_box_info['confidence'],
        color_list[cat_class_to_label_dict[ground_truth_bounding_box_info['class_name']]]
      )

    #confidence 값이 최대인 box 찾기
    max_confidence_bounding_box = find_max_confidence_bounding_box(bounding_box_info_list)

    #예측값 그리기
    draw_bounding_box_and_label_info(
      drawing_image,
      max_confidence_bounding_box['left'],
      max_confidence_bounding_box['top'],
      max_confidence_bounding_box['right'],
      max_confidence_bounding_box['bottom'],
      max_confidence_bounding_box['class_name'],
      max_confidence_bounding_box['confidence'],
      color_list[cat_class_to_label_dict[max_confidence_bounding_box['class_name']]]
    )

    #왼: 정답 영역/오: 예측 영역으로 box 설정
    drawing_image = np.concatenate((ground_truth_drawing_image, drawing_image), axis=1)
    drawing_image = drawing_image / 255
    drawing_image = tf.expand_dims(drawing_image, axis=0)

    #tensorboard log 저장
    with validation_summary_writer.as_default():
      tf.summary.image('validation_image_'+str(validation_image_index), drawing_image, step=int(ckpt.step))

전체 validation 데이터를 가져와 validation 데이터셋과 YOLO v1 모델이 예측한 데이터셋을 비교하는 함수다. validation 전체 데이터셋의 loss 값을 어느정도 보여주고, 오브젝트가 잘 검출되는지 확인하기 위해 total_validation_loss 등의 변수들에 할당을 해준다. 

이후 각 영역들에 대한 예측값을 시각화하며 tensorboar_log로 저장하기까지가 위 코드의 내용이다.


def main(_):
  #learning rate decay 설정
  lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    FLAGS.init_learning_rate,
    decay_steps=FLAGS.lr_decay_steps,
    decay_rate=FLAGS.lr_decay_rate,
    staircase=True)

  #optimizer 설정
  optimizer = tf.optimizers.Adam(lr_schedule)  # original paper : SGD with momentum 0.9, decay 0.0005

  #체크포인트 경로가 존재한다면 ...
  if not os.path.exists(FLAGS.checkpoint_path):
    os.mkdir(FLAGS.checkpoint_path)

  #YOLO 모델 생성
  YOLOv1_model = YOLOv1(input_height, input_width, cell_size, boxes_per_cell, num_classes)

  #체크포인트 manager 설정
  ckpt = tf.train.Checkpoint(step=tf.Variable(0), model=YOLOv1_model)
  ckpt_manager = tf.train.CheckpointManager(ckpt,
                                            directory=FLAGS.checkpoint_path,
                                            max_to_keep=None)
  latest_ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path)

  #마지막 체크포인트 저장
  if latest_ckpt:
    ckpt.restore(latest_ckpt)
    print('global_step : {}, checkpoint is restored!'.format(int(ckpt.step)))

  #tensorboard log 설정
  train_summary_writer = tf.summary.create_file_writer(FLAGS.tensorboard_log_path +  '/train')
  validation_summary_writer = tf.summary.create_file_writer(FLAGS.tensorboard_log_path +  '/validation')

  for epoch in range(FLAGS.num_epochs):
    num_batch = len(list(train_data))
    for iter, features in enumerate(train_data):
      batch_image = features['image']
      batch_bbox = features['objects']['bbox']
      batch_labels = features['objects']['label']

      batch_image = tf.squeeze(batch_image, axis=1)
      batch_bbox = tf.squeeze(batch_bbox, axis=1)
      batch_labels = tf.squeeze(batch_labels, axis=1)

      #최적화 및 loss 함수 실행
      total_loss, coord_loss, object_loss, noobject_loss, class_loss = train_step(optimizer, YOLOv1_model, batch_image, batch_bbox, batch_labels)

      #실행 log 출력
      print("Epoch: %d, Iter: %d/%d, Loss: %f" % ((epoch+1), (iter+1), num_batch, total_loss.numpy()))

      #tensorboard log 저장
      with train_summary_writer.as_default():
        tf.summary.scalar('learning_rate ', optimizer.lr(ckpt.step).numpy(), step=int(ckpt.step))
        tf.summary.scalar('total_loss', total_loss, step=int(ckpt.step))
        tf.summary.scalar('coord_loss', coord_loss, step=int(ckpt.step))
        tf.summary.scalar('object_loss ', object_loss, step=int(ckpt.step))
        tf.summary.scalar('noobject_loss ', noobject_loss, step=int(ckpt.step))
        tf.summary.scalar('class_loss ', class_loss, step=int(ckpt.step))

      #체크포인트 저장
      if ckpt.step % FLAGS.save_checkpoint_steps == 0:
        
        ckpt_manager.save(checkpoint_number=ckpt.step)
        print('global_step : {}, checkpoint is saved!'.format(int(ckpt.step)))

      ckpt.step.assign_add(1)

      # occasionally check validation data and save tensorboard log
      if iter % FLAGS.validation_steps == 0:
        save_validation_result(YOLOv1_model, ckpt, validation_summary_writer, FLAGS.num_visualize_image)

if __name__ == '__main__':
  app.run(main)

main 함수다. main 함수의 로직은 다음과 같다.

  1. tensoflow api에서  learning rate decay를 호출한다.
  2. 체크포인트 path에 중간 파라미터를 저장한다.
  3. 앞서 생성했던 YOLO v1 클래스를 인스턴스로 설정한다.
  4. YOLO v1 모델의 중간 파라미터를 계속 저장하기 위한 체크포인트 manager를 설정한다.
  5. tensorboatd_log를 저장하기 위한 summary_writer를 생성한다.