Tensorflow.js, Epoch, Batch and Learning Rate
In Tensorflow.js, training a model involves configuring three critical hyperparameters: Epochs, Batch Size, and Learning Rate. These determine how often the model updates, how many samples it sees at once, and how much it adjusts its internal parameters during training.
- Epochs
An epoch is one complete pass of the entire training dataset through the model.
- Function: It defines the duration of the training process.
- Purpose: Training for multiple epochs allows the model to see the data repeatedly, which is necessary for the weights to converge to an optimal state.
- Usage: It is defined in the model.fit() or model.fitDataset() configuration.
- Risk: Too few epochs lead to underfitting (the model hasn't learned enough), while too many can lead to overfitting (the model memorizes the training data but fails on new data).
- Batch Size
The batch size is the number of samples processed before the model updates its internal weights.
- Efficiency: It is often impossible to load a massive dataset into memory at once; batches allow the model to learn in smaller, manageable chunks.
- Impact:
- Small batches (e.g., 32) Provide more frequent updates and "noisier" gradients, which can help escape local minima but take longer to process.
- Large batches (e.g., 256+) Speed up training and provide more stable gradients but may lead to poorer generalization.
- Learning Rate
The learning rate is a small positive value (often between 0.001 and 0.1) that determines how much the model's weights are adjusted in response to the estimated error each time they are updated.
- Control: It scales the magnitude of weight updates. A learning rate that is too high may cause the model to overshoot the optimal solution, while one that is too low will make training extremely slow.
- High Rate (e.g., 0.1): Training is faster, but the model may overshoot the optimal solution, causing the loss to oscillate or diverge.
- Low Rate (e.g., 0.001): Training is more stable and precise but significantly slower.
- Setting in TF.js: It is typically defined when initializing an optimizer (like SGD or Adam):
javascript const optimizer = tf.train.adam(0.001); // 0.001 is the learning rate
Summary of Implementation When training a model, these parameters are typically used together in the model.fit() method:
| Parameter | Code Example | Description |
|---|---|---|
| Learning Rate | tf.train.adam(0.01) | Set during optimizer initialization. |
| Batch Size | batchSize: 32 | Set in the fit configuration object. |
| Epochs | epochs: 10 | Set in the fit configuration object. |
For advanced control, you can use callbacks to adjust the learning rate dynamically after each epoch or batch.
In TensorFlow.js (TFJS), learning rate scheduling is used to dynamically adjust the step size during training to improve convergence and prevent oscillations.
Implementation Methods
Unlike the Python version of TensorFlow, which has many built-in schedule classes, TFJS typically requires manual updates via the optimizer or custom callbacks.
- Manual Optimizer Updates: You can update the learning rate of an active optimizer object at any time during the training loop using the learningRate property.
- Example for Adam: optimizer.learningRate = 0.0001;.
- Example for SGD: optimizer.setLearningRate(0.0001);.
- Using Custom Callbacks: When using model.fit(), you can define an onBatchEnd or onEpochEnd callback to adjust the rate based on progress or performance metrics.
Common Scheduling Strategies
- Step Decay: Reduces the learning rate by a fixed factor at specific intervals (e.g., every 10 epochs).
- Exponential Decay: Gradually decreases the rate according to an exponential function, often used to refine weights as the model nears a minimum.
- Polynomial Decay: Functions as a linear schedule by default, starting at an initial value and decreasing at a constant rate until reaching zero at the end of training.
- Reduce on Plateau: A popular technique that monitors a metric (like validation loss) and reduces the learning rate when improvement stops.
Summary of Built-in Optimizer Support (2026)
While specialized schedule classes are more common in Python, the primary way to manage this in TFJS remains direct manipulation of the optimizer instance within the fit() API or a custom train loop.
In TensorFlow.js, learning rate scheduling is implemented by creating a custom callback object and passing it to the model.fit() method. This callback can modify the optimizer's learningRate property at the end of each epoch or batch.
Example: Step Decay Callback The following example demonstrates a Step Decay schedule, which reduces the learning rate by a specific factor after a set number of epochs.
javascript
// Define a custom callback for learning rate scheduling
const lrCallback = {
onEpochEnd: async (epoch, logs) => {
const initialLr = 0.01;
const drop = 0.5; // Reduce by 50%
const epochsPerDrop = 10;
// Calculate new learning rate: lr = initial_lr * drop^(floor(epoch / epochs_per_drop))
const newLr = initialLr * Math.pow(drop, Math.floor((epoch + 1) / epochsPerDrop));
// Update the optimizer's learning rate directly
model.optimizer.learningRate = newLr;
console.log(`Epoch ${epoch + 1}: Learning rate is now ${newLr}`);
}
};
// Use the callback in model.fit
await model.fit(trainData, trainLabels, {
epochs: 50,
callbacks: lrCallback // Pass the custom callback here
});
Key Implementation Details
- Optimizer Properties: Different optimizers in TFJS handle rate updates slightly differently. For tf.train.adam, you set .learningRate, while for tf.train.sgd, you may use .setLearningRate(value) depending on the specific version.
- Metric-Based Scheduling: You can also implement a Reduce on Plateau strategy by checking the logs parameter (e.g., logs.val_loss) in the onEpochEnd function to decide if the rate should be lowered.
- Flexible Timing: While onEpochEnd is common, you can use onBatchEnd if you need more frequent updates during large datasets.
What are the hyperparameters for exponential decay?
In TensorFlow.js, exponential decay is typically implemented by updating the optimizer's learning rate in a custom callback. The hyperparameters that define this schedule are as follows:
- Initial Learning Rate (initial_learning_rate): The starting step size at the beginning of training (e.g., 0.1 or 0.01).
- Decay Rate (decay_rate): A scalar factor (usually between 0 and 1, such as 0.96) used to multiply the learning rate at each interval to reduce it.
- Decay Steps (decay_steps): The number of training steps (batches) that must occur before the decay is applied. This controls the "frequency" of the reduction.
- Staircase (staircase): A boolean hyperparameter. If true, the learning rate decays at discrete intervals (like steps in a staircase); if false, it decays continuously at every training step.
Formula -The decayed learning rate is calculated at each step using the following relationship:
decayed_learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)
To apply this in a TensorFlow.js custom callback:
javascript
const initialLr = 0.1;
const decayRate = 0.96;
const decaySteps = 1000;
const exponentialDecayCallback = {
onBatchEnd: async (batch, logs) => {
// Continuous decay implementation
const step = batch;
const newLr = initialLr * Math.pow(decayRate, (step / decaySteps));
model.optimizer.learningRate = newLr;
}
};
In TensorFlow.js (TFJS), a staircase decay is implemented by using a Math.floor() function within a custom training callback. This ensures the learning rate remains constant for a set number of steps or epochs before dropping to a new discrete level.
Implementation with model.fit() The most common way to do this is to update the optimizer's learning rate at the end of every epoch using the onEpochEnd callback.
javascript
// Hyperparameters for staircase decay
const initialLr = 0.1;
const decayRate = 0.5; // Reduce by 50%
const decaySteps = 10; // Drop every 10 epochs
const staircaseCallback = {
onEpochEnd: async (epoch, logs) => {
// formula: initial_lr * decay_rate ^ floor(epoch / decay_steps)
const exponent = Math.floor((epoch + 1) / decaySteps);
const newLr = initialLr * Math.pow(decayRate, exponent);
// Apply the new rate to the optimizer
model.optimizer.learningRate = newLr;
console.log(`Epoch ${epoch + 1}: LR is ${newLr.toFixed(6)}`);
}
};
// Start training
await model.fit(trainX, trainY, {
epochs: 50,
callbacks: staircaseCallback
});
Continuous vs. Staircase Decay
- Staircase (Discrete): Uses Math.floor(step / decaySteps) to hold the learning rate steady within an interval. This is often preferred because it gives the model time to stabilize at a specific rate before the next reduction.
- Exponential (Continuous): Removes the floor function to decrease the rate slightly at every single step, creating a smooth downward curve.
Batch-Level Staircase
If you have a massive dataset and want to drop the learning rate every few hundred batches instead of epochs, use the onBatchEnd hook:
javascript
let totalSteps = 0;
const batchStaircaseCallback = {
onBatchEnd: async (batch, logs) => {
totalSteps++;
if (totalSteps % 500 === 0) { // Drop every 500 batches
model.optimizer.learningRate *= 0.9;
}
}
};
In TensorFlow.js, a built-in ReduceLROnPlateau class is not available in the core library as it is in Python's Keras. To achieve this, you must implement a custom callback that monitors validation loss and manually updates the optimizer's learning rate when improvement stalls.
Implementation Example This custom callback tracks the best validation loss and reduces the learning rate by a factor if no improvement is seen for a set number of patience epochs.
javascript
class ReduceLROnPlateau {
constructor(config = {}) {
this.patience = config.patience || 5;
this.factor = config.factor || 0.1;
this.minLr = config.minLr || 1e-7;
this.bestLoss = Infinity;
this.wait = 0;
}
async onEpochEnd(epoch, logs) {
const currentLoss = logs.val_loss; // Monitor validation loss
if (currentLoss < this.bestLoss) {
this.bestLoss = currentLoss;
this.wait = 0;
} else {
this.wait++;
if (this.wait >= this.patience) {
let oldLr = model.optimizer.learningRate;
let newLr = Math.max(oldLr * this.factor, this.minLr);
if (oldLr > newLr) {
model.optimizer.learningRate = newLr;
console.log(`Epoch ${epoch + 1}: Reducing LR to ${newLr}`);
}
this.wait = 0; // Reset wait after reduction
}
}
}
}
// Usage in model.fit()
const reduceLR = new ReduceLROnPlateau({ patience: 3, factor: 0.5 });
await model.fit(trainX, trainY, {
epochs: 50,
validationData: [valX, valY],
callbacks: reduceLR // Pass the custom class instance
});
Key Hyperparameters
- Patience: The number of epochs to wait with no improvement before dropping the rate.
- Factor: The multiplier used to decrease the learning rate (e.g., 0.1 reduces it by 90%).
- Min LR: A lower bound to prevent the learning rate from dropping to zero.
- Monitor: While val_loss is standard, you can monitor any metric available in the logs object, such as val_acc.
Comments
Post a Comment