signed

QiShunwang

“诚信为本、客户至上”

pytorch自定义scheduler

2021/3/21 11:48:37   来源:

实现当网络的训练loss不下降的时候调低学习率

1、直接上代码:
–scheduler类代码

import numpy as numpy
import scipy.stats as ss
class myscheduler():
    def __init__(self,optimizer,n_epoch,decay):
        self.optimizer = optimizer
        self.n_epoch = n_epoch
        self.decay = decay      #学习率衰减值
        self.loss_summary = [] #初始化一个列表存储n_epoch代的loss


    def step(self,loss):
        if len(self.loss_summary)<self.n_epoch:
            self.loss_summary.append(loss)
        else:
            r = ss.rankdata(self.loss_summary) #计算loss序列的秩
            ##计算spearman指标
            spearman = 0
            for i in range (self.n_epoch):
                spearman = spearman + (i+1-r[i])**2
            spearman = 1-6*spearman/(self.n_epoch*(self.n_epoch**2-1))
            print(f"spearman:{spearman},loss下降中")
            ##当spearman指标显示近n_epoch代的loss没有下降时,衰减学习率
            if spearman >0:
                self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr']  * self.decay
                print(f"lr change to {self.optimizer.param_groups[0]['lr'] }")
            self.loss_summary = []

–调用的代码段

optimizer=torch.optim.SGD(Net.parameters(),lr=0.001)
scheduler = myscheduler(optimizer,10,0.9) ##初始化scheduler

for epoch in range(epochs):
    running_loss=0.0
    for i,T in enumerate(test_loader,start=0):
        input,label=T
        optimizer.zero_grad()
        output=Net(input.to(device))
       
        label=torch.tensor(label,dtype=torch.float32)
        loss=loss_fun(output,label.to(device))
        print(i,loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step(loss.item())##更新scheduler

        running_loss+=loss.item()

2、代码解释
2-1 利用spearman 指标来表征n代的loss的趋势(上升,不变,下降)
2-2 n_epoch参数代表的是每过多少step或者多少epoch更新一次lr。示例中就是每过10个step更新一次