PyTorchの畳み込みニューラルネットワークのライブラリ化の試み

PyTorchで作成した手書き数字認識プログラムを手軽に使えるようにライブラリ化できないかなぁっという思いで、以前、PyTorchで作成した全結合のニューラルネットワークの学習したパラメータを、Cythonで作成した自前のニューラルネットワークに読み込ませるといったことをしました。

今回は、それを畳み込みニューラルネットワーク(CNN)でやってみました。

 

元となるPyTorchのモデル

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4*4*50, 500)
        self.fc2 = torch.nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

参照元:
PyTorchでシンプルな畳み込みニューラルネットワークを作ろう
https://qiita.com/sudamasahiko/items/fd6a52f958f3f9013f0f

 

以前、このモデルで、手書き数字の判別をさせてみると正解率95%以上という結果でした。
そこで、前回、勉強がてら、PyTorchで学習したパラメータを利用して、生Pythonで実装した畳み込み演算で画像判別をするということを行いました。今回、これをCython化して高速化しました。

結果、生Pythonでは、1文字1.609秒かかる処理が0.062秒になり、およそ25倍高速化しました。まずまずですが、まだ少し物足りない・・・かな。

 

過程を書くと、
・1文字もコードを変えずにCython化すると0.929秒(1.7倍)
・クラスをcdef化し、def内整数と実数をcdefで宣言すると0.311秒(5.2倍)
・Cython内のみで使用する配列をポインタ化すると0.062秒(26.0倍)
となり、ポインタ化がすごく効いてきます。

 

今回のソースは、PyTorchでのモデル定義をなるべくそのまま使えるような形を意識しており、Python側でモデルを定義できます。それゆえ、各関数の出力データは、生Pythonで扱えるlistオブジェクトで返してます。そのlistの生成に時間がかかっている気がします。

 

ちなみにlistの代わりにnumpyの配列(メモリービュー)も試しましたが、あまり速くならず、というかlistより遅くなりました。listと違い、要素の型を指定するので高速化するはずなのですが、処理ごとに配列を生成する書き方をしているので、生成に時間がかかってしまっているのかなと。モジュールをインポートした際に、必要な配列を用意しておく書き方にすれば、たぶんもうちょっと高速化するのだろうけど、今回はここまでにします。

 

そもそも、手書き数字認識ライブラリを作ることを目的とするなら、モデルの定義をPython側で柔軟にできるようにする必要もなく、Cython側で固定してしまって、全てのデータを配列なりポインタなりで扱えば、もっと大幅に高速化するはずです。その方向性でのを次回試すことにします。
追記:さらに約6倍高速化しました。

 

 

前処理:PyTorchモデルのパラメータを保存するスクリプト

import pickle

import torch
import torch.nn.functional as F

class MyLeNet(torch.nn.Module):
    def __init__(self):
        super(MyLeNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5)
        self.conv2 = torch.nn.Conv2d(20, 50, 5)
        self.fc1 = torch.nn.Linear(4*4*50, 500)
        self.fc2 = torch.nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def main():
    torch_cnn = MyLeNet()
    torch_cnn.load_state_dict(torch.load('my_mnist_model.pth'))
    state = torch_cnn.state_dict()

    new_state = {}
    state_names = ['conv1.weight','conv1.bias',
                   'conv2.weight','conv2.bias',
                   'fc1.weight','fc1.bias',
                   'fc2.weight','fc2.bias']
    for sn in state_names:
        new_state[sn] = to_float(state[sn])

    f=open('param.dat','wb')
    pickle.dump(new_state,f)
    f.close()

def to_float(data):
    ls = []
    for d in data:
        if is_number(d):
            ls.append(float(d))
        else:
            ls.append(to_float(d))
    return ls

def is_number(data):
    try:
        float(data)
        return True
    except:
        return False

if __name__ == '__main__':
    main()

 

Python:自前CNNモジュールで数字判別するスクリプト

import math,time,pickle

from PIL import Image, ImageOps

import fixed_cnn as nn

