DEMの勉強4

前回のプログラムを要素数492ほどで試してみたら、ネットブック(Acer Aspire One HAPPY2)では、1ステップで2.55秒もかかりました。お、おう。ネットブックで解析を回すつもりは無いとはいえ、最低でも1万ステップくらい回すことを踏まえるとさすがに遅い。ノートPC(Dell Inspiron14 core i3-4010U)でも0.63秒で4倍くらい速いものの、全然遅い。勉強用とはいえ、5000要素で1万ステップくらいは1分以内でやってほしい。1ステップで0.005秒くらいとなると、100倍以上の高速化が必要だ。

 

Pythonが他のプログラムに比べても演算の処理速度が遅いことは知っていたので、最終段階で高速化しようと思っていたけども、現時点で、高速化しないと色々遊んでみることもままならなさそう・・・。っということで、Cython化してみました。

 

Pythonの高速化の手段はいろいろあって、今回ただの演算なので、Numpyとかの演算ライブラリを使うのが有用な手段だけども、プログラムの可読性が悪くなるので、Cythonに挑戦してみました。

 

Cython入門 静的型付けでコードを高速化する
http://omake.accense.com/static/doc-ja/cython/src/quickstart/cythonize.html

 

結果から

ステップ0:Python 1ステップ2.55秒
ステップ1:ソースを何も変えずにCython化 1ステップ2.01秒(1.27倍)
ステップ2:ループ仕様を変更 1ステップ1.82秒(1.40倍)

ここまではCythonインストール以外の手間はほぼないので、うれしい高速化なのだけど、目指すところはその程度じゃない。

 

ステップ3:クラスを全てcdef化してみる⇒上手くコーディングできない

この辺になってくると、Cythonの仕様を理解しなくちゃいけなくなってきました。条件設定や画面表示などは通常のPythonにしたい都合上、Pythonから見える見えない問題の解決が難しくて、Cython仕様に合わせて一から作ってみようと思いました。

で、まだ全然書ききれていないのだけど、形になってきたので、とりあえずどれだけ速くなりそうか見てみたい。

 

ステップ4:Cythonで作り直し。1ステップ0.062秒(41倍)

おおー。100倍とは行かないまでも、随分と速くなりました。いや、まぁプログラムを省略している部分が多いんだけど。

体感的にはこんな感じ(ノートPC)

 

元のPythonにおける要素データは、要素クラスのインスタンスをそのままリストに代入していく方式。プログラムがとても分かりやすくなるのだけど、Cythonでクラスインスタンスを配列に格納したりする方法が分からず、結局、構造体の配列にすることにしました。数値配列(行列)にしてしまうと可読性が悪くなってしまうし、それならば、Numpyなどのライブラリ使うほうが、速度的にも、利便的にも良くて、Cythonを使う意味がなくなっちゃうので。それに伴い、クラスは使わないように変更(オブジェクト指向は断念)。

 

まだ、ちゃん動いているのかどうかの検証も、Python版では計算している処理もしてなかったりするので、ちゃんとした比較ではないけども、Pythonの書き方+ちょっとした知識だけで高速なモジュールが作れるのは、ありがたい。

 

dem.pyx (dem.pyd)

# -*- coding: utf-8 -*-

#Cython仕様で一から作ってみる

from libc.math cimport sqrt,sin,cos,atan2,fabs
from cpython cimport bool

DEF MAX_PAR_NO = 1000 #最大粒子数
DEF MAX_LINE_NO = 10 #最大線要素数
DEF MAX_ELEM_NO = MAX_PAR_NO + MAX_LINE_NO #最大要素数

cdef float PI =  3.1415926535 #円周率
cdef float G = 9.80665 #重力加速度

cdef struct Particle:
    
    #要素共通
    int etype #要素タイプ
    int n #要素No.
    float r#半径
    float x#X座標
    float y#Y座標
    float a#角度
    float dx#X方向増加量
    float dy#Y方向増加量
    float da#角度増加量
    float vx#X方向速度
    float vy#Y方向速度
    float va#角速度
    float fy
    float fx
    float fm
    float en[MAX_ELEM_NO]#弾性力(直方向)
    float es[MAX_ELEM_NO]#弾性力(せん断方向)
    
    #粒子専用
    float m# 質量
    float Ir#慣性モーメント

