Pythonで手書き数字のORCにトライ(その3)

以前、PyTorchで全結合のニューラルネットワークで、手書き数字認識を行い、自分の文字の正解率6~8割くらいでした。その後、線化処理などを加えた全結合のニューラルネットワークで試すと、9割くらいになりました。

今回は、畳み込みニューラルネットワーク(CNN)を試しました。モデルは、1998年に考案された元祖CNN、LeNetを模したモデルでやってみました。

結果、NMISTのテストデータでは99%以上の正解率、自分の文字の正解率は96%でした。すげぇー。

ただ、文字を中心に配置してサイズ調整する処理を加えない場合は、正解率は82%に落ちてしまったので、画像の前処理が肝となりそうです。

 

 

モデル

モデルは下記のサイトのものをコピペしました。

 

PyTorchでシンプルな畳み込みニューラルネットワークを作ろう
https://qiita.com/sudamasahiko/items/fd6a52f958f3f9013f0f
 

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4*4*50, 500)
        self.fc2 = torch.nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

(全ソースコードは最後に)

上記の継承クラスを書き換えただけで、ほぼ前回のソースコードで回すことができました。PyTorchすごい。

 

結果

MNISTのtestデータの結果

0:99.80 %
1:99.47 %
2:99.52 %
3:99.31 %
4:98.88 %
5:99.22 %
6:99.06 %
7:98.54 %
8:99.59 %
9:99.41 %
total: 99.28 %

素晴らしい正解率。4が9に、7が1になることがあるのは相変わらずでしたが、それでも98%以上。

 

さて、自分の文字は?

0(100%): 0 0 0 0 0 0 0 0 0 0 
1( 90%): 1 1 1 2 1 1 1 1 1 1 
2(100%): 2 2 2 2 2 2 2 2 2 2 
3(100%): 3 3 3 3 3 3 3 3 3 3 
4(100%): 4 4 4 4 4 4 4 4 4 4 
5(100%): 5 5 5 5 5 5 5 5 5 5 
6( 90%): 6 6 6 6 6 8 6 6 6 6 
7(100%): 7 7 7 7 7 7 7 7 7 7 
8(100%): 8 8 8 8 8 8 8 8 8 8 
9( 80%): 9 9 9 9 9 9 9 9 7 7 
total: 96.00 %

おおー、素晴らしい。
これは2値化処理および文字を中央に配置し、さらに縮小する際、文字サイズを幅の80%程度に調整した画像での判別です。調整サイズを71%にしても、正解率は変わず(前は正解率が大幅に変化)で、確かに位置ずれに強くなっているようです。

ただし、中央寄せもサイズ調整もしていないテキトーに切り出した元の画像だと正解率は、8割くらいになりました。

0( 60%): 6 0 8 0 0 4 6 0 0 0 
1( 60%): 7 1 1 1 1 7 1 1 0 8 
2( 80%): 8 2 2 2 2 6 2 2 2 2 
3( 90%): 3 3 3 3 3 3 3 3 3 0 
4(100%): 4 4 4 4 4 4 4 4 4 4 
5(100%): 5 5 5 5 5 5 5 5 5 5 
6( 60%): 6 6 6 8 6 8 5 6 2 6 
7(100%): 7 7 7 7 7 7 7 7 7 7 
8( 90%): 8 7 8 8 8 8 8 8 8 8 
9( 80%): 9 9 9 9 8 9 9 9 9 8 
total: 82.00 %

調整された画像で学習しているので、当然といえば当然かもしれません。

そうなると綺麗に文字の領域を判別できるかカギになります。今のサンプル画像は綺麗に2値化できているので領域検出がたやすいですが、ノイズが入るとたちまちダメになりそう。物体検出のCNNなども試してみたくなります。OCRによる実務の効率化という当初目的から、意識が離れていっている気がしますが。。。
PyTorch、流行りの技術を手軽に試すことができて楽しいので、試すことが自体が目的になっちゃいますね^^;

 

さて、このCNN、Cythonで実装したいとは思うのですが、アルゴリズムはなんとなく理解しましたが、完全自力は無理そうなのでサンプルコードが必要です。良い書籍等が見つかれば挑戦してみることにします。

 
 

