在建立CNN模型时,使用如下代码,在构建方法里面新建了一些对象,例如self.conv1,在下面的forward方法中直接把对象名作为方法名,传入变量x
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(5*5*6,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self, x):x=self.conv1(x)x=self.conv1.forward(x)x=F.relu(x)x=self.pool(x)x=self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5) x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return x
实际上x=self.conv1(x) 等价于 x=self.conv1.forward(x)
原因:self.conv1=nn.Conv2d(3,6,5)
其中nn.Conv2d是一个类,继承关系Conv2d ——> _Convnd ——> Module,在建立Module类时,定义了__call()__,说明这个类的子类实例是可调用的
def __call__(self, *input, **kwargs):for hook in self._forward_pre_hooks.values():hook(self, input)if torch.jit._tracing:result = self._slow_forward(*input, **kwargs)else:result = self.forward(*input, **kwargs)for hook in self._forward_hooks.values():hook_result = hook(self, input, result)if hook_result is not None:raise RuntimeError("forward hooks should never return any values, but '{}'""didn't return None".format(hook))if len(self._backward_hooks) > 0:var = resultwhile not isinstance(var, torch.Tensor):if isinstance(var, dict):var = next((v for v in var.values() if isinstance(v, torch.Tensor)))else:var = var[0]grad_fn = var.grad_fnif grad_fn is not None:for hook in self._backward_hooks.values():wrapper = functools.partial(hook, self)functools.update_wrapper(wrapper, hook)grad_fn.register_hook(wrapper)return result
具体调用什么方法:
使用x=self.conv1(x) 等价于调用了forward(),是因为:
1、Module定义了forward(),Conv2d, _Convnd 继承并改写了forward()
2、在Module的__call()__里面写了
....result = self.forward(*input, **kwargs)....return result
所以,使用对象名作为方法名时,调用的方法是 对象的 类的 forward()
做一个简单例子
class A():def __init__(self):self.a = 1def func(self, input):print('A_call '+ input)def __call__(self, *args, **kwargs):return self.func(*args)class B(A):def __init__(self):super(B, self).__init__()def func(self, input):print('B_call ' + input)a=A()
a('ok')b = B()
b('ok')
输出:
A_call ok
B_call ok