Update Part1

master
YunMao 3 years ago
parent 4ba81f201b
commit ab02a2aa6d

@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@ -127,7 +127,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -211,9 +211,9 @@
" x_unfold = F.unfold(x, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride)\n",
"\n",
" if self.bias:\n",
" out_unfold = x_unfold.transpose(1, 2).matmul(self.w.view(self.w.size(0), -1).t()).transpose(1, 2) + self.b.view(-1, 1)\n",
" out_unfold = x_unfold.permute(0, 2, 1).matmul(self.w.view(self.w.size(0), -1).t()).permute(0, 2, 1) + self.b.view(-1, 1)\n",
" else:\n",
" out_unfold = x_unfold.transpose(1, 2).matmul(self.w.view(self.w.size(0), -1).t()).transpose(1, 2)\n",
" out_unfold = x_unfold.permute(0, 2, 1).matmul(self.w.view(self.w.size(0), -1).t()).permute(0, 2, 1)\n",
" out = out_unfold.view(x.shape[0], self.out_channels, out_H, out_W)\n",
"\n",
" # print(nn.functional.conv2d(x, self.w, self.b, padding=1))\n",
@ -232,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -250,7 +250,7 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" if type(kernel_size) is tuple :\n",
" if type(kernel_size) is tuple :\n",
" self.kernel_size = kernel_size\n",
" else:\n",
" self.kernel_size = (kernel_size, kernel_size)\n",
@ -272,20 +272,36 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
" H_k, W_k = self.kernel_size\n",
"\n",
" x_unfold = F.unfold(x, kernel_size = self.kernel_size, stride = self.kernel_size)\n",
"\n",
" x_reshape = x_unfold.view(x.shape[0], x.shape[1], self.kernel_size[0] * self.kernel_size[1], -1)\n",
"\n",
" x_max = x_reshape.max(axis = 2)[0]\n",
"\n",
" H_out = x.shape[2] // H_k\n",
" W_out = x.shape[3] // W_k\n",
"\n",
" out = x_max.view(x.shape[0], x.shape[1], H_out, W_out)\n",
" # print(out == nn.functional.max_pool2d(x, kernel_size = self.kernel_size))\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
" # END OF YOUR CODE #\n",
" ########################################################################\n",
"\n",
" return out"
" return out\n",
"# inputs = torch.rand(3, 3, 23, 23)\n",
"# maxpool = MaxPool2d(kernel_size=5)\n",
"# out2 = maxpool.forward(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class Linear(nn.Module):\n",
@ -303,7 +319,12 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
" self.weight = torch.randn((in_channels, out_channels), requires_grad = True) * torch.sqrt(torch.tensor(1.0/in_channels))\n",
"\n",
" if bias:\n",
" self.bias = torch.randn(out_channels, requires_grad = True) * torch.sqrt(torch.tensor(1.0/in_channels))\n",
" else:\n",
" self.bias = None\n",
"\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
@ -324,19 +345,26 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
"\n",
" if self.bias != None:\n",
" out = torch.matmul(x, self.weight) + self.bias\n",
" else:\n",
" out = torch.matmul(x, self.weight)\n",
" # nn.functional.linear(x,self.weight, self.bias)\n",
" # print(out == nn.functional.linear(x,self.weight.t(), self.bias))\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
" # END OF YOUR CODE #\n",
" ########################################################################\n",
"\n",
" return out"
" return out\n",
"# inputs = torch.rand(3, 3, 23, 23)\n",
"# linear = Linear(in_channels=23, out_channels = 4)\n",
"# out2 = linear.forward(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@ -363,7 +391,13 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
" self.num_features = num_features\n",
" self.eps = eps\n",
" self.momentum = momentum\n",
" self.gamma = torch.ones(num_features, requires_grad = True)\n",
" self.beta = torch.zeros(num_features, requires_grad = True)\n",
" self.running_mean = torch.zeros(num_features)\n",
" self.running_var = torch.ones(num_features)\n",
"\n",
"\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
@ -385,15 +419,28 @@
" # (be aware of the difference for training and testing) #\n",
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
"\n",
" mean = x.mean([0, 2, 3])\n",
" var = x.var([0, 2, 3])\n",
" self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean\n",
" running_var = (1 - self.momentum) * self.running_var + self.momentum * var\n",
" var = x.var([0,2,3], unbiased = False)\n",
" x = (x - mean.view(1, -1, 1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1, 1) * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)\n",
"\n",
" # nn.functional.BatchNorm2d(x, self.num_features)\n",
" # print(x)\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
" # END OF YOUR CODE #\n",
" ########################################################################\n",
"\n",
" return x"
" return x\n",
"# inputs = torch.rand(3, 3, 23, 23)\n",
"# batch = BatchNorm2d(3)\n",
"# test = nn.BatchNorm2d(3)\n",
"# out = test.forward(inputs)\n",
"# out2 = batch.forward(inputs)\n",
"# print(test.running_mean)\n",
"# print(batch.running_mean)"
]
},
{

Loading…
Cancel
Save