cdef struct Line:
    
    #要素共通
    int etype #要素タイプ
    int n #要素No.
    float r#半径
    float x#X座標
    float y#Y座標
    float a#角度
    float dx#X方向増加量
    float dy#Y方向増加量
    float da#角度増加量
    float vx#X方向速度
    float vy#Y方向速度
    float va#角速度
    float fy
    float fx
    float fm
    float en[MAX_ELEM_NO]#弾性力(直方向)
    float es[MAX_ELEM_NO]#弾性力(せん断方向)
    
    #線要素専用
    float x1
    float y1
    float x2
    float y2

cdef struct Interface:
    
    float kn #弾性係数(法線方向)
    float etan#粘性係数(法線方向)
    float ks#弾性係数(せん断方向)
    float etas#弾性係数(せん断方向)
    float frc#摩擦係数

cdef Particle pe[MAX_ELEM_NO]
cdef Line le[MAX_LINE_NO]

cdef float rho = 10 #粒子間密度
cdef int peCount #粒子の数
cdef int leCount #線要素の数
cdef float dt = 0.001 #⊿t
cdef int st = 0 #ステップ

cdef void nextStep():
    resetForce()
    calcForce()
    updateCoord()
    global st
    st += 1

cdef void resetForce():
    cdef int i
    for i in range(peCount):
        pe[i].fx = 0
        pe[i].fy = 0
        pe[i].fm = 0

cdef Interface interface(int etype1,int etype2):
    cdef Interface inf
    if etype1 == 1 and etype2 == 1:
        inf.kn = 10**6 #弾性係数(法線方向)
        inf.etan = 10000 #粘性係数(法線方向)
        inf.ks = 10*6 #弾性係数(せん断方向)
        inf.etas = 10000 #粘性係数(せん断方向)
        inf.frc = 10*6 #摩擦係数
    return inf

cdef void calcForce():
    #2粒子間の接触判定
    cdef float lx,ly,ld,cos_a,sin_a
    cdef int i,j
    for i in range(peCount):
        for j in range(i+1,peCount):
            lx = pe[i].x - pe[j].x
            ly = pe[i].y - pe[j].y
            ld = (lx**2+ly**2)**0.5
            
            if (pe[i].r+pe[j].r)>ld:
                cos_a = lx/ld
                sin_a = ly/ld
                forcePar2par(i,j,cos_a,sin_a)
            else:
                pe[i].en[j] = 0.0
                pe[i].es[j] = 0.0
    
    #粒子と線の接触判定
    cdef bool hit
    cdef Particle p
    cdef Line l
    for i in range(peCount):
        for j in range(leCount):
            hit = False
            p = pe[i]
            l = le[j]
            th0 = atan2(l.y2-l.y1, l.x2-l.x1)
            th1 = atan2(p.y-l.y1, p.x-l.x1)
            a = sqrt((p.x-l.x1)**2+(p.y-l.y1)**2)
            d = fabs(a*sin(th1-th0))
            if d < p.r:
                b = sqrt((p.x-l.x2)**2+(p.y-l.y2)**2)
                s = sqrt((l.x2-l.x1)**2+(l.y2-l.y1)**2)
                if a < s and b < s:
                    s1 = sqrt(a**2-d**2)
                    x = l.x1 + s1*cos(th0)
                    y = l.y1 + s1*sin(th0)
                    hit = True
                elif a < b and a < p.r:
                    x = l.x1
                    y = l.y1
                    hit = True
                elif b < p.r:
                    x = l.x2
                    y = l.y2
                    hit = True       
            if hit:
                lx = p.x - x
                ly = p.y - y
                ld = sqrt(lx**2+ly**2)
                cos_a = lx/ld
                sin_a = ly/ld
                forceLine2par(p,l,cos_a,sin_a)
            else:
                p.en[l.n] = 0.0
                p.es[l.n] = 0.0
    #外力
    for i in range(peCount):
        pe[i].fy += -G*pe[i].m #重力

