My Coding > Programming language > Python > PyTocrch > PyTorch - Calculate few steps of gradient

PyTorch - Calculate few steps of gradient

Calculate few steps of gradient descent with PyTorch

Approaching to the local minima (or any local extremum) required few steps towards gradient. Now we will make this tensor optimisation by two different ways

Our function for example will be:

f(w)=i,j∏ln(ln(wi,j+7))

where initial tensor w = [[5., 10.], [1., 2.]]

and

We will perform n=500 steps

Optimisation with fixed step

For the fixed step we will take gradient step alpha (start learning rate) = 0.001


import torch

w = torch.tensor([[5., 10.], [1., 2.]], requires_grad=True)
alpha = 0.001
n = 500
for _ in range(n):
    # calculate target function according to our equation
    function = (w + 7).log().log().prod()
    # calculate gradient
    function.backward()
    # move towards gradient
    w.data -=  alpha * w.grad
    # clear gradient to avoid gradient Accumulation
    w.grad.zero_()

print(w) 
#tensor([[4.9900, 9.9948],
#        [0.9775, 1.9825]], requires_grad=True)

Optimisation with optimal step

For real tasks, it is important to have proper steps towards gradient descent. PyTorch can calculate optimal value of this step.

In this task we will perform 500 steps of optimisation with calculating of an optimal step size for each cycle. All other parameters are the same


import torch
# define initial tensor
w = torch.tensor([[5., 10.], [1., 2.]], requires_grad=True)
# define initial step (start learning rate)
alpha = 0.001
n = 500
# define step optimizer with stochastic gradient descent
optimizer = torch.optim.SGD([w], lr=alpha)

for _ in range(n):
    # calculate target function according to our equation
    function = (w + 7).log().log().prod()
    # calculate gradient
    function.backward()
    # calculationg one step with optimizer
    optimizer.step()
    # clear gradient to avoid gradient Accumulation
    optimizer.zero_grad()
print(w) 
#tensor([[4.9900, 9.9948],
#        [0.9775, 1.9825]], requires_grad=True)


Published: 2022-05-05 02:43:14
Updated: 2022-05-05 02:49:35

Last 10 artitles


9 popular artitles

© 2020 MyCoding.uk -My blog about coding and further learning. This blog was writen with pure Perl and front-end output was performed with TemplateToolkit.