Pythonで画像の畳み込みニューラルネットワーク処理のコーディング(学習なし)

前回、PyTorchで手書き数字の識別する畳み込みニューラルネットワーク(CNN)を試しました。モデルは元祖CNNであるLeNet形で、今回はその畳み込み演算をPythonで書いてみました。

 

外枠だけニューラルネットワークの形にしてはいますが、画像の1画素をニューロンとみなして畳み込み演算のフィルタを機械学習可能にしたCNNのコーディングとは、たぶん別物です。

元々の目的がPyTorchで学習したCNNモデルのパラメータ使ってCythonで自前のライブラリを作ることで、今回は理解している処理が正しいのか確認するための勉強コーディングになります。

そのため学習(逆伝播)を実装するつもりがないので、一般的(だと思う)な畳み込み演算をそのままコーディングしています。おそらく学習するんであれば、ニューロンのクラス化等々、相応のやり方があるんだろうなぁと想像します。

 

ここからCythonで高速化することが、本番ですが、畳み込み演算の実際の処理を理解したい人とか、CNNの処理の流れ理解したい人とかには役立つかなぁと思って投稿します。まぁ、ホントはなんの書物等の手助けもなく出来たことが嬉しかったんで投稿しとこうってだけなのですが^^;

 

import math,random,copy,time

class Conv2d:
    def __init__(self,in_channels, out_channels, kernel_size):
        self.in_channels = ch0 = in_channels
        self.out_channels = ch1 = out_channels
        self.kernel_size = sz = kernel_size

        self.filters = multi_list(ch1,ch0,sz,sz)
        self.bias = multi_list(ch1)
        self.input = None

    def __call__(self,data_3d):
        self.set_input(data_3d)
        return self.feature_maps()

    def set_input(self,data_3d):
        ch = len(data_3d) # = in_channels
        self.input_h = h = len(data_3d[0])
        self.input_w = w = len(data_3d[0][0])
        self.input = multi_list(ch,h,w)
        for ch,lines in enumerate(data_3d):
            for h,line in enumerate(lines):
                for w,value in enumerate(line):
                    self.input[ch][h][w] = float(value)

    def set_filters(self,data_4d):
        for n,data in enumerate(data_4d):
            for ch,lines in enumerate(data):
                for h,line in enumerate(lines):
                    for w,value in enumerate(line):
                        self.filters[n][ch][h][w] = float(value)

    def set_bias(self,data_1d):
        for i in range(self.out_channels):
            self.bias[i] = float(data_1d[i])

    def feature_maps(self):
        in_ch = self.in_channels
        out_ch = self.out_channels
        size = self.kernel_size
        map_w = self.input_w - size + 1
        map_h = self.input_h - size + 1
        fmaps = multi_list(out_ch,map_h,map_w)
        for ch1 in range(self.out_channels):
            for i in range(map_h):
                for j in range(map_w):
                    v = 0.0
                    for ch0 in range(self.in_channels):
                        for k in range(size):
                            for l in range(size):
                                y = i + k
                                x = j + l
                                weight = self.filters[ch1][ch0][k][l]
                                v += self.input[ch0][y][x] * weight
                    v += self.bias[ch1]
                    fmaps[ch1][i][j] = v
        return fmaps

def multi_list(*shape):
    dim = len(shape)
    data = [0 for j in range(shape[dim-1])]
    for i in range(2,dim+1):
        data = [ copy.deepcopy(data) for j in range(shape[dim-i])]
    return data

def shape(data):
    elem_nums = []
    while True:
        try:
            elem_nums.append(len(data))
            data = data[0]
        except:
            break
    return elem_nums

def trans3dto1d(data):
    new_data = []
    for lines in data:
        for line in lines:
            for value in line:
                new_data.append(value)
    return new_data

def max_pool2d(data,kernel_size,stride):
    in_ch = len(data)
    in_h = len(data[0])
    in_w = len(data[0][0])
    out_h = math.ceil((in_h - (kernel_size-1))/stride)
    out_w = math.ceil((in_w - (kernel_size-1))/stride)
    out_data = multi_list(in_ch,out_h,out_w)
    for ch in range(in_ch):
        for i in range(out_h):
            for j in range(out_w):
                v_max = -math.inf
                for k in range(kernel_size):
                    for l in range(kernel_size):
                        y = i*stride + k
                        x = j*stride + l
                        v = data[ch][y][x]
                        if v > v_max:
                            v_max = v
                out_data[ch][i][j] = v_max
    return out_data

