PyTorchで作成した手書き数字認識プログラムを手軽に使えるようにライブラリ化できないかなぁっという思いで、以前、PyTorchで作成した全結合のニューラルネットワークの学習したパラメータを、Cythonで作成した自前のニューラルネットワークに読み込ませるといったことをしました。
今回は、それを畳み込みニューラルネットワーク(CNN)でやってみました。
元となるPyTorchのモデル
class MyNet(torch.nn.Module): def __init__(self): super(MyNet, self).__init__() self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.fc1 = torch.nn.Linear(4*4*50, 500) self.fc2 = torch.nn.Linear(500, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)
参照元:
PyTorchでシンプルな畳み込みニューラルネットワークを作ろう
https://qiita.com/sudamasahiko/items/fd6a52f958f3f9013f0f
以前、このモデルで、手書き数字の判別をさせてみると正解率95%以上という結果でした。
そこで、前回、勉強がてら、PyTorchで学習したパラメータを利用して、生Pythonで実装した畳み込み演算で画像判別をするということを行いました。今回、これをCython化して高速化しました。
結果、生Pythonでは、1文字1.609秒かかる処理が0.062秒になり、およそ25倍高速化しました。まずまずですが、まだ少し物足りない・・・かな。
過程を書くと、
・1文字もコードを変えずにCython化すると0.929秒(1.7倍)
・クラスをcdef化し、def内整数と実数をcdefで宣言すると0.311秒(5.2倍)
・Cython内のみで使用する配列をポインタ化すると0.062秒(26.0倍)
となり、ポインタ化がすごく効いてきます。
今回のソースは、PyTorchでのモデル定義をなるべくそのまま使えるような形を意識しており、Python側でモデルを定義できます。それゆえ、各関数の出力データは、生Pythonで扱えるlistオブジェクトで返してます。そのlistの生成に時間がかかっている気がします。
ちなみにlistの代わりにnumpyの配列(メモリービュー)も試しましたが、あまり速くならず、というかlistより遅くなりました。listと違い、要素の型を指定するので高速化するはずなのですが、処理ごとに配列を生成する書き方をしているので、生成に時間がかかってしまっているのかなと。モジュールをインポートした際に、必要な配列を用意しておく書き方にすれば、たぶんもうちょっと高速化するのだろうけど、今回はここまでにします。
そもそも、手書き数字認識ライブラリを作ることを目的とするなら、モデルの定義をPython側で柔軟にできるようにする必要もなく、Cython側で固定してしまって、全てのデータを配列なりポインタなりで扱えば、もっと大幅に高速化するはずです。その方向性でのを次回試すことにします。
追記:さらに約6倍高速化しました。
前処理:PyTorchモデルのパラメータを保存するスクリプト
import pickle import torch import torch.nn.functional as F class MyLeNet(torch.nn.Module): def __init__(self): super(MyLeNet, self).__init__() self.conv1 = torch.nn.Conv2d(1, 20, 5) self.conv2 = torch.nn.Conv2d(20, 50, 5) self.fc1 = torch.nn.Linear(4*4*50, 500) self.fc2 = torch.nn.Linear(500, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) def main(): torch_cnn = MyLeNet() torch_cnn.load_state_dict(torch.load('my_mnist_model.pth')) state = torch_cnn.state_dict() new_state = {} state_names = ['conv1.weight','conv1.bias', 'conv2.weight','conv2.bias', 'fc1.weight','fc1.bias', 'fc2.weight','fc2.bias'] for sn in state_names: new_state[sn] = to_float(state[sn]) f=open('param.dat','wb') pickle.dump(new_state,f) f.close() def to_float(data): ls = [] for d in data: if is_number(d): ls.append(float(d)) else: ls.append(to_float(d)) return ls def is_number(data): try: float(data) return True except: return False if __name__ == '__main__': main()
Python:自前CNNモジュールで数字判別するスクリプト
import math,time,pickle from PIL import Image, ImageOps import fixed_cnn as nn class Network: def __init__(self): self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 50, 5) self.fc1 = nn.Linear(4*4*50, 500) self.fc2 = nn.Linear(500, 10) def __call__(self,x): return self.forward(x) def set_param(self,state): self.conv1.set_filters(state['conv1.weight']) self.conv1.set_bias(state['conv1.bias']) self.conv2.set_filters(state['conv2.weight']) self.conv2.set_bias(state['conv2.bias']) self.fc1.set_weights(state['fc1.weight']) self.fc1.set_bias(state['fc1.bias']) self.fc2.set_weights(state['fc2.weight']) self.fc2.set_bias(state['fc2.bias']) def forward(self, x): x = nn.relu(self.conv1(x)) x = nn.max_pool2d(x, 2, 2) x = nn.relu(self.conv2(x)) x = nn.max_pool2d(x, 2, 2) x = nn.trans3dto1d(x) #x = x.view(-1, 4*4*50) x = nn.relu(self.fc1(x)) x = self.fc2(x) return nn.log_softmax(x) def max_ch(self,data_1d): max_value = -math.inf ch = 0 for i,value in enumerate(data_1d): if value > max_value: max_value = value ch = i return ch def main(): start_time = time.time() mynet = Network() f = open('param.dat','rb') state = pickle.load(f) f.close() mynet.set_param(state) set_time = time.time() - start_time start_time = time.time() for i in range(10): image_path = '../img4/%s-0.png' % i img = Image.open(image_path) img = img.convert('L').resize((28,28)) img = ImageOps.invert(img) data = nn.to_data(img) output = mynet(data) result = mynet.max_ch(output) print('result',result,image_path) calc_time = (time.time() - start_time)/10.0 print('set_time',set_time) print('calc_time',calc_time) if __name__ == '__main__': main()
自前CNNモジュール
ダウンロードはこちらから。
(動作確認環境:Windows 64bit、Python3.8)
ソース
Cython:fixed_cnn.pyx
import math,copy from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free cdef class Conv2d: cdef int in_channels cdef int out_channels cdef int kernel_size cdef int input_w cdef int input_h cdef double ****filters cdef double *bias cdef double ***input def __init__(self,in_channels, out_channels, kernel_size): cdef int i,j,k self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.filters = <double****> PyMem_Malloc(out_channels * sizeof(double*)) for i in range(out_channels): self.filters[i] = <double***> PyMem_Malloc(in_channels * sizeof(double*)) for j in range(in_channels): self.filters[i][j] = <double**> PyMem_Malloc(kernel_size * sizeof(double*)) for k in range(kernel_size): self.filters[i][j][k] = <double*> PyMem_Malloc(kernel_size * sizeof(double)) self.bias = <double*> PyMem_Malloc(out_channels * sizeof(double)) def __call__(self,data_3d): self.set_input(data_3d) return self.feature_maps() def set_input(self,data_3d): cdef int i,j,k self.input_h = len(data_3d[0]) self.input_w = len(data_3d[0][0]) self.input = <double***> PyMem_Malloc(self.in_channels * sizeof(double*)) for i in range(self.in_channels): self.input[i] = <double**> PyMem_Malloc(self.input_h * sizeof(double*)) for j in range(self.input_h): self.input[i][j] = <double*> PyMem_Malloc(self.input_w * sizeof(double)) for i in range(self.in_channels): for j in range(self.input_h): for k in range(self.input_w): self.input[i][j][k] = data_3d[i][j][k] def set_filters(self,data_4d): cdef int i,j,k,l for i in range(self.out_channels): for j in range(self.in_channels): for k in range(self.kernel_size): for l in range(self.kernel_size): self.filters[i][j][k][l] = data_4d[i][j][k][l] def set_bias(self,data_1d): cdef int i for i in range(self.out_channels): self.bias[i] = float(data_1d[i]) def feature_maps(self): cdef int map_h,map_w,ch0,ch1,i,j,k,l,x,y cdef double v,weight map_h = self.input_h - self.kernel_size + 1 map_w = self.input_w - self.kernel_size + 1 fmaps = multi_list(self.out_channels,map_h,map_w) for ch1 in range(self.out_channels): for i in range(map_h): for j in range(map_w): v = 0.0 for ch0 in range(self.in_channels): for k in range(self.kernel_size): for l in range(self.kernel_size): y = i + k x = j + l weight = self.filters[ch1][ch0][k][l] v += self.input[ch0][y][x] * weight v += self.bias[ch1] fmaps[ch1][i][j] = v return fmaps def multi_list(*element_nums): cdef int dim,i,j dim = len(element_nums) data = [0 for j in range(element_nums[dim-1])] for i in range(2,dim+1): data = [ copy.deepcopy(data) for j in range(element_nums[dim-i])] return data def shape(data): elem_nums = [] while True: try: elem_nums.append(len(data)) data = data[0] except: break return elem_nums def trans3dto1d(data): new_data = [] for lines in data: for line in lines: for value in line: new_data.append(value) return new_data def max_pool2d(data,int kernel_size,int stride): cdef int in_ch,in_h,in_w,out_h,out_w,ch,i,j,k,l,x,y cdef double v,v_max,inf in_ch = len(data) in_h = len(data[0]) in_w = len(data[0][0]) out_h = math.ceil((in_h - (kernel_size-1))/stride) out_w = math.ceil((in_w - (kernel_size-1))/stride) out_data = multi_list(in_ch,out_h,out_w) for ch in range(in_ch): for i in range(out_h): for j in range(out_w): v_max = 0.0 for k in range(kernel_size): for l in range(kernel_size): y = i*stride + k x = j*stride + l v = data[ch][y][x] if v > v_max: v_max = v out_data[ch][i][j] = v_max return out_data def relu(data): cdef int dim,i,j,k cdef double value shp = shape(data) dim = len(shp) if dim == 1: for i in range(shp[0]): value = data[i] if value < 0.0: data[i] = 0.0 return data elif dim == 3: for i in range(shp[0]): for j in range(shp[1]): for k in range(shp[2]): value = data[i][j][k] if value < 0.0: data[i][j][k] = 0.0 return data else: print('Error: Unsupported dimension. dim=%d' % dim) exit() def log_softmax(data_1d): cdef double xmax,e_sum,e_x,x xmax = max(data_1d) e_sum = 0.0 e_xs = [] for x in data_1d: e_x = math.exp(x - xmax) e_xs.append(e_x) e_sum += e_x result = [] for e_x in e_xs: result.append(e_x/e_sum) return result cdef class Linear: cdef int in_features cdef int out_features cdef list values cdef double *bias cdef double **weights def __init__(self,in_features,out_features): cdef int i self.in_features = in_features self.out_features = out_features self.values = multi_list(out_features) #self.bias = multi_list(out_features) #self.weights = multi_list(out_features,in_features) self.bias = <double*> PyMem_Malloc(out_features * sizeof(double)) self.weights = <double**> PyMem_Malloc(out_features * sizeof(double*)) for i in range(out_features): self.weights[i] = <double*> PyMem_Malloc(in_features * sizeof(double)) def set_weights(self,data_2d): cdef int i,j for i in range(self.out_features): for j in range(self.in_features): self.weights[i][j] = data_2d[i][j] def set_bias(self,data_1d): cdef int i for i in range(self.out_features): self.bias[i] = data_1d[i] def __call__(self,data_1d): cdef int i,j cdef double value,pvalue values = [] for i in range(self.out_features): value = 0.0 for j in range(self.in_features): pvalue = data_1d[j] value += pvalue * self.weights[i][j] value += self.bias[i] values.append(value) return values def to_data(pil_img): cdef int i,j,w,h w,h = pil_img.size new_data = multi_list(1,h,w) for i in range(w): for j in range(h): new_data[0][j][i] = pil_img.getpixel((i,j)) return new_data