1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
| def average_losses(loss): tf.add_to_collection('losses', loss)
losses = tf.get_collection('losses')
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = tf.add_n(losses + regularization_losses, name='total_loss')
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') loss_averages_op = loss_averages.apply(losses + [total_loss])
with tf.control_dependencies([loss_averages_op]): total_loss = tf.identity(total_loss) return total_loss
def average_gradients(tower_grads): average_grads = [] for grad_and_vars in zip(*tower_grads): grads = [g for g, _ in grad_and_vars] grad = tf.stack(grads, 0) grad = tf.reduce_mean(grad, 0)
v = grad_and_vars[0][1] grad_and_var = (grad, v) average_grads.append(grad_and_var) return average_grads
def feed_all_gpu(inp_dict, models, payload_per_gpu, batch_x, batch_y): for i in range(len(models)): x, y, _, _, _ = models[i] start_pos = i * payload_per_gpu stop_pos = (i + 1) * payload_per_gpu inp_dict[x] = batch_x[start_pos:stop_pos] inp_dict[y] = batch_y[start_pos:stop_pos] return inp_dict def multi_gpu(num_gpu): batch_size = 128 * num_gpu mnist = input_data.read_data_sets('/tmp/data/mnist',one_hot=True)
tf.reset_default_graph() with tf.Session() as sess: with tf.device('/cpu:0'): learning_rate = tf.placeholder(tf.float32, shape=[]) opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
print('build model...') print('build model on gpu tower...') models = [] for gpu_id in range(num_gpu): with tf.device('/gpu:%d' % gpu_id): print('tower:%d...'% gpu_id) with tf.name_scope('tower_%d' % gpu_id): with tf.variable_scope('cpu_variables', reuse=gpu_id>0): x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) pred = build_model(x) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) grads = opt.compute_gradients(loss) models.append((x,y,pred,loss,grads)) print('build model on gpu tower done.')
print('reduce model on cpu...') tower_x, tower_y, tower_preds, tower_losses, tower_grads = zip(*models) aver_loss_op = tf.reduce_mean(tower_losses) apply_gradient_op = opt.apply_gradients(average_gradients(tower_grads))
all_y = tf.reshape(tf.stack(tower_y, 0), [-1,10]) all_pred = tf.reshape(tf.stack(tower_preds, 0), [-1,10]) correct_pred = tf.equal(tf.argmax(all_y, 1), tf.argmax(all_pred, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, 'float')) print('reduce model on cpu done.')
print('run train op...') sess.run(tf.global_variables_initializer()) lr = 0.01 for epoch in range(2): start_time = time.time() payload_per_gpu = batch_size/num_gpu total_batch = int(mnist.train.num_examples/batch_size) avg_loss = 0.0 print('\n---------------------') print('Epoch:%d, lr:%.4f' % (epoch,lr)) for batch_idx in range(total_batch): batch_x,batch_y = mnist.train.next_batch(batch_size) inp_dict = {} inp_dict[learning_rate] = lr inp_dict = feed_all_gpu(inp_dict, models, payload_per_gpu, batch_x, batch_y) _, _loss = sess.run([apply_gradient_op, aver_loss_op], inp_dict) avg_loss += _loss avg_loss /= total_batch print('Train loss:%.4f' % (avg_loss)) lr = max(lr * 0.7,0.00001)
val_payload_per_gpu = batch_size / num_gpu total_batch = int(mnist.validation.num_examples / batch_size) preds = None ys = None for batch_idx in range(total_batch): batch_x,batch_y = mnist.validation.next_batch(batch_size) inp_dict = feed_all_gpu({}, models, val_payload_per_gpu, batch_x, batch_y) batch_pred,batch_y = sess.run([all_pred,all_y], inp_dict) if preds is None: preds = batch_pred else: preds = np.concatenate((preds, batch_pred), 0) if ys is None: ys = batch_y else: ys = np.concatenate((ys,batch_y),0) val_accuracy = sess.run([accuracy], {all_y:ys, all_pred:preds})[0] print('Val Accuracy: %0.4f%%' % (100.0 * val_accuracy))
stop_time = time.time() elapsed_time = stop_time-start_time print('Cost time: ' + str(elapsed_time) + ' sec.') print('training done.')
|