tensorflow 实现模型微调finetune

前端时间完成了一个分类模型的训练,在后期测试时发现分类模型存在一些场景或者角度没法准确分类。一般咱们每每会决定对这些场景及角度数据进行补充而后从新训练模型。若是数据量较少,模型完成训练的时间则较少。假若数据量十分庞大(如我此次训练的图片有600W+),那么完成模型训练时间则较长,整体会拖延项目进度。模型从新训练也不能彻底保证新增的数据能达到较好的分类效果,此时可使用微调的方法,可大大减小模型训练的时间并能保持以前模型较好的训练效果。前端

微调策略,可选用以前测试效果较好的模型进行微调处理。微调选用的学习率设置为预训练模型快达到拟合时前的学习率,将该预训练模型加载完成后迭代1-2轮便可。例如:预训练模型在学习率衰减达到0.001时达到拟合,此时微调时学习率设置为0.005进行训练。python

如个人预加载模型以下:git

 

微调部分代码以下:(其中checkpoint_path = './ckpt/step_291800_loss_0.00985_acc_1.00000'ide

def preTrain(checkpoint_path):
    model_path_restore = checkpoint_path + '.ckpt'
    dataset = get_record_dataset(record_path=Config.record_path,     
         num_samples=Config.num_samples,num_classes=Config.num_classes)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    # print('image:',image)
    image, label = processing_image(image,label)
    images, labels = tf.train.batch([image, label], batch_size=Config.BATCH_SIZE, num_threads=1, capacity=5)

    logist = Model(images, is_training=True, num_classes=Config.num_classes)

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logist)
    loss = tf.reduce_mean(cross_entropy)

    correct_prediction = tf.equal(tf.argmax(labels, 1), tf.argmax(logist, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # # 量化
    graph = tf.get_default_graph()
    create_training_graph(graph, 40000)

    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(learning_rate=Config.learning_rate,
                                    global_step=tf.cast(tf.div(global_step, 40),
                                    tf.int32),
                                    decay_steps=Config.decay_steps,
                                    decay_rate=Config.decay_rate,
                                    staircase=True)

    # lr = 0.001
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss,
          global_step=global_step)

    global_init = tf.global_variables_initializer()
    total_step = Config.NUM_EPOCH * Config.num_samples // Config.BATCH_SIZE

    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=12)

    with tf.Session(config=gpuConfig) as sess:
        # rimages, rlabels = sess.run([images, labels])
        # print('--------rimages:---------',rimages)
        # init = tf.initialize_local_variables()
        # sess.run([init])
        sess.run([global_init])
        #加载预训练模型
        saver.restore(sess, model_path_restore)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            for step in range(1, total_step + 1):
                if coord.should_stop():
                    break

                _, loss_val, accuracy_val, global_step_test = sess.run([train_op, loss, 
                                                              accuracy, global_step])
                lr_val = sess.run(lr)
                print('global_step',global_step_test)
                print("step: %d,lr: = %.5f,loss = %.5f,accuracy =%.5f" % (step, lr_val, 
                                                         loss_val, accuracy_val))
                #模型保存
                if (step == 1):
                    tf.train.write_graph(sess.graph, Config.pb_path, "handcnet.pb")
                if step % 200 == 0:
                    saver.save(sess, Config.ckpt_path + 
                    "step_%d_loss_%.5f_acc_%.5f.ckpt" % (step, loss_val, accuracy_val))
                    print('Save for ', Config.ckpt_path + "step_%d_loss_%.5f.ckpt" % 
                                                                (step, loss_val))

        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()

        coord.join(threads)

        saver.save(sess, Config.ckpt_path + "completed_model.ckpt" % loss_val)

        tf.train.write_graph(sess.graph, Config.pb_path, "handcnet.pb")

        print("train completed!")

        sess.close()

      本次TensorFlow finetune针对我的分类模型完成的,若每次更新数据都从新训练模型都须要几天的时间(数据量较大),采用该方法,在获得新的数据时进行数据预处理后与以前的数据进行打乱整合,加载完预训练模型后迭代1-2轮便可完成,模型测试时也有更可以拟合新增数据的效果。学习