当前位置: 代码迷 >> 综合 >> [ pytorch ] —— 基本原理
  详细解决方案

[ pytorch ] —— 基本原理

热度:27   发布时间:2024-01-25 02:36:10.0
  1. 自动求导与权重更新
import torch
from torch.autograd import  Variabledef func(input, theta):# functiony = theta * inputreturn y.mean(), theta
#------------------#
# Hyperparameter #
#------------------#
learning_rate = 0.1################
# Input Data #
################
input = torch.ones(2, 2)  # Note that: input don't require grad because it doesn't need update!
print('Value of input:\t',input,'\n')################################################
# Create a operation(func wi parameter theta) #
################################################### Init weight of func
# parameter which requires grad(method1)
theta = torch.rand([2,2],requires_grad=True)
print('Value of theta:\t', theta, '\n')
# #(method2)
# theta = torch.rand([2,2])
# theta.requires_grad_(True)### Forward
output, theta = func(input, theta)
print('Value of output:\t',output)
print('Record Operation in Output:\t',output.grad_fn,'\n')#################################
# Compute grad by .backward() #
#################################### Watch grad in thetax before backward
print('Grad in theta Before Backward:\t',theta.grad,'\n')### Watch grad in theta after backward
output.backward()
print('Grad in theta After Backward:\t',theta.grad,'\n')######################################
# Update weight of parameter theta #
######################################
theta.data.sub_(theta.grad.data * learning_rate)
print('theta After Updating:\t',theta)########################
# Next iteration ... #
########################
# # new forward
# output, theta = func(input, theta)