cdef void forcePar2par(int i,int j,float cos_a, float sin_a):
    cdef float un,us,vn,vs,hn,hs
    
    #相対的変位増分
    un = +(pe[i].dx-pe[j].dx)*cos_a+(pe[i].dy-pe[j].dy)*sin_a
    us = -(pe[i].dx-pe[j].dx)*sin_a+(pe[i].dy-pe[j].dy)*cos_a+(pe[i].r*pe[i].da+pe[j].r*pe[j].da)
    #相対的速度増分
    vn = +(pe[i].vx-pe[j].vx)*cos_a+(pe[i].vy-pe[j].vy)*sin_a
    vs = -(pe[i].vx-pe[j].vx)*sin_a+(pe[i].vy-pe[j].vy)*cos_a+(pe[i].r*pe[i].va+pe[j].r*pe[j].va)
    
    inf = interface(pe[i].etype,pe[j].etype)
    
    #合力(局所座標系)
    pe[i].en[pe[j].n] += inf.kn*un
    pe[i].es[pe[j].n] += inf.ks*us
    hn = pe[i].en[pe[j].n] + inf.etan*vn
    hs = pe[i].es[pe[j].n] + inf.etas*vs
    
    if hn <= 0.0:
        #法線力がなければ、せん断力は0
        hs = 0.0
    elif fabs(hs) >= inf.frc*hn:
        #摩擦力以上のせん断力は働かない
        hs = inf.frc*fabs(hn)*hs/fabs(hs)     
    
    #全体合力(全体座標系)
    pe[i].fx += -hn*cos_a + hs*sin_a
    pe[i].fy += -hn*sin_a - hs*cos_a
    pe[i].fm -=  pe[i].r*hs
    pe[j].fx +=  hn*cos_a - hs*sin_a
    pe[j].fy +=  hn*sin_a + hs*cos_a
    pe[j].fm -=  pe[j].r*hs

cdef void forceLine2par(Particle p1,Line l,float cos_a, float sin_a):
    cdef float un,us,vn,vs,hn,hs
    
    #相対的変位増分
    un = +p1.dx*cos_a+p1.dy*sin_a
    us = -p1.dx*sin_a+p1.dy*cos_a+p1.r*p1.da
    #相対的速度増分
    vn = +p1.vx*cos_a+p1.vy*sin_a
    vs = -p1.vx*sin_a+p1.vy*cos_a+p1.r*p1.va
    
    inf = interface(p1.etype,l.etype)
    
    #合力(局所座標系)
    pe[p1.n].en[l.n] += inf.kn*un
    pe[p1.n].es[l.n] += inf.ks*us
    hn = p1.en[l.n] + inf.etan*vn
    hs = p1.es[l.n] + inf.etas*vs
    
    if hn <= 0.0:
        #法線力がなければ、せん断力は0
        hs = 0.0
    elif abs(hs) >= inf.frc*hn:
        #摩擦力以上のせん断力は働かない
        hs = inf.frc*fabs(hn)*hs/fabs(hs)    
    
    #全体合力(全体座標系)
    pe[p1.n].fx += -hn*cos_a + hs*sin_a
    pe[p1.n].fy += -hn*sin_a - hs*cos_a
    pe[p1.n].fm -=  p1.r*hs

cdef void updateCoord():
    cdef float ax,ay,aa
    cdef int i
    for i in range(peCount):
        #位置更新(オイラー差分)
        ax = pe[i].fx/pe[i].m
        ay = pe[i].fy/pe[i].m
        aa = pe[i].fm/pe[i].Ir
        pe[i].vx += ax*dt
        pe[i].vy += ay*dt
        pe[i].va += aa*dt
        pe[i].dx = pe[i].vx*dt
        pe[i].dy = pe[i].vy*dt
        pe[i].da = pe[i].va*dt
        pe[i].x += pe[i].dx
        pe[i].y += pe[i].dy
        pe[i].a += pe[i].da

# -------------------------
# Pythonからの設定用
# -------------------------
    
def setDeltaTime(sec):
    global dt
    dt = sec

def setNumberOfParticle(n):
    global peCount
    peCount = n

def numberOfParticle():
    return peCount

def setParticle(pe_no,pe_obj):
    pe[pe_no].x = pe_obj.x
    pe[pe_no].y = pe_obj.y 

def particle(pe_no,pe_obj):
    pe_obj.x = pe[pe_no].x
    pe_obj.y = pe[pe_no].y
    pe_obj.a = pe[pe_no].a
    return pe_obj    

def setLine(l_no,l_obj):
    le[l_no].x1 = l_obj.x1
    le[l_no].y1 = l_obj.y1
    le[l_no].x2 = l_obj.x2
    le[l_no].y2 = l_obj.y2
    return l_obj

def setNumberOfLine(n):
    global leCount
    leCount = n

def line(l_no,l_obj):
    l_obj.x1 = le[l_no].x1
    l_obj.y1 = le[l_no].y1
    l_obj.x2 = le[l_no].x2
    l_obj.y2 = le[l_no].y2
    return l_no

