master
YunMao 3 years ago
parent c697e42da4
commit 4ba81f201b

@ -43,9 +43,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 1,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: torch in c:\\users\\yunmao\\.conda\\envs\\ml\\lib\\site-packages (1.7.1)\nRequirement already satisfied: torchvision in c:\\users\\yunmao\\.conda\\envs\\ml\\lib\\site-packages (0.8.2)\nRequirement already satisfied: typing-extensions in c:\\users\\yunmao\\.conda\\envs\\ml\\lib\\site-packages (from torch) (3.7.4.3)\nRequirement already satisfied: numpy in c:\\users\\yunmao\\.conda\\envs\\ml\\lib\\site-packages (from torch) (1.19.5)\nRequirement already satisfied: pillow>=4.1.1 in c:\\users\\yunmao\\.conda\\envs\\ml\\lib\\site-packages (from torchvision) (8.1.0)\n"
]
}
],
"source": [
"!pip install torch torchvision"
]
@ -119,7 +127,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -160,21 +168,23 @@
"\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
" self.kernel_size = kernel_size\n",
" self.stride = stride\n",
" self.padding = padding\n",
"\n",
" self.w = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))\n",
" self.w.data.normal_(-0.1, 0.1)\n",
" if type(kernel_size) is tuple :\n",
" self.kernel_size = kernel_size\n",
" else:\n",
" self.kernel_size = (kernel_size, kernel_size)\n",
"\n",
" # Good practice is to start your weights in the range of [-y, y] where y=1/sqrt(n) (n is the number of inputs to a given neuron).\n",
" self.w = torch.randn((out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad = True) * torch.sqrt(torch.tensor(1.0/in_channels))\n",
"\n",
" if bias:\n",
" self.b = nn.Parameter(torch.Tensor(outchannel, ))\n",
" self.b.data.normal_(-0.1, 0.1)\n",
" self.b = torch.randn(out_channels, requires_grad = True) * torch.sqrt(torch.tensor(1.0/in_channels))\n",
" else:\n",
" self.b = None\n",
"\n",
"\n",
" self.bias = bias\n",
"\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
@ -194,14 +204,30 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
" #Calculate the output shape\n",
" out_H = (x.shape[2] + 2*self.padding - self.kernel_size[0]) // self.stride + 1\n",
" out_W = (x.shape[3] + 2*self.padding - self.kernel_size[1]) // self.stride + 1\n",
"\n",
" x_unfold = F.unfold(x, kernel_size = self.kernel_size, padding = self.padding, stride = self.stride)\n",
"\n",
" if self.bias:\n",
" out_unfold = x_unfold.transpose(1, 2).matmul(self.w.view(self.w.size(0), -1).t()).transpose(1, 2) + self.b.view(-1, 1)\n",
" else:\n",
" out_unfold = x_unfold.transpose(1, 2).matmul(self.w.view(self.w.size(0), -1).t()).transpose(1, 2)\n",
" out = out_unfold.view(x.shape[0], self.out_channels, out_H, out_W)\n",
"\n",
" # print(nn.functional.conv2d(x, self.w, self.b, padding=1))\n",
"\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
" # END OF YOUR CODE #\n",
" ########################################################################\n",
"\n",
" return out"
" return out\n",
"# inputs = torch.rand(3, 3, 24, 24)\n",
"# conv2 = Conv2d(in_channels=3, out_channels=3, kernel_size=(3, 3),stride=1, padding=1)\n",
"# out2 = conv2.forward(inputs)\n",
"# print(out2)"
]
},
{
@ -224,7 +250,10 @@
" ########################################################################\n",
" # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
"\n",
" pass\n",
" if type(kernel_size) is tuple :\n",
" self.kernel_size = kernel_size\n",
" else:\n",
" self.kernel_size = (kernel_size, kernel_size)\n",
"\n",
" # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****\n",
" ########################################################################\n",
@ -791,9 +820,13 @@
],
"metadata": {
"kernelspec": {
"display_name": "pyg",
"language": "python",
"name": "pyg"
"name": "python3",
"display_name": "Python 3.9.1 64-bit ('ml': conda)",
"metadata": {
"interpreter": {
"hash": "4387a1793e0dcd9f77e6d87a719a21277afc2df44face861a4961f7859ab8d21"
}
}
},
"language_info": {
"codemirror_mode": {
@ -805,7 +838,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.9.1-final"
}
},
"nbformat": 4,

@ -0,0 +1,112 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# Unfold 函数"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,\n 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.,\n 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35.,\n 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.,\n 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.,\n 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71.,\n 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83.,\n 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95.,\n 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., 107.,\n 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119.,\n 120., 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,\n 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143.,\n 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154., 155.,\n 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., 166., 167.,\n 168., 169., 170., 171., 172., 173., 174., 175., 176., 177., 178., 179.,\n 180., 181., 182., 183., 184., 185., 186., 187., 188., 189., 190., 191.,\n 192., 193., 194., 195., 196., 197., 198., 199., 200., 201., 202., 203.,\n 204., 205., 206., 207., 208., 209., 210., 211., 212., 213., 214., 215.,\n 216., 217., 218., 219., 220., 221., 222., 223., 224., 225., 226., 227.,\n 228., 229., 230., 231., 232., 233., 234., 235., 236., 237., 238., 239.,\n 240., 241., 242., 243., 244., 245., 246., 247., 248., 249., 250., 251.,\n 252., 253., 254., 255., 256., 257., 258., 259., 260., 261., 262., 263.,\n 264., 265., 266., 267., 268., 269., 270., 271., 272., 273., 274., 275.,\n 276., 277., 278., 279., 280., 281., 282., 283., 284., 285., 286., 287.,\n 288., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299.,\n 300., 301., 302., 303., 304., 305., 306., 307., 308., 309., 310., 311.,\n 312., 313., 314., 315., 316., 317., 318., 319., 320., 321., 322., 323.,\n 324., 325., 326., 327., 328., 329., 330., 331., 332., 333., 334., 335.,\n 336., 337., 338., 339., 340., 341., 342., 343., 344., 345., 346., 347.,\n 348., 349., 350., 351., 352., 353., 354., 355., 356., 357., 358., 359.,\n 360., 361., 362., 363., 364., 365., 366., 367., 368., 369., 370., 371.,\n 372., 373., 374., 375., 376., 377., 378., 379., 380., 381., 382., 383.,\n 384., 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,\n 396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406., 407.,\n 408., 409., 410., 411., 412., 413., 414., 415., 416., 417., 418., 419.,\n 420., 421., 422., 423., 424., 425., 426., 427., 428., 429., 430., 431.,\n 432., 433., 434., 435., 436., 437., 438., 439., 440., 441., 442., 443.,\n 444., 445., 446., 447., 448., 449., 450., 451., 452., 453., 454., 455.,\n 456., 457., 458., 459., 460., 461., 462., 463., 464., 465., 466., 467.,\n 468., 469., 470., 471., 472., 473., 474., 475., 476., 477., 478., 479.,\n 480., 481., 482., 483., 484., 485., 486., 487., 488., 489., 490., 491.,\n 492., 493., 494., 495., 496., 497., 498., 499., 500., 501., 502., 503.,\n 504., 505., 506., 507., 508., 509., 510., 511., 512., 513., 514., 515.,\n 516., 517., 518., 519., 520., 521., 522., 523., 524., 525., 526., 527.,\n 528., 529., 530., 531., 532., 533., 534., 535., 536., 537., 538., 539.,\n 540., 541., 542., 543., 544., 545., 546., 547., 548., 549., 550., 551.,\n 552., 553., 554., 555., 556., 557., 558., 559., 560., 561., 562., 563.,\n 564., 565., 566., 567., 568., 569., 570., 571., 572., 573., 574., 575.,\n 576., 577., 578., 579., 580., 581., 582., 583., 584., 585., 586., 587.,\n 588., 589., 590., 591., 592., 593., 594., 595., 596., 597., 598., 599.,\n 600., 601., 602., 603., 604., 605., 606., 607., 608., 609., 610., 611.,\n 612., 613., 614., 615., 616., 617., 618., 619., 620., 621., 622., 623.,\n 624., 625., 626., 627., 628., 629., 630., 631., 632., 633., 634., 635.,\n 636., 637., 638., 639., 640., 641., 642., 643., 644., 645., 646., 647.,\n 648., 649., 650., 651., 652., 653., 654., 655., 656., 657., 658., 659.,\n 660., 661., 662., 663., 664., 665., 666., 667., 668., 669., 670., 671.,\n 672., 673., 674.])\n"
]
}
],
"source": [
"import torch\n",
"from torch.nn import functional as f\n",
"\n",
"x = torch.arange(0, 1*3*15*15).float()\n",
"print(x)\n",
"\n"
]
},
{
"source": [
"Unfold函数的输入数据是思维的但输出是三维的。假设输入数据是\\[B, C, H, W\\], 那么输出的数据是\\[B, C\\*kH\\*KW, L\\]\n",
"\n",
"L = (H - kH + 1) \\* (W - kW + 1)"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[ 0., 1., 2., ..., 190., 191., 192.],\n [ 1., 2., 3., ..., 191., 192., 193.],\n [ 2., 3., 4., ..., 192., 193., 194.],\n ...,\n [480., 481., 482., ..., 670., 671., 672.],\n [481., 482., 483., ..., 671., 672., 673.],\n [482., 483., 484., ..., 672., 673., 674.]]])\ntorch.Size([1, 27, 169])\ntensor([[[ 0., 1., 2., ..., 480., 481., 482.],\n [ 1., 2., 3., ..., 481., 482., 483.],\n [ 2., 3., 4., ..., 482., 483., 484.],\n ...,\n [190., 191., 192., ..., 670., 671., 672.],\n [191., 192., 193., ..., 671., 672., 673.],\n [192., 193., 194., ..., 672., 673., 674.]]])\n"
]
}
],
"source": [
"x = x.view(1,3,15,15)\n",
"x1 = f.unfold(x, kernel_size=3, dilation=1, stride=1)\n",
"print(x1)\n",
"print(x1.shape)\n",
"print(x1.permute(0, 2, 1))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[[[ 0., 1., 2.],\n [ 15., 16., 17.],\n [ 30., 31., 32.]],\n\n [[225., 226., 227.],\n [240., 241., 242.],\n [255., 256., 257.]],\n\n [[450., 451., 452.],\n [465., 466., 467.],\n [480., 481., 482.]]],\n\n\n [[[ 1., 2., 3.],\n [ 16., 17., 18.],\n [ 31., 32., 33.]],\n\n [[226., 227., 228.],\n [241., 242., 243.],\n [256., 257., 258.]],\n\n [[451., 452., 453.],\n [466., 467., 468.],\n [481., 482., 483.]]],\n\n\n [[[ 2., 3., 4.],\n [ 17., 18., 19.],\n [ 32., 33., 34.]],\n\n [[227., 228., 229.],\n [242., 243., 244.],\n [257., 258., 259.]],\n\n [[452., 453., 454.],\n [467., 468., 469.],\n [482., 483., 484.]]],\n\n\n ...,\n\n\n [[[190., 191., 192.],\n [205., 206., 207.],\n [220., 221., 222.]],\n\n [[415., 416., 417.],\n [430., 431., 432.],\n [445., 446., 447.]],\n\n [[640., 641., 642.],\n [655., 656., 657.],\n [670., 671., 672.]]],\n\n\n [[[191., 192., 193.],\n [206., 207., 208.],\n [221., 222., 223.]],\n\n [[416., 417., 418.],\n [431., 432., 433.],\n [446., 447., 448.]],\n\n [[641., 642., 643.],\n [656., 657., 658.],\n [671., 672., 673.]]],\n\n\n [[[192., 193., 194.],\n [207., 208., 209.],\n [222., 223., 224.]],\n\n [[417., 418., 419.],\n [432., 433., 434.],\n [447., 448., 449.]],\n\n [[642., 643., 644.],\n [657., 658., 659.],\n [672., 673., 674.]]]]])\n"
]
}
],
"source": [
"B, C_kh_kw, L = x1.size()\n",
"x1 = x1.permute(0, 2, 1)\n",
"x1 = x1.view(B, L, -1, 3, 3)\n",
"print(x1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
]
}
Loading…
Cancel
Save