前回、PyTorchで手書き数字の識別する畳み込みニューラルネットワーク(CNN)を試しました。モデルは元祖CNNであるLeNet形で、今回はその畳み込み演算をPythonで書いてみました。
外枠だけニューラルネットワークの形にしてはいますが、画像の1画素をニューロンとみなして畳み込み演算のフィルタを機械学習可能にしたCNNのコーディングとは、たぶん別物です。
元々の目的がPyTorchで学習したCNNモデルのパラメータ使ってCythonで自前のライブラリを作ることで、今回は理解している処理が正しいのか確認するための勉強コーディングになります。
そのため学習(逆伝播)を実装するつもりがないので、一般的(だと思う)な畳み込み演算をそのままコーディングしています。おそらく学習するんであれば、ニューロンのクラス化等々、相応のやり方があるんだろうなぁと想像します。
ここからCythonで高速化することが、本番ですが、畳み込み演算の実際の処理を理解したい人とか、CNNの処理の流れ理解したい人とかには役立つかなぁと思って投稿します。まぁ、ホントはなんの書物等の手助けもなく出来たことが嬉しかったんで投稿しとこうってだけなのですが^^;
追記:Cython化しました
import math,random,copy,time
class Conv2d:
def __init__(self,in_channels, out_channels, kernel_size):
self.in_channels = ch0 = in_channels
self.out_channels = ch1 = out_channels
self.kernel_size = sz = kernel_size
self.filters = multi_list(ch1,ch0,sz,sz)
self.bias = multi_list(ch1)
self.input = None
def __call__(self,data_3d):
self.set_input(data_3d)
return self.feature_maps()
def set_input(self,data_3d):
ch = len(data_3d) # = in_channels
self.input_h = h = len(data_3d[0])
self.input_w = w = len(data_3d[0][0])
self.input = multi_list(ch,h,w)
for ch,lines in enumerate(data_3d):
for h,line in enumerate(lines):
for w,value in enumerate(line):
self.input[ch][h][w] = float(value)
def set_filters(self,data_4d):
for n,data in enumerate(data_4d):
for ch,lines in enumerate(data):
for h,line in enumerate(lines):
for w,value in enumerate(line):
self.filters[n][ch][h][w] = float(value)
def set_bias(self,data_1d):
for i in range(self.out_channels):
self.bias[i] = float(data_1d[i])
def feature_maps(self):
in_ch = self.in_channels
out_ch = self.out_channels
size = self.kernel_size
map_w = self.input_w - size + 1
map_h = self.input_h - size + 1
fmaps = multi_list(out_ch,map_h,map_w)
for ch1 in range(self.out_channels):
for i in range(map_h):
for j in range(map_w):
v = 0.0
for ch0 in range(self.in_channels):
for k in range(size):
for l in range(size):
y = i + k
x = j + l
weight = self.filters[ch1][ch0][k][l]
v += self.input[ch0][y][x] * weight
v += self.bias[ch1]
fmaps[ch1][i][j] = v
return fmaps
def multi_list(*shape):
dim = len(shape)
data = [0 for j in range(shape[dim-1])]
for i in range(2,dim+1):
data = [ copy.deepcopy(data) for j in range(shape[dim-i])]
return data
def shape(data):
elem_nums = []
while True:
try:
elem_nums.append(len(data))
data = data[0]
except:
break
return elem_nums
def trans3dto1d(data):
new_data = []
for lines in data:
for line in lines:
for value in line:
new_data.append(value)
return new_data
def max_pool2d(data,kernel_size,stride):
in_ch = len(data)
in_h = len(data[0])
in_w = len(data[0][0])
out_h = math.ceil((in_h - (kernel_size-1))/stride)
out_w = math.ceil((in_w - (kernel_size-1))/stride)
out_data = multi_list(in_ch,out_h,out_w)
for ch in range(in_ch):
for i in range(out_h):
for j in range(out_w):
v_max = -math.inf
for k in range(kernel_size):
for l in range(kernel_size):
y = i*stride + k
x = j*stride + l
v = data[ch][y][x]
if v > v_max:
v_max = v
out_data[ch][i][j] = v_max
return out_data
def relu(data):
dim = len(shape(data))
if dim == 1:
for i,value in enumerate(data):
if value < 0.0:
data[i] = 0.0
return data
elif dim == 3:
for ch,lines in enumerate(data):
for h,line in enumerate(lines):
for w,value in enumerate(line):
value = data[ch][h][w]
if value < 0.0:
data[ch][h][w] = 0.0
return data
else:
print('Error: Unsupported dimension. dim=%d' % dim)
exit()
def log_softmax(data_1d):
xmax = max(data_1d)
e_sum = 0.0
e_xs = []
for x in data_1d:
e_x = math.exp(x - xmax)
e_xs.append(e_x)
e_sum += e_x
result = []
for e_x in e_xs:
result.append(e_x/e_sum)
return result
class Linear:
def __init__(self,in_features,out_features):
self.in_features = in_features
self.out_features = out_features
self.values = multi_list(out_features)
self.bias = multi_list(out_features)
self.weights = multi_list(out_features,in_features)
def set_weights(self,data_2d):
for i in range(self.out_features):
for j in range(self.in_features):
self.weights[i][j] = float(data_2d[i][j])
def set_bias(self,data_1d):
for i in range(self.out_features):
self.bias[i] = float(data_1d[i])
def __call__(self,data_1d):
values = []
for i in range(self.out_features):
value = 0.0
for j in range(self.in_features):
pvalue = float(data_1d[j])
value += pvalue * self.weights[i][j]
value += self.bias[i]
values.append(value)
return values
class Network:
def __init__(self):
self.conv1 = Conv2d(1, 20, 5)
self.conv2 = Conv2d(20, 50, 5)
self.fc1 = Linear(4*4*50, 500)
self.fc2 = Linear(500, 10)
def __call__(self,x):
return self.forward(x)
def set_param(self,state):
self.conv1.set_filters(state['conv1.weight'])
self.conv1.set_bias(state['conv1.bias'])
self.conv2.set_filters(state['conv2.weight'])
self.conv2.set_bias(state['conv2.bias'])
self.fc1.set_weights(state['fc1.weight'])
self.fc1.set_bias(state['fc1.bias'])
self.fc2.set_weights(state['fc2.weight'])
self.fc2.set_bias(state['fc2.bias'])
def forward(self, x):
x = relu(self.conv1(x))
x = max_pool2d(x, 2, 2)
x = relu(self.conv2(x))
x = max_pool2d(x, 2, 2)
x = trans3dto1d(x) #x = x.view(-1, 4*4*50)
x = relu(self.fc1(x))
x = self.fc2(x)
return log_softmax(x)
def max_ch(self,data_1d):
max_value = -math.inf
ch = 0
for i,value in enumerate(data_1d):
if value > max_value:
max_value = value
ch = i
return ch
def main():
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from PIL import Image, ImageOps
class MyLeNet(torch.nn.Module):
def __init__(self):
super(MyLeNet, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5)
self.conv2 = torch.nn.Conv2d(20, 50, 5)
self.fc1 = torch.nn.Linear(4*4*50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
start_time = time.time()
mynet = Network()
torch_cnn = MyLeNet()
torch_cnn.load_state_dict(torch.load('my_mnist_model.pth'))
state = torch_cnn.state_dict()
mynet.set_param(state)
print('set state!')
set_time = time.time() - start_time
start_time = time.time()
image_path = '../img4/0-0.png'
img = Image.open(image_path)
img = img.convert('L').resize((28,28))
img = ImageOps.invert(img)
trans = transforms.Compose([transforms.ToTensor()])
data = trans(img)
output = mynet(data)
result = mynet.max_ch(output)
print('result',result,image_path)
calc_time = time.time() - start_time
print('set_time',set_time)
print('calc_time',calc_time)
if __name__ == '__main__':
main()