class Network:
    def __init__(self):
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def __call__(self,x):
        return self.forward(x)

    def set_param(self,state):
        self.conv1.set_filters(state['conv1.weight'])
        self.conv1.set_bias(state['conv1.bias'])
        self.conv2.set_filters(state['conv2.weight'])
        self.conv2.set_bias(state['conv2.bias'])
        self.fc1.set_weights(state['fc1.weight'])
        self.fc1.set_bias(state['fc1.bias'])
        self.fc2.set_weights(state['fc2.weight'])
        self.fc2.set_bias(state['fc2.bias'])

    def forward(self, x):
        x = nn.relu(self.conv1(x))
        x = nn.max_pool2d(x, 2, 2)
        x = nn.relu(self.conv2(x))
        x = nn.max_pool2d(x, 2, 2)
        x = nn.trans3dto1d(x) #x = x.view(-1, 4*4*50)
        x = nn.relu(self.fc1(x))
        x = self.fc2(x)
        return nn.log_softmax(x)

    def max_ch(self,data_1d):
        max_value = -math.inf
        ch = 0
        for i,value in enumerate(data_1d):
            if value > max_value:
                max_value = value
                ch = i
        return ch

def main():
    start_time = time.time()

    mynet = Network()

    f = open('param.dat','rb')
    state = pickle.load(f)
    f.close()

    mynet.set_param(state)
    set_time = time.time() - start_time

    start_time = time.time()
    for i in range(10):
        image_path = '../img4/%s-0.png' % i
        img = Image.open(image_path)
        img = img.convert('L').resize((28,28))
        img = ImageOps.invert(img)
        data = nn.to_data(img)

        output = mynet(data)
        result = mynet.max_ch(output)
        print('result',result,image_path)
    calc_time = (time.time() - start_time)/10.0

    print('set_time',set_time)
    print('calc_time',calc_time)

if __name__ == '__main__':
    main()

 

自前CNNモジュール

ダウンロードはこちらから。
(動作確認環境:Windows 64bit、Python3.8)

ソース
Cython:fixed_cnn.pyx

import math,copy

from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free

cdef class Conv2d:
    cdef int in_channels
    cdef int out_channels
    cdef int kernel_size
    cdef int input_w
    cdef int input_h
    cdef double ****filters
    cdef double *bias
    cdef double ***input

    def __init__(self,in_channels, out_channels, kernel_size):
        cdef int i,j,k
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        self.filters = <double****> PyMem_Malloc(out_channels * sizeof(double*))
        for i in range(out_channels):
            self.filters[i] = <double***> PyMem_Malloc(in_channels * sizeof(double*))
            for j in range(in_channels):
                self.filters[i][j] = <double**> PyMem_Malloc(kernel_size * sizeof(double*))
                for k in range(kernel_size):
                    self.filters[i][j][k] = <double*> PyMem_Malloc(kernel_size * sizeof(double))
        self.bias = <double*> PyMem_Malloc(out_channels * sizeof(double))

    def __call__(self,data_3d):
        self.set_input(data_3d)
        return self.feature_maps()

    def set_input(self,data_3d):
        cdef int i,j,k
        self.input_h = len(data_3d[0])
        self.input_w = len(data_3d[0][0])

        self.input = <double***> PyMem_Malloc(self.in_channels * sizeof(double*))
        for i in range(self.in_channels):
            self.input[i] = <double**> PyMem_Malloc(self.input_h * sizeof(double*))
            for j in range(self.input_h):
                self.input[i][j] = <double*> PyMem_Malloc(self.input_w * sizeof(double))

        for i in range(self.in_channels):
            for j in range(self.input_h):
                for k in range(self.input_w):
                    self.input[i][j][k] = data_3d[i][j][k]

    def set_filters(self,data_4d):
        cdef int i,j,k,l
        for i in range(self.out_channels):
            for j in range(self.in_channels):
                for k in range(self.kernel_size):
                    for l in range(self.kernel_size):
                        self.filters[i][j][k][l] = data_4d[i][j][k][l]

    def set_bias(self,data_1d):
        cdef int i
        for i in range(self.out_channels):
            self.bias[i] = float(data_1d[i])

    def feature_maps(self):
        cdef int map_h,map_w,ch0,ch1,i,j,k,l,x,y
        cdef double v,weight
        map_h = self.input_h - self.kernel_size + 1
        map_w = self.input_w - self.kernel_size + 1
        fmaps = multi_list(self.out_channels,map_h,map_w)
        for ch1 in range(self.out_channels):
            for i in range(map_h):
                for j in range(map_w):
                    v = 0.0
                    for ch0 in range(self.in_channels):
                        for k in range(self.kernel_size):
                            for l in range(self.kernel_size):
                                y = i + k
                                x = j + l
                                weight = self.filters[ch1][ch0][k][l]
                                v += self.input[ch0][y][x] * weight
                    v += self.bias[ch1]
                    fmaps[ch1][i][j] = v
        return fmaps