def initialize():
    cdef int i,j
    for i in range(MAX_ELEM_NO):
        pe[i].x = 0
        pe[i].y = 0
        pe[i].r = 0
        pe[i].a = 0
        pe[i].dx = 0
        pe[i].dy = 0
        pe[i].da = 0
        pe[i].vx = 0
        pe[i].vy = 0
        pe[i].va = 0
        pe[i].fx = 0
        pe[i].fy = 0
        pe[i].fm = 0
        pe[i].m = 0
        pe[i].Ir = 0
        for j in range(MAX_ELEM_NO):
            pe[i].en[j] = 0
            pe[i].es[j] = 0
            
def setup():
    cdef int i
    #粒子要素
    for i in range(peCount):
        pe[i].n = i
        pe[i].r = 5.0
        pe[i].m = 4.0/3.0*PI*rho*pe[i].r**3 # 質量
        pe[i].Ir = PI*rho*pe[i].r**4/2.0 #慣性モーメント
    #線要素
    for i in range(leCount):
        le[i].n = peCount+i


def step():
    return st

def calcStep(int n=1):
    cdef int i
    for i in range(n):
        nextStep()

 

dem_ui.py

# -*- coding: utf-8 -*-

print u'読み込み中...',

import sys
import math
import random
import time
import Tkinter

import dem

from PIL import ImageGrab

class Element(object):
    def __init__(self):
        self.n = 0 #要素No.
        self.r = 0 #半径
        self.x = 0 #X座標
        self.y = 0 #Y座標
        self.a = 0 #角度
        self.dx = 0 #X方向増加量
        self.dy = 0 #Y方向増加量
        self.da = 0 #角度増加量
        self.vx = 0 #X方向速度
        self.vy = 0 #Y方向速度
        self.va = 0 #角速度
        self.fy = 0
        self.fx = 0
        self.fm = 0
        
        self.en = [] #弾性力(直方向)
        self.es = [] #弾性力(せん断方向)

class Particle(Element):
    def __init__(self,x=0,y=0,vx=0,vy=0):
        super(Particle,self).__init__()
        self.type = 1
        
        self.x = x #X座標
        self.y = y #Y座標
        self.vx = vx #X方向速度
        self.vy = vy #Y方向速度
        
        rho = 10
        self.r = 5 #半径
        self.m = 4.0/3.0*math.pi*rho*self.r**3 # 質量
        self.loop_nr = math.pi*rho*self.r**4/2.0 #慣性モーメント

class Line(Element):
    def __init__(self,x1,y1,x2,y2):
        super(Line,self).__init__()
        self.type = 2
        self.x1 = x1
        self.y1 = y1
        self.x2 = x2
        self.y2 = y2

class DEM_UI:
    def __init__(self):
        self._pars = []
        self._lines = []
        
        self._setup()
        
    def _setup(self):
        dem.initialize()
        dem.setDeltaTime(0.01)
        
        # set lines
        self._lines = []
        self._lines.append(Line(5,5,295,5))
        self._lines.append(Line(5,195,5,5))
        self._lines.append(Line(295,195,295,5))
        self._lines.append(Line(5,195,295,195))
        self._lines.append(Line(100,100,300,150))
        self._lines.append(Line(10,80,160,50))
        dem.setNumberOfLine(len(self._lines))
        for i,l in enumerate(self._lines):
            dem.setLine(i,l)
        
        # set particle
        max_n = 100
        self.parCount = 0
        for x in range(40,290,2):
            for y in range(145,195,2):
                if self._hitParticle(x,y,5):
                    continue
                if self._hitLine(x,y,5):
                    continue
                p = Particle(x,y)
                dem.setParticle(self.parCount,p)
                self.parCount += 1
        dem.setNumberOfParticle(self.parCount)
        
        dem.setup()
        
        print(u'完了')
        print(u'粒子要素数: %d ' % self.parCount)

    def _hitParticle(self,x,y,r):
        hit = False
        for p in self.particles():
            lx = p.x - x
            ly = p.y - y
            ld = (lx**2+ly**2)**0.5
            if (p.r+r)>=ld:
                hit = True
                break
        return hit
    
    def _hitLine(self,px,py,pr):
        hit = False
        for l in self._lines:
            th0 = math.atan2(l.y2-l.y1,l.x2-l.x1)
            th1 = math.atan2(py-l.y1,px-l.x1)
            a = math.sqrt((px-l.x1)**2+(py-l.y1)**2)
            d = abs(a*math.sin(th1-th0))
            if d < pr:
                b = math.sqrt((px-l.x2)**2+(py-l.y2)**2)
                s = math.sqrt((l.x2-l.x1)**2+(l.y2-l.y1)**2)
                if a < s and b < s:
                    hit = True
                elif a < b and a < pr:
                    hit = True
                elif b < pr:
                    hit = True
                if hit:
                    break
        return hit
    
    def particles(self):
        pars = []
        for i in range(self.parCount):
            p = dem.particle(i,Particle())
            pars.append(p)
        return pars
    
    def lines(self):
        return self._lines
    
