前回、PyTorchで手書き数字の識別する畳み込みニューラルネットワーク(CNN)を試しました。モデルは元祖CNNであるLeNet形で、今回はその畳み込み演算をPythonで書いてみました。
外枠だけニューラルネットワークの形にしてはいますが、画像の1画素をニューロンとみなして畳み込み演算のフィルタを機械学習可能にしたCNNのコーディングとは、たぶん別物です。
元々の目的がPyTorchで学習したCNNモデルのパラメータ使ってCythonで自前のライブラリを作ることで、今回は理解している処理が正しいのか確認するための勉強コーディングになります。
そのため学習(逆伝播)を実装するつもりがないので、一般的(だと思う)な畳み込み演算をそのままコーディングしています。おそらく学習するんであれば、ニューロンのクラス化等々、相応のやり方があるんだろうなぁと想像します。
ここからCythonで高速化することが、本番ですが、畳み込み演算の実際の処理を理解したい人とか、CNNの処理の流れ理解したい人とかには役立つかなぁと思って投稿します。まぁ、ホントはなんの書物等の手助けもなく出来たことが嬉しかったんで投稿しとこうってだけなのですが^^;
追記:Cython化しました
import math,random,copy,time class Conv2d: def __init__(self,in_channels, out_channels, kernel_size): self.in_channels = ch0 = in_channels self.out_channels = ch1 = out_channels self.kernel_size = sz = kernel_size self.filters = multi_list(ch1,ch0,sz,sz) self.bias = multi_list(ch1) self.input = None def __call__(self,data_3d): self.set_input(data_3d) return self.feature_maps() def set_input(self,data_3d): ch = len(data_3d) # = in_channels self.input_h = h = len(data_3d[0]) self.input_w = w = len(data_3d[0][0]) self.input = multi_list(ch,h,w) for ch,lines in enumerate(data_3d): for h,line in enumerate(lines): for w,value in enumerate(line): self.input[ch][h][w] = float(value) def set_filters(self,data_4d): for n,data in enumerate(data_4d): for ch,lines in enumerate(data): for h,line in enumerate(lines): for w,value in enumerate(line): self.filters[n][ch][h][w] = float(value) def set_bias(self,data_1d): for i in range(self.out_channels): self.bias[i] = float(data_1d[i]) def feature_maps(self): in_ch = self.in_channels out_ch = self.out_channels size = self.kernel_size map_w = self.input_w - size + 1 map_h = self.input_h - size + 1 fmaps = multi_list(out_ch,map_h,map_w) for ch1 in range(self.out_channels): for i in range(map_h): for j in range(map_w): v = 0.0 for ch0 in range(self.in_channels): for k in range(size): for l in range(size): y = i + k x = j + l weight = self.filters[ch1][ch0][k][l] v += self.input[ch0][y][x] * weight v += self.bias[ch1] fmaps[ch1][i][j] = v return fmaps def multi_list(*shape): dim = len(shape) data = [0 for j in range(shape[dim-1])] for i in range(2,dim+1): data = [ copy.deepcopy(data) for j in range(shape[dim-i])] return data def shape(data): elem_nums = [] while True: try: elem_nums.append(len(data)) data = data[0] except: break return elem_nums def trans3dto1d(data): new_data = [] for lines in data: for line in lines: for value in line: new_data.append(value) return new_data def max_pool2d(data,kernel_size,stride): in_ch = len(data) in_h = len(data[0]) in_w = len(data[0][0]) out_h = math.ceil((in_h - (kernel_size-1))/stride) out_w = math.ceil((in_w - (kernel_size-1))/stride) out_data = multi_list(in_ch,out_h,out_w) for ch in range(in_ch): for i in range(out_h): for j in range(out_w): v_max = -math.inf for k in range(kernel_size): for l in range(kernel_size): y = i*stride + k x = j*stride + l v = data[ch][y][x] if v > v_max: v_max = v out_data[ch][i][j] = v_max return out_data def relu(data): dim = len(shape(data)) if dim == 1: for i,value in enumerate(data): if value < 0.0: data[i] = 0.0 return data elif dim == 3: for ch,lines in enumerate(data): for h,line in enumerate(lines): for w,value in enumerate(line): value = data[ch][h][w] if value < 0.0: data[ch][h][w] = 0.0 return data else: print('Error: Unsupported dimension. dim=%d' % dim) exit() def log_softmax(data_1d): xmax = max(data_1d) e_sum = 0.0 e_xs = [] for x in data_1d: e_x = math.exp(x - xmax) e_xs.append(e_x) e_sum += e_x result = [] for e_x in e_xs: result.append(e_x/e_sum) return result class Linear: def __init__(self,in_features,out_features): self.in_features = in_features self.out_features = out_features self.values = multi_list(out_features) self.bias = multi_list(out_features) self.weights = multi_list(out_features,in_features) def set_weights(self,data_2d): for i in range(self.out_features): for j in range(self.in_features): self.weights[i][j] = float(data_2d[i][j]) def set_bias(self,data_1d): for i in range(self.out_features): self.bias[i] = float(data_1d[i]) def __call__(self,data_1d): values = [] for i in range(self.out_features): value = 0.0 for j in range(self.in_features): pvalue = float(data_1d[j]) value += pvalue * self.weights[i][j] value += self.bias[i] values.append(value) return values class Network: def __init__(self): self.conv1 = Conv2d(1, 20, 5) self.conv2 = Conv2d(20, 50, 5) self.fc1 = Linear(4*4*50, 500) self.fc2 = Linear(500, 10) def __call__(self,x): return self.forward(x) def set_param(self,state): self.conv1.set_filters(state['conv1.weight']) self.conv1.set_bias(state['conv1.bias']) self.conv2.set_filters(state['conv2.weight']) self.conv2.set_bias(state['conv2.bias']) self.fc1.set_weights(state['fc1.weight']) self.fc1.set_bias(state['fc1.bias']) self.fc2.set_weights(state['fc2.weight']) self.fc2.set_bias(state['fc2.bias']) def forward(self, x): x = relu(self.conv1(x)) x = max_pool2d(x, 2, 2) x = relu(self.conv2(x)) x = max_pool2d(x, 2, 2) x = trans3dto1d(x) #x = x.view(-1, 4*4*50) x = relu(self.fc1(x)) x = self.fc2(x) return log_softmax(x) def max_ch(self,data_1d): max_value = -math.inf ch = 0 for i,value in enumerate(data_1d): if value > max_value: max_value = value ch = i return ch def main(): import torch import torch.nn.functional as F import torchvision from torchvision import datasets, transforms from PIL import Image, ImageOps class MyLeNet(torch.nn.Module): def __init__(self): super(MyLeNet, self).__init__() self.conv1 = torch.nn.Conv2d(1, 20, 5) self.conv2 = torch.nn.Conv2d(20, 50, 5) self.fc1 = torch.nn.Linear(4*4*50, 500) self.fc2 = torch.nn.Linear(500, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) start_time = time.time() mynet = Network() torch_cnn = MyLeNet() torch_cnn.load_state_dict(torch.load('my_mnist_model.pth')) state = torch_cnn.state_dict() mynet.set_param(state) print('set state!') set_time = time.time() - start_time start_time = time.time() image_path = '../img4/0-0.png' img = Image.open(image_path) img = img.convert('L').resize((28,28)) img = ImageOps.invert(img) trans = transforms.Compose([transforms.ToTensor()]) data = trans(img) output = mynet(data) result = mynet.max_ch(output) print('result',result,image_path) calc_time = time.time() - start_time print('set_time',set_time) print('calc_time',calc_time) if __name__ == '__main__': main()