From ab02a2aa6d842ca105ac5160eec75cce82e2627f Mon Sep 17 00:00:00 2001 From: YunMao Date: Sun, 31 Jan 2021 20:51:28 +0000 Subject: [PATCH] Update Part1 --- CW1/460cw1_2020.ipynb | 87 +++++++++++++++++++++++++++++++++---------- 1 file changed, 67 insertions(+), 20 deletions(-) diff --git a/CW1/460cw1_2020.ipynb b/CW1/460cw1_2020.ipynb index 2f9265d..4a437f0 100644 --- a/CW1/460cw1_2020.ipynb +++ b/CW1/460cw1_2020.ipynb @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -211,9 +211,9 @@ " x_unfold = F.unfold(x, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride)\n", "\n", " if self.bias:\n", - " out_unfold = x_unfold.transpose(1, 2).matmul(self.w.view(self.w.size(0), -1).t()).transpose(1, 2) + self.b.view(-1, 1)\n", + " out_unfold = x_unfold.permute(0, 2, 1).matmul(self.w.view(self.w.size(0), -1).t()).permute(0, 2, 1) + self.b.view(-1, 1)\n", " else:\n", - " out_unfold = x_unfold.transpose(1, 2).matmul(self.w.view(self.w.size(0), -1).t()).transpose(1, 2)\n", + " out_unfold = x_unfold.permute(0, 2, 1).matmul(self.w.view(self.w.size(0), -1).t()).permute(0, 2, 1)\n", " out = out_unfold.view(x.shape[0], self.out_channels, out_H, out_W)\n", "\n", " # print(nn.functional.conv2d(x, self.w, self.b, padding=1))\n", @@ -232,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -250,7 +250,7 @@ " ########################################################################\n", " # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", "\n", - " if type(kernel_size) is tuple :\n", + " if type(kernel_size) is tuple :\n", " self.kernel_size = kernel_size\n", " else:\n", " self.kernel_size = (kernel_size, kernel_size)\n", @@ -272,20 +272,36 @@ " ########################################################################\n", " # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", "\n", - " pass\n", + " H_k, W_k = self.kernel_size\n", + "\n", + " x_unfold = F.unfold(x, kernel_size = self.kernel_size, stride = self.kernel_size)\n", + "\n", + " x_reshape = x_unfold.view(x.shape[0], x.shape[1], self.kernel_size[0] * self.kernel_size[1], -1)\n", + "\n", + " x_max = x_reshape.max(axis = 2)[0]\n", "\n", + " H_out = x.shape[2] // H_k\n", + " W_out = x.shape[3] // W_k\n", + "\n", + " out = x_max.view(x.shape[0], x.shape[1], H_out, W_out)\n", + " # print(out == nn.functional.max_pool2d(x, kernel_size = self.kernel_size))\n", " # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", " ########################################################################\n", " # END OF YOUR CODE #\n", " ########################################################################\n", "\n", - " return out" + " return out\n", + "# inputs = torch.rand(3, 3, 23, 23)\n", + "# maxpool = MaxPool2d(kernel_size=5)\n", + "# out2 = maxpool.forward(inputs)" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, + "execution_count": 5, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "class Linear(nn.Module):\n", @@ -303,7 +319,12 @@ " ########################################################################\n", " # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", "\n", - " pass\n", + " self.weight = torch.randn((in_channels, out_channels), requires_grad = True) * torch.sqrt(torch.tensor(1.0/in_channels))\n", + "\n", + " if bias:\n", + " self.bias = torch.randn(out_channels, requires_grad = True) * torch.sqrt(torch.tensor(1.0/in_channels))\n", + " else:\n", + " self.bias = None\n", "\n", " # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", " ########################################################################\n", @@ -324,19 +345,26 @@ " ########################################################################\n", " # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", "\n", - " pass\n", - "\n", + " if self.bias != None:\n", + " out = torch.matmul(x, self.weight) + self.bias\n", + " else:\n", + " out = torch.matmul(x, self.weight)\n", + " # nn.functional.linear(x,self.weight, self.bias)\n", + " # print(out == nn.functional.linear(x,self.weight.t(), self.bias))\n", " # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", " ########################################################################\n", " # END OF YOUR CODE #\n", " ########################################################################\n", "\n", - " return out" + " return out\n", + "# inputs = torch.rand(3, 3, 23, 23)\n", + "# linear = Linear(in_channels=23, out_channels = 4)\n", + "# out2 = linear.forward(inputs)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -363,7 +391,13 @@ " ########################################################################\n", " # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", "\n", - " pass\n", + " self.num_features = num_features\n", + " self.eps = eps\n", + " self.momentum = momentum\n", + " self.gamma = torch.ones(num_features, requires_grad = True)\n", + " self.beta = torch.zeros(num_features, requires_grad = True)\n", + " self.running_mean = torch.zeros(num_features)\n", + " self.running_var = torch.ones(num_features)\n", "\n", "\n", " # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", @@ -385,15 +419,28 @@ " # (be aware of the difference for training and testing) #\n", " ########################################################################\n", " # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", - "\n", - " pass\n", - "\n", + " mean = x.mean([0, 2, 3])\n", + " var = x.var([0, 2, 3])\n", + " self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean\n", + " running_var = (1 - self.momentum) * self.running_var + self.momentum * var\n", + " var = x.var([0,2,3], unbiased = False)\n", + " x = (x - mean.view(1, -1, 1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1, 1) * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)\n", + "\n", + " # nn.functional.BatchNorm2d(x, self.num_features)\n", + " # print(x)\n", " # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n", " ########################################################################\n", " # END OF YOUR CODE #\n", " ########################################################################\n", "\n", - " return x" + " return x\n", + "# inputs = torch.rand(3, 3, 23, 23)\n", + "# batch = BatchNorm2d(3)\n", + "# test = nn.BatchNorm2d(3)\n", + "# out = test.forward(inputs)\n", + "# out2 = batch.forward(inputs)\n", + "# print(test.running_mean)\n", + "# print(batch.running_mean)" ] }, {