def multi_list(*element_nums):
    cdef int dim,i,j
    dim = len(element_nums)
    data = [0 for j in range(element_nums[dim-1])]
    for i in range(2,dim+1):
        data = [ copy.deepcopy(data) for j in range(element_nums[dim-i])]
    return data

def shape(data):
    elem_nums = []
    while True:
        try:
            elem_nums.append(len(data))
            data = data[0]
        except:
            break
    return elem_nums

def trans3dto1d(data):
    new_data = []
    for lines in data:
        for line in lines:
            for value in line:
                new_data.append(value)
    return new_data

def max_pool2d(data,int kernel_size,int stride):
    cdef int in_ch,in_h,in_w,out_h,out_w,ch,i,j,k,l,x,y
    cdef double v,v_max,inf
    in_ch = len(data)
    in_h = len(data[0])
    in_w = len(data[0][0])
    out_h = math.ceil((in_h - (kernel_size-1))/stride)
    out_w = math.ceil((in_w - (kernel_size-1))/stride)
    out_data = multi_list(in_ch,out_h,out_w)
    for ch in range(in_ch):
        for i in range(out_h):
            for j in range(out_w):
                v_max = 0.0
                for k in range(kernel_size):
                    for l in range(kernel_size):
                        y = i*stride + k
                        x = j*stride + l
                        v = data[ch][y][x]
                        if v > v_max:
                            v_max = v
                out_data[ch][i][j] = v_max
    return out_data

def relu(data):
    cdef int dim,i,j,k
    cdef double value
    shp = shape(data)
    dim = len(shp)
    if dim == 1:
        for i in range(shp[0]):
            value = data[i]
            if value < 0.0:
                data[i] = 0.0
        return data
    elif dim == 3:
        for i in range(shp[0]):
            for j in range(shp[1]):
                for k in range(shp[2]):
                    value = data[i][j][k]
                    if value < 0.0:
                        data[i][j][k] = 0.0
        return data
    else:
        print('Error: Unsupported dimension. dim=%d' % dim)
        exit()

def log_softmax(data_1d):
    cdef double xmax,e_sum,e_x,x
    xmax = max(data_1d)
    e_sum = 0.0
    e_xs = []
    for x in data_1d:
        e_x = math.exp(x - xmax)
        e_xs.append(e_x)
        e_sum += e_x
    result = []
    for e_x in e_xs:
        result.append(e_x/e_sum)
    return result

cdef class Linear:
    cdef int in_features
    cdef int out_features
    cdef list values
    cdef double *bias
    cdef double **weights

    def __init__(self,in_features,out_features):
        cdef int i
        self.in_features = in_features
        self.out_features = out_features
        self.values = multi_list(out_features)
        #self.bias = multi_list(out_features)
        #self.weights = multi_list(out_features,in_features)

        self.bias = <double*> PyMem_Malloc(out_features * sizeof(double))
        self.weights = <double**> PyMem_Malloc(out_features * sizeof(double*))
        for i in range(out_features):
            self.weights[i] = <double*> PyMem_Malloc(in_features * sizeof(double))

    def set_weights(self,data_2d):
        cdef int i,j
        for i in range(self.out_features):
            for j in range(self.in_features):
                self.weights[i][j] = data_2d[i][j]

    def set_bias(self,data_1d):
        cdef int i
        for i in range(self.out_features):
            self.bias[i] = data_1d[i]

    def __call__(self,data_1d):
        cdef int i,j
        cdef double value,pvalue
        values = []
        for i in range(self.out_features):
            value = 0.0
            for j in range(self.in_features):
                pvalue = data_1d[j]
                value += pvalue * self.weights[i][j]
            value += self.bias[i]
            values.append(value)
        return values

def to_data(pil_img):
    cdef int i,j,w,h
    w,h = pil_img.size
    new_data = multi_list(1,h,w)
    for i in range(w):
        for j in range(h):
            new_data[0][j][i] = pil_img.getpixel((i,j))
    return new_data

 

Updated: 2022年8月27日 — 08:23

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です