Spaces:
Configuration error
Configuration error
| import torch.optim as optim | |
| import warnings | |
| class StepLRScheduler(optim.lr_scheduler._LRScheduler): | |
| def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False): | |
| self.step_size = step_size | |
| self.gamma = gamma | |
| super(StepLRScheduler, self).__init__(optimizer, last_epoch, verbose) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| warnings.warn("To get the last learning rate computed by the scheduler, " | |
| "please use `get_last_lr()`.", UserWarning) | |
| if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): | |
| return [group['lr'] for group in self.optimizer.param_groups] | |
| return [group['lr'] * self.gamma | |
| for group in self.optimizer.param_groups] | |
| def _get_closed_form_lr(self): | |
| return [base_lr * self.gamma ** (self.last_epoch // self.step_size) | |
| for base_lr in self.base_lrs] | |