学習部:下記のサイトのソースをベースにモデルを変更したものです。
https://rightcode.co.jp/blog/information-technology/pytorch-mnist-learning

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4*4*50, 500)
        self.fc2 = torch.nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def load_MNIST(batch=100, intensity=1.0):
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(r'C:\Temp',
                       train=True,
                       download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Lambda(lambda x: x * intensity)
                       ])),
        batch_size=batch,
        shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(r'C:\Temp',
                       train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Lambda(lambda x: x * intensity)
                       ])),
        batch_size=batch,
        shuffle=True)

    return {'train': train_loader, 'test': test_loader}


def main():
    # 学習回数
    epoch = 20

    # 学習結果の保存用
    history = {
        'train_loss': [],
        'test_loss': [],
        'test_acc': [],
    }

    # ネットワークを構築
    net: torch.nn.Module = MyNet()

    # MNISTのデータローダーを取得
    loaders = load_MNIST()

    optimizer = torch.optim.Adam(params=net.parameters(), lr=0.001)

    for e in range(epoch):

        """ Training Part"""
        loss = None
        # 学習開始 (再開)
        net.train(True)  # 引数は省略可能
        for i, (data, target) in enumerate(loaders['train']):

            optimizer.zero_grad()
            output = net(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                print('Training log: {} epoch ({} / 60000 train. data). Loss: {}'.format(e+1,
                                                                                         (i+1)*100,
                                                                                         loss.item())
                      )

        history['train_loss'].append(loss)

        """ Test Part """
        # 学習のストップ
        net.eval()  # または net.train(False) でも良い
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for data, target in loaders['test']:
                output = net(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= 10000

        print('Test loss (avg): {}, Accuracy: {}'.format(test_loss,
                                                         correct / 10000))

        history['test_loss'].append(test_loss)
        history['test_acc'].append(correct / 10000)

    #モデルの保存
    torch.save(net.state_dict(), 'my_mnist_model.pth')

    # 結果の出力と描画
    print(history)
    plt.figure()
    plt.plot(range(1, epoch+1), history['train_loss'], label='train_loss')
    plt.plot(range(1, epoch+1), history['test_loss'], label='test_loss')
    plt.xlabel('epoch')
    plt.legend()
    plt.savefig('loss.png')

    plt.figure()
    plt.plot(range(1, epoch+1), history['test_acc'])
    plt.title('test accuracy')
    plt.xlabel('epoch')
    plt.savefig('test_acc.png')

if __name__ == '__main__':
    main()

 

学習済みモデルによる判別

# -*- coding: utf-8 -*-

import os
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from PIL import Image, ImageOps


class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4*4*50, 500)
        self.fc2 = torch.nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def predic(data):
    net: torch.nn.Module = MyNet()
    net.load_state_dict(torch.load('my_mnist_model.pth'))
    net = net.eval()

    output = net(data)
    _, prediction = torch.max(output, 1)
    print('result=' + str(prediction[0].item()))

def image_loader(path,invert=True):
    image = Image.open(path)
    image = image.convert('L').resize((28,28))
    if invert:
        image = ImageOps.invert(image)

    transform=transforms.Compose([
                                transforms.ToTensor(),
                                #transforms.Normalize((0.5,), (0.5,))
                           ])
    data = transform(image)
    data = data.unsqueeze(0)
    return data

def test():
    net = MyNet()
    net.load_state_dict(torch.load('my_mnist_model.pth'))
    net = net.eval()

    total_n = total_c = 0.0
    for i in range(10):
        path = './img/'
        files = os.listdir(path)
        flist = [f for f in files if os.path.isfile(os.path.join(path, f))]

        n = c = 0
        result = ''
        for j in range(10):
            f = str(i)+'-'+str(j)+'.png'
            filepath = os.path.join(path,f)
            data = image_loader(filepath)
            output = net(data)
            _, prediction = torch.max(output, 1)
            # 結果を出力
            re = str(prediction[0].item())
            if str(i) == re:
                c += 1
            n += 1
            result += re+' '
        per = float(c)/float(n)*100
        total_c += c
        total_n += n
        print('%d(%d%%): %s' % (i,per,result))
    per = total_c/total_n*100
    print('total: %0.2f %%' % per)

def main():
    img = image_loader('img-path.png')
    predic(img)

if __name__ == '__main__':
    main()
Updated: 2021年1月11日 — 20:25

コメントを残す

メールアドレスが公開されることはありません。