def relu(data):
    dim = len(shape(data))
    if dim == 1:
        for i,value in enumerate(data):
            if value < 0.0:
                data[i] = 0.0
        return data
    elif dim == 3:
        for ch,lines in enumerate(data):
            for h,line in enumerate(lines):
                for w,value in enumerate(line):
                    value = data[ch][h][w]
                    if value < 0.0:
                        data[ch][h][w] = 0.0
        return data
    else:
        print('Error: Unsupported dimension. dim=%d' % dim)
        exit()

def log_softmax(data_1d):
    xmax = max(data_1d)
    e_sum = 0.0
    e_xs = []
    for x in data_1d:
        e_x = math.exp(x - xmax)
        e_xs.append(e_x)
        e_sum += e_x
    result = []
    for e_x in e_xs:
        result.append(e_x/e_sum)
    return result

class Linear:
    def __init__(self,in_features,out_features):
        self.in_features = in_features
        self.out_features = out_features
        self.values = multi_list(out_features)
        self.bias = multi_list(out_features)
        self.weights = multi_list(out_features,in_features)

    def set_weights(self,data_2d):
        for i in range(self.out_features):
            for j in range(self.in_features):
                self.weights[i][j] = float(data_2d[i][j])

    def set_bias(self,data_1d):
        for i in range(self.out_features):
            self.bias[i] = float(data_1d[i])

    def __call__(self,data_1d):
        values = []
        for i in range(self.out_features):
            value = 0.0
            for j in range(self.in_features):
                pvalue = float(data_1d[j])
                value += pvalue * self.weights[i][j]
            value += self.bias[i]
            values.append(value)
        return values

class Network:
    def __init__(self):
        self.conv1 = Conv2d(1, 20, 5)
        self.conv2 = Conv2d(20, 50, 5)
        self.fc1 = Linear(4*4*50, 500)
        self.fc2 = Linear(500, 10)

    def __call__(self,x):
        return self.forward(x)

    def set_param(self,state):
        self.conv1.set_filters(state['conv1.weight'])
        self.conv1.set_bias(state['conv1.bias'])
        self.conv2.set_filters(state['conv2.weight'])
        self.conv2.set_bias(state['conv2.bias'])
        self.fc1.set_weights(state['fc1.weight'])
        self.fc1.set_bias(state['fc1.bias'])
        self.fc2.set_weights(state['fc2.weight'])
        self.fc2.set_bias(state['fc2.bias'])

    def forward(self, x):
        x = relu(self.conv1(x))
        x = max_pool2d(x, 2, 2)
        x = relu(self.conv2(x))
        x = max_pool2d(x, 2, 2)
        x = trans3dto1d(x) #x = x.view(-1, 4*4*50)
        x = relu(self.fc1(x))
        x = self.fc2(x)
        return log_softmax(x)

    def max_ch(self,data_1d):
        max_value = -math.inf
        ch = 0
        for i,value in enumerate(data_1d):
            if value > max_value:
                max_value = value
                ch = i
        return ch

def main():
    import torch
    import torch.nn.functional as F
    import torchvision
    from torchvision import datasets, transforms
    from PIL import Image, ImageOps

    class MyLeNet(torch.nn.Module):
        def __init__(self):
            super(MyLeNet, self).__init__()
            self.conv1 = torch.nn.Conv2d(1, 20, 5)
            self.conv2 = torch.nn.Conv2d(20, 50, 5)
            self.fc1 = torch.nn.Linear(4*4*50, 500)
            self.fc2 = torch.nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4*4*50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    start_time = time.time()

    mynet = Network()

    torch_cnn = MyLeNet()
    torch_cnn.load_state_dict(torch.load('my_mnist_model.pth'))

    state = torch_cnn.state_dict()
    mynet.set_param(state)
    print('set state!')
    set_time = time.time() - start_time
    start_time = time.time()

    image_path = '../img4/0-0.png'
    img = Image.open(image_path)
    img = img.convert('L').resize((28,28))
    img = ImageOps.invert(img)
    trans = transforms.Compose([transforms.ToTensor()])
    data = trans(img)

    output = mynet(data)
    result = mynet.max_ch(output)
    print('result',result,image_path)
    calc_time = time.time() - start_time

    print('set_time',set_time)
    print('calc_time',calc_time)

if __name__ == '__main__':
    main()
Updated: 2021年2月7日 — 02:37

コメントを残す

メールアドレスが公開されることはありません。