TensorFlow - Writing Custom Training Loop
04 Jun 2022-
Create the data structure using:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)Ingredients:
(x_train, y_train) -
Create the model:
inputs = keras.Input(shape=(784,), name="digits") x1 = layers.Dense(64, activation="relu")(inputs) x2 = layers.Dense(64, activation="relu")(x1) outputs = layers.Dense(10, name="predictions")(x2) model = keras.Model(inputs=inputs, outputs=outputs) -
Instantiate Optimizer and Loss function:
optimizer = keras.optimizers.SGD(learning_rate=1e-3) loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) -
Define Training Loop:
Here’s our training loop:
- We open a
forloop that iterates over epochs - For each epoch, we open a
forloop that iterates over the dataset, in batches - For each batch, we open a
GradientTape()scope - Inside this scope, we call the model (forward pass) and compute the loss
- Outside the scope, we retrieve the gradients of the weights of the model with regard to the loss
- Finally, we use the optimizer to update the weights of the model based on the gradients
epochs = 2 for epoch in range(epochs): for batch, (x_batch_train, y_batch_train) in enumerate(train_dataset): with tf.GradientTape() as tape: logits = model(x_batch_train, training=True) loss_value = loss_fn(y_batch_train, logits) grads = tape.gradient(loss_value, model.trainable_weights # NEW optimizer.apply_gradients(zip(grads, model.trainable_weights)) # NEW # Log every 200 batches. if batch % 200 == 0: print( "Training loss (for one batch) at batch %d: %.4f" % (batch, float(loss_value)) ) print("Seen so far: %s samples" % ((batch + 1) * batch_size)) - We open a
Others:
- To use custom learning_rate, see this.
- Once tape.gradient() is called, the memory of GradientTape is cleaned.
Ref:
- From Tensorflow Doc