class Window(Tkinter.Tk):      
    
    def __init__(self):
        self.loop_n = 1
        
        print u'初期設定中...',
        Tkinter.Tk.__init__(self)
        self.canvas = Tkinter.Canvas(self, bg="white")
        self.canvas.pack(fill=Tkinter.BOTH,expand=True)
        self.geometry('300x200')
        self.title('DEM')
        
        self.dem_ui = DEM_UI()
        
        for l in self.dem_ui.lines():
            xy = self.viewCoord([l.x1,l.y1,l.x2,l.y2])
            self.canvas.create_line(xy,width=1)
        
        self.redraw()
        self.update_idletasks()
        
        print(u'解析開始')

    def calcloop(self):        
        dem.calcStep(5)
        if self.loop_n == 2:
            self.saveCalcTime('start')
        if self.loop_n % 1 == 0:
            print('Step %d' % dem.step())
        if self.loop_n % 2 == 0:
            self.redraw()
        if self.loop_n % 4 == 0:
            #self.saveImage()
            pass
        if self.loop_n >= 1000:
            self.saveCalcTime('finish')
            print(u'解析終了.設定最大ループに達しました')
        else:
            self.after(0,self.calcloop)
        self.update_idletasks()
        self.loop_n += 1
 
    def redraw(self):
        self.canvas.delete('elem')
        h = 200
        for p in self.dem_ui.particles():
            x1,y1 = self.viewCoord([p.x-p.r,p.y-p.r])
            x2,y2 = self.viewCoord([p.x+p.r,p.y+p.r])
            self.canvas.create_oval(x1,y1,x2,y2,tags='elem')
            
            x1,y1 = self.viewCoord([p.x,p.y])
            x2,y2 = self.viewCoord([p.x+p.r*math.cos(p.a),
                                    p.y+p.r*math.sin(p.a)])
            self.canvas.create_line(x1,y1,x2,y2,tags='elem')
        
    def viewCoord(self,coords,offset=(0,0)):
        s = 1.0 # 表示倍率
        h = 200 #表示画面高さ
        w = 300 #表示画面幅
        x_offset = 0#int(w/2)
        y_offset = 0#int(h/2)
        xy_list = []
        for i in range(0,len(coords),2):
            x = round(s*coords[i])+x_offset
            y = round(h-s*coords[i+1])-y_offset
            x = x + offset[0]
            y = y + offset[1]
            xy_list.append(x)
            xy_list.append(y)
        return xy_list

    def saveCalcTime(self,option):
        if option == 'start':
            self.st_time = time.time()
            self.st_step = dem.step()
        elif option == 'finish':
            now = time.time()
            dt = now-self.st_time
            ds = dem.step() - self.st_step +1
            f = open('calc_time.txt','w')
            f.write('START STEP %d\n' % self.st_step)
            f.write('START TIME {0}\n'.format(self.st_time))
            f.write('END STEP %d\n' % dem.step())
            f.write('END TIME {0}\n'.format(now))
            f.write('DIFF STEP %d \n' % ds)
            f.write('DIFF TIME {0}\n'.format(dt))
            f.write('ONE STEP TIME {0}'.format(dt/ds))
            f.close()
        
    def saveImage(self):
        filepath = 'c://Temp/dem/capture%05d.png' % dem.step()
        img = ImageGrab.grab()
        s,x,y = self.geometry().split('+')
        w,h = s.split('x')
        w,h,x,y = map(int,[w,h,x,y])
        x += 8
        y += 30
        img = img.crop((x,y,x+w,y+h))
        img.save(filepath)

def main():
    w = Window()
    w.after(0,w.calcloop)
    w.mainloop()
        
print u'完了'

if __name__ == '__main__':
    main()


 

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です