-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathstyle_subnet.py
More file actions
111 lines (86 loc) · 3.62 KB
/
Copy pathstyle_subnet.py
File metadata and controls
111 lines (86 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from layer_utils import *
# dimensions of image [batch_size, channels, height, width]
class StyleSubnet(nn.Module):
def __init__(self):
super(StyleSubnet, self).__init__()
# Bilinear downsampling
#self.downsample = nn.Upsample(size=256, mode='bilinear')
# Transform to Grayscale
self.togray = nn.Conv2d(3, 1, kernel_size=1, stride=1)
w = torch.nn.Parameter(torch.tensor([[[[0.299]],
[[0.587]],
[[0.114]]]]))
self.togray.weight = w
# RGB Block
self.rgb_conv1 = ConvLayer(3, 16, kernel_size=9, stride=1)
self.rgb_in1 = InstanceNormalization(16)
self.rgb_conv2 = ConvLayer(16, 32, kernel_size=3, stride=2)
self.rgb_in2 = InstanceNormalization(32)
self.rgb_conv3 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.rgb_in3 = InstanceNormalization(64)
self.rgb_res1 = ResidualBlock(64)
self.rgb_res2 = ResidualBlock(64)
self.rgb_res3 = ResidualBlock(64)
# L Block
self.l_conv1 = ConvLayer(1, 16, kernel_size=9, stride=1)
self.l_in1 = InstanceNormalization(16)
self.l_conv2 = ConvLayer(16, 32, kernel_size=3, stride=2)
self.l_in2 = InstanceNormalization(32)
self.l_conv3 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.l_in3 = InstanceNormalization(64)
self.l_res1 = ResidualBlock(64)
self.l_res2 = ResidualBlock(64)
self.l_res3 = ResidualBlock(64)
# Residual layers
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
self.res6 = ResidualBlock(128)
# Upsampling Layers
self.rezconv1 = ResizeConvLayer(128, 64, kernel_size=3, stride=1)
self.in4 = InstanceNormalization(64)
self.rezconv2 = ResizeConvLayer(64, 32, kernel_size=3, stride=1)
self.in5 = InstanceNormalization(32)
self.rezconv3 = ConvLayer(32, 3, kernel_size=3, stride=1)
# Non-linearities
self.relu = nn.ReLU()
def forward(self, x):
# Bilinear downsampling
#x = self.downsample(x)
# Resized input image is the content target
resized_input_img = x.clone()
# Get RGB and L image
x_rgb = x
with torch.no_grad(): x_l = self.togray(x.clone())
# RGB Block
y_rgb = self.relu(self.rgb_in1(self.rgb_conv1(x_rgb)))
y_rgb = self.relu(self.rgb_in2(self.rgb_conv2(y_rgb)))
y_rgb = self.relu(self.rgb_in3(self.rgb_conv3(y_rgb)))
y_rgb = self.rgb_res1(y_rgb)
y_rgb = self.rgb_res2(y_rgb)
y_rgb = self.rgb_res3(y_rgb)
# L Block
y_l = self.relu(self.l_in1(self.l_conv1(x_l)))
y_l = self.relu(self.l_in2(self.l_conv2(y_l)))
y_l = self.relu(self.l_in3(self.l_conv3(y_l)))
y_l = self.l_res1(y_l)
y_l = self.l_res2(y_l)
y_l = self.l_res3(y_l)
# Concatenate blocks along the depth dimension
y = torch.cat((y_rgb, y_l), 1)
# Residuals
y = self.res4(y)
y = self.res5(y)
y = self.res6(y)
# Decoding
y = self.relu(self.in4(self.rezconv1(y)))
y = self.relu(self.in5(self.rezconv2(y)))
y = self.rezconv3(y)
# Clamp image to be in range [0,1] after denormalization
y[0][0].clamp_((0-0.485)/0.299, (1-0.485)/0.299)
y[0][1].clamp_((0-0.456)/0.224, (1-0.456)/0.224)
y[0][2].clamp_((0-0.406)/0.225, (1-0.406)/0.225)
return y, resized_input_img