PyTorchで作成した手書き数字認識プログラムを手軽に使えるようにライブラリ化できないかなぁっという思いで、以前、PyTorchで作成した全結合のニューラルネットワークの学習したパラメータを、Cythonで作成した自前のニューラルネットワークに読み込ませるといったことをしました。
今回は、それを畳み込みニューラルネットワーク(CNN)でやってみました。
元となるPyTorchのモデル
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4*4*50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
参照元:
PyTorchでシンプルな畳み込みニューラルネットワークを作ろう
https://qiita.com/sudamasahiko/items/fd6a52f958f3f9013f0f
以前、このモデルで、手書き数字の判別をさせてみると正解率95%以上という結果でした。
そこで、前回、勉強がてら、PyTorchで学習したパラメータを利用して、生Pythonで実装した畳み込み演算で画像判別をするということを行いました。今回、これをCython化して高速化しました。
結果、生Pythonでは、1文字1.609秒かかる処理が0.062秒になり、およそ25倍高速化しました。まずまずですが、まだ少し物足りない・・・かな。
過程を書くと、
・1文字もコードを変えずにCython化すると0.929秒(1.7倍)
・クラスをcdef化し、def内整数と実数をcdefで宣言すると0.311秒(5.2倍)
・Cython内のみで使用する配列をポインタ化すると0.062秒(26.0倍)
となり、ポインタ化がすごく効いてきます。
今回のソースは、PyTorchでのモデル定義をなるべくそのまま使えるような形を意識しており、Python側でモデルを定義できます。それゆえ、各関数の出力データは、生Pythonで扱えるlistオブジェクトで返してます。そのlistの生成に時間がかかっている気がします。
ちなみにlistの代わりにnumpyの配列(メモリービュー)も試しましたが、あまり速くならず、というかlistより遅くなりました。listと違い、要素の型を指定するので高速化するはずなのですが、処理ごとに配列を生成する書き方をしているので、生成に時間がかかってしまっているのかなと。モジュールをインポートした際に、必要な配列を用意しておく書き方にすれば、たぶんもうちょっと高速化するのだろうけど、今回はここまでにします。
そもそも、手書き数字認識ライブラリを作ることを目的とするなら、モデルの定義をPython側で柔軟にできるようにする必要もなく、Cython側で固定してしまって、全てのデータを配列なりポインタなりで扱えば、もっと大幅に高速化するはずです。その方向性でのを次回試すことにします。
追記:さらに約6倍高速化しました。
前処理:PyTorchモデルのパラメータを保存するスクリプト
import pickle
import torch
import torch.nn.functional as F
class MyLeNet(torch.nn.Module):
def __init__(self):
super(MyLeNet, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5)
self.conv2 = torch.nn.Conv2d(20, 50, 5)
self.fc1 = torch.nn.Linear(4*4*50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def main():
torch_cnn = MyLeNet()
torch_cnn.load_state_dict(torch.load('my_mnist_model.pth'))
state = torch_cnn.state_dict()
new_state = {}
state_names = ['conv1.weight','conv1.bias',
'conv2.weight','conv2.bias',
'fc1.weight','fc1.bias',
'fc2.weight','fc2.bias']
for sn in state_names:
new_state[sn] = to_float(state[sn])
f=open('param.dat','wb')
pickle.dump(new_state,f)
f.close()
def to_float(data):
ls = []
for d in data:
if is_number(d):
ls.append(float(d))
else:
ls.append(to_float(d))
return ls
def is_number(data):
try:
float(data)
return True
except:
return False
if __name__ == '__main__':
main()
Python:自前CNNモジュールで数字判別するスクリプト
import math,time,pickle
from PIL import Image, ImageOps
import fixed_cnn as nn
class Network:
def __init__(self):
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 50, 5)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def __call__(self,x):
return self.forward(x)
def set_param(self,state):
self.conv1.set_filters(state['conv1.weight'])
self.conv1.set_bias(state['conv1.bias'])
self.conv2.set_filters(state['conv2.weight'])
self.conv2.set_bias(state['conv2.bias'])
self.fc1.set_weights(state['fc1.weight'])
self.fc1.set_bias(state['fc1.bias'])
self.fc2.set_weights(state['fc2.weight'])
self.fc2.set_bias(state['fc2.bias'])
def forward(self, x):
x = nn.relu(self.conv1(x))
x = nn.max_pool2d(x, 2, 2)
x = nn.relu(self.conv2(x))
x = nn.max_pool2d(x, 2, 2)
x = nn.trans3dto1d(x) #x = x.view(-1, 4*4*50)
x = nn.relu(self.fc1(x))
x = self.fc2(x)
return nn.log_softmax(x)
def max_ch(self,data_1d):
max_value = -math.inf
ch = 0
for i,value in enumerate(data_1d):
if value > max_value:
max_value = value
ch = i
return ch
def main():
start_time = time.time()
mynet = Network()
f = open('param.dat','rb')
state = pickle.load(f)
f.close()
mynet.set_param(state)
set_time = time.time() - start_time
start_time = time.time()
for i in range(10):
image_path = '../img4/%s-0.png' % i
img = Image.open(image_path)
img = img.convert('L').resize((28,28))
img = ImageOps.invert(img)
data = nn.to_data(img)
output = mynet(data)
result = mynet.max_ch(output)
print('result',result,image_path)
calc_time = (time.time() - start_time)/10.0
print('set_time',set_time)
print('calc_time',calc_time)
if __name__ == '__main__':
main()
自前CNNモジュール
ダウンロードはこちらから。
(動作確認環境:Windows 64bit、Python3.8)
ソース
Cython:fixed_cnn.pyx
import math,copy
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
cdef class Conv2d:
cdef int in_channels
cdef int out_channels
cdef int kernel_size
cdef int input_w
cdef int input_h
cdef double ****filters
cdef double *bias
cdef double ***input
def __init__(self,in_channels, out_channels, kernel_size):
cdef int i,j,k
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.filters = <double****> PyMem_Malloc(out_channels * sizeof(double*))
for i in range(out_channels):
self.filters[i] = <double***> PyMem_Malloc(in_channels * sizeof(double*))
for j in range(in_channels):
self.filters[i][j] = <double**> PyMem_Malloc(kernel_size * sizeof(double*))
for k in range(kernel_size):
self.filters[i][j][k] = <double*> PyMem_Malloc(kernel_size * sizeof(double))
self.bias = <double*> PyMem_Malloc(out_channels * sizeof(double))
def __call__(self,data_3d):
self.set_input(data_3d)
return self.feature_maps()
def set_input(self,data_3d):
cdef int i,j,k
self.input_h = len(data_3d[0])
self.input_w = len(data_3d[0][0])
self.input = <double***> PyMem_Malloc(self.in_channels * sizeof(double*))
for i in range(self.in_channels):
self.input[i] = <double**> PyMem_Malloc(self.input_h * sizeof(double*))
for j in range(self.input_h):
self.input[i][j] = <double*> PyMem_Malloc(self.input_w * sizeof(double))
for i in range(self.in_channels):
for j in range(self.input_h):
for k in range(self.input_w):
self.input[i][j][k] = data_3d[i][j][k]
def set_filters(self,data_4d):
cdef int i,j,k,l
for i in range(self.out_channels):
for j in range(self.in_channels):
for k in range(self.kernel_size):
for l in range(self.kernel_size):
self.filters[i][j][k][l] = data_4d[i][j][k][l]
def set_bias(self,data_1d):
cdef int i
for i in range(self.out_channels):
self.bias[i] = float(data_1d[i])
def feature_maps(self):
cdef int map_h,map_w,ch0,ch1,i,j,k,l,x,y
cdef double v,weight
map_h = self.input_h - self.kernel_size + 1
map_w = self.input_w - self.kernel_size + 1
fmaps = multi_list(self.out_channels,map_h,map_w)
for ch1 in range(self.out_channels):
for i in range(map_h):
for j in range(map_w):
v = 0.0
for ch0 in range(self.in_channels):
for k in range(self.kernel_size):
for l in range(self.kernel_size):
y = i + k
x = j + l
weight = self.filters[ch1][ch0][k][l]
v += self.input[ch0][y][x] * weight
v += self.bias[ch1]
fmaps[ch1][i][j] = v
return fmaps
def multi_list(*element_nums):
cdef int dim,i,j
dim = len(element_nums)
data = [0 for j in range(element_nums[dim-1])]
for i in range(2,dim+1):
data = [ copy.deepcopy(data) for j in range(element_nums[dim-i])]
return data
def shape(data):
elem_nums = []
while True:
try:
elem_nums.append(len(data))
data = data[0]
except:
break
return elem_nums
def trans3dto1d(data):
new_data = []
for lines in data:
for line in lines:
for value in line:
new_data.append(value)
return new_data
def max_pool2d(data,int kernel_size,int stride):
cdef int in_ch,in_h,in_w,out_h,out_w,ch,i,j,k,l,x,y
cdef double v,v_max,inf
in_ch = len(data)
in_h = len(data[0])
in_w = len(data[0][0])
out_h = math.ceil((in_h - (kernel_size-1))/stride)
out_w = math.ceil((in_w - (kernel_size-1))/stride)
out_data = multi_list(in_ch,out_h,out_w)
for ch in range(in_ch):
for i in range(out_h):
for j in range(out_w):
v_max = 0.0
for k in range(kernel_size):
for l in range(kernel_size):
y = i*stride + k
x = j*stride + l
v = data[ch][y][x]
if v > v_max:
v_max = v
out_data[ch][i][j] = v_max
return out_data
def relu(data):
cdef int dim,i,j,k
cdef double value
shp = shape(data)
dim = len(shp)
if dim == 1:
for i in range(shp[0]):
value = data[i]
if value < 0.0:
data[i] = 0.0
return data
elif dim == 3:
for i in range(shp[0]):
for j in range(shp[1]):
for k in range(shp[2]):
value = data[i][j][k]
if value < 0.0:
data[i][j][k] = 0.0
return data
else:
print('Error: Unsupported dimension. dim=%d' % dim)
exit()
def log_softmax(data_1d):
cdef double xmax,e_sum,e_x,x
xmax = max(data_1d)
e_sum = 0.0
e_xs = []
for x in data_1d:
e_x = math.exp(x - xmax)
e_xs.append(e_x)
e_sum += e_x
result = []
for e_x in e_xs:
result.append(e_x/e_sum)
return result
cdef class Linear:
cdef int in_features
cdef int out_features
cdef list values
cdef double *bias
cdef double **weights
def __init__(self,in_features,out_features):
cdef int i
self.in_features = in_features
self.out_features = out_features
self.values = multi_list(out_features)
#self.bias = multi_list(out_features)
#self.weights = multi_list(out_features,in_features)
self.bias = <double*> PyMem_Malloc(out_features * sizeof(double))
self.weights = <double**> PyMem_Malloc(out_features * sizeof(double*))
for i in range(out_features):
self.weights[i] = <double*> PyMem_Malloc(in_features * sizeof(double))
def set_weights(self,data_2d):
cdef int i,j
for i in range(self.out_features):
for j in range(self.in_features):
self.weights[i][j] = data_2d[i][j]
def set_bias(self,data_1d):
cdef int i
for i in range(self.out_features):
self.bias[i] = data_1d[i]
def __call__(self,data_1d):
cdef int i,j
cdef double value,pvalue
values = []
for i in range(self.out_features):
value = 0.0
for j in range(self.in_features):
pvalue = data_1d[j]
value += pvalue * self.weights[i][j]
value += self.bias[i]
values.append(value)
return values
def to_data(pil_img):
cdef int i,j,w,h
w,h = pil_img.size
new_data = multi_list(1,h,w)
for i in range(w):
for j in range(h):
new_data[0][j][i] = pil_img.getpixel((i,j))
return new_data
