前回のプログラムを要素数492ほどで試してみたら、ネットブック(Acer Aspire One HAPPY2)では、1ステップで2.55秒もかかりました。お、おう。ネットブックで解析を回すつもりは無いとはいえ、最低でも1万ステップくらい回すことを踏まえるとさすがに遅い。ノートPC(Dell Inspiron14 core i3-4010U)でも0.63秒で4倍くらい速いものの、全然遅い。勉強用とはいえ、5000要素で1万ステップくらいは1分以内でやってほしい。1ステップで0.005秒くらいとなると、100倍以上の高速化が必要だ。
Pythonが他のプログラムに比べても演算の処理速度が遅いことは知っていたので、最終段階で高速化しようと思っていたけども、現時点で、高速化しないと色々遊んでみることもままならなさそう・・・。っということで、Cython化してみました。
Pythonの高速化の手段はいろいろあって、今回ただの演算なので、Numpyとかの演算ライブラリを使うのが有用な手段だけども、プログラムの可読性が悪くなるので、Cythonに挑戦してみました。
Cython入門 静的型付けでコードを高速化する
http://omake.accense.com/static/doc-ja/cython/src/quickstart/cythonize.html
結果から
ステップ0:Python 1ステップ2.55秒
ステップ1:ソースを何も変えずにCython化 1ステップ2.01秒(1.27倍)
ステップ2:ループ仕様を変更 1ステップ1.82秒(1.40倍)
ここまではCythonインストール以外の手間はほぼないので、うれしい高速化なのだけど、目指すところはその程度じゃない。
ステップ3:クラスを全てcdef化してみる⇒上手くコーディングできない
この辺になってくると、Cythonの仕様を理解しなくちゃいけなくなってきました。条件設定や画面表示などは通常のPythonにしたい都合上、Pythonから見える見えない問題の解決が難しくて、Cython仕様に合わせて一から作ってみようと思いました。
で、まだ全然書ききれていないのだけど、形になってきたので、とりあえずどれだけ速くなりそうか見てみたい。
ステップ4:Cythonで作り直し。1ステップ0.062秒(41倍)
おおー。100倍とは行かないまでも、随分と速くなりました。いや、まぁプログラムを省略している部分が多いんだけど。
体感的にはこんな感じ(ノートPC)
元のPythonにおける要素データは、要素クラスのインスタンスをそのままリストに代入していく方式。プログラムがとても分かりやすくなるのだけど、Cythonでクラスインスタンスを配列に格納したりする方法が分からず、結局、構造体の配列にすることにしました。数値配列(行列)にしてしまうと可読性が悪くなってしまうし、それならば、Numpyなどのライブラリ使うほうが、速度的にも、利便的にも良くて、Cythonを使う意味がなくなっちゃうので。それに伴い、クラスは使わないように変更(オブジェクト指向は断念)。
まだ、ちゃん動いているのかどうかの検証も、Python版では計算している処理もしてなかったりするので、ちゃんとした比較ではないけども、Pythonの書き方+ちょっとした知識だけで高速なモジュールが作れるのは、ありがたい。
dem.pyx (dem.pyd)
# -*- coding: utf-8 -*- #Cython仕様で一から作ってみる from libc.math cimport sqrt,sin,cos,atan2,fabs from cpython cimport bool DEF MAX_PAR_NO = 1000 #最大粒子数 DEF MAX_LINE_NO = 10 #最大線要素数 DEF MAX_ELEM_NO = MAX_PAR_NO + MAX_LINE_NO #最大要素数 cdef float PI = 3.1415926535 #円周率 cdef float G = 9.80665 #重力加速度 cdef struct Particle: #要素共通 int etype #要素タイプ int n #要素No. float r#半径 float x#X座標 float y#Y座標 float a#角度 float dx#X方向増加量 float dy#Y方向増加量 float da#角度増加量 float vx#X方向速度 float vy#Y方向速度 float va#角速度 float fy float fx float fm float en[MAX_ELEM_NO]#弾性力(直方向) float es[MAX_ELEM_NO]#弾性力(せん断方向) #粒子専用 float m# 質量 float Ir#慣性モーメント cdef struct Line: #要素共通 int etype #要素タイプ int n #要素No. float r#半径 float x#X座標 float y#Y座標 float a#角度 float dx#X方向増加量 float dy#Y方向増加量 float da#角度増加量 float vx#X方向速度 float vy#Y方向速度 float va#角速度 float fy float fx float fm float en[MAX_ELEM_NO]#弾性力(直方向) float es[MAX_ELEM_NO]#弾性力(せん断方向) #線要素専用 float x1 float y1 float x2 float y2 cdef struct Interface: float kn #弾性係数(法線方向) float etan#粘性係数(法線方向) float ks#弾性係数(せん断方向) float etas#弾性係数(せん断方向) float frc#摩擦係数 cdef Particle pe[MAX_ELEM_NO] cdef Line le[MAX_LINE_NO] cdef float rho = 10 #粒子間密度 cdef int peCount #粒子の数 cdef int leCount #線要素の数 cdef float dt = 0.001 #⊿t cdef int st = 0 #ステップ cdef void nextStep(): resetForce() calcForce() updateCoord() global st st += 1 cdef void resetForce(): cdef int i for i in range(peCount): pe[i].fx = 0 pe[i].fy = 0 pe[i].fm = 0 cdef Interface interface(int etype1,int etype2): cdef Interface inf if etype1 == 1 and etype2 == 1: inf.kn = 10**6 #弾性係数(法線方向) inf.etan = 10000 #粘性係数(法線方向) inf.ks = 10*6 #弾性係数(せん断方向) inf.etas = 10000 #粘性係数(せん断方向) inf.frc = 10*6 #摩擦係数 return inf cdef void calcForce(): #2粒子間の接触判定 cdef float lx,ly,ld,cos_a,sin_a cdef int i,j for i in range(peCount): for j in range(i+1,peCount): lx = pe[i].x - pe[j].x ly = pe[i].y - pe[j].y ld = (lx**2+ly**2)**0.5 if (pe[i].r+pe[j].r)>ld: cos_a = lx/ld sin_a = ly/ld forcePar2par(i,j,cos_a,sin_a) else: pe[i].en[j] = 0.0 pe[i].es[j] = 0.0 #粒子と線の接触判定 cdef bool hit cdef Particle p cdef Line l for i in range(peCount): for j in range(leCount): hit = False p = pe[i] l = le[j] th0 = atan2(l.y2-l.y1, l.x2-l.x1) th1 = atan2(p.y-l.y1, p.x-l.x1) a = sqrt((p.x-l.x1)**2+(p.y-l.y1)**2) d = fabs(a*sin(th1-th0)) if d < p.r: b = sqrt((p.x-l.x2)**2+(p.y-l.y2)**2) s = sqrt((l.x2-l.x1)**2+(l.y2-l.y1)**2) if a < s and b < s: s1 = sqrt(a**2-d**2) x = l.x1 + s1*cos(th0) y = l.y1 + s1*sin(th0) hit = True elif a < b and a < p.r: x = l.x1 y = l.y1 hit = True elif b < p.r: x = l.x2 y = l.y2 hit = True if hit: lx = p.x - x ly = p.y - y ld = sqrt(lx**2+ly**2) cos_a = lx/ld sin_a = ly/ld forceLine2par(p,l,cos_a,sin_a) else: p.en[l.n] = 0.0 p.es[l.n] = 0.0 #外力 for i in range(peCount): pe[i].fy += -G*pe[i].m #重力 cdef void forcePar2par(int i,int j,float cos_a, float sin_a): cdef float un,us,vn,vs,hn,hs #相対的変位増分 un = +(pe[i].dx-pe[j].dx)*cos_a+(pe[i].dy-pe[j].dy)*sin_a us = -(pe[i].dx-pe[j].dx)*sin_a+(pe[i].dy-pe[j].dy)*cos_a+(pe[i].r*pe[i].da+pe[j].r*pe[j].da) #相対的速度増分 vn = +(pe[i].vx-pe[j].vx)*cos_a+(pe[i].vy-pe[j].vy)*sin_a vs = -(pe[i].vx-pe[j].vx)*sin_a+(pe[i].vy-pe[j].vy)*cos_a+(pe[i].r*pe[i].va+pe[j].r*pe[j].va) inf = interface(pe[i].etype,pe[j].etype) #合力(局所座標系) pe[i].en[pe[j].n] += inf.kn*un pe[i].es[pe[j].n] += inf.ks*us hn = pe[i].en[pe[j].n] + inf.etan*vn hs = pe[i].es[pe[j].n] + inf.etas*vs if hn <= 0.0: #法線力がなければ、せん断力は0 hs = 0.0 elif fabs(hs) >= inf.frc*hn: #摩擦力以上のせん断力は働かない hs = inf.frc*fabs(hn)*hs/fabs(hs) #全体合力(全体座標系) pe[i].fx += -hn*cos_a + hs*sin_a pe[i].fy += -hn*sin_a - hs*cos_a pe[i].fm -= pe[i].r*hs pe[j].fx += hn*cos_a - hs*sin_a pe[j].fy += hn*sin_a + hs*cos_a pe[j].fm -= pe[j].r*hs cdef void forceLine2par(Particle p1,Line l,float cos_a, float sin_a): cdef float un,us,vn,vs,hn,hs #相対的変位増分 un = +p1.dx*cos_a+p1.dy*sin_a us = -p1.dx*sin_a+p1.dy*cos_a+p1.r*p1.da #相対的速度増分 vn = +p1.vx*cos_a+p1.vy*sin_a vs = -p1.vx*sin_a+p1.vy*cos_a+p1.r*p1.va inf = interface(p1.etype,l.etype) #合力(局所座標系) pe[p1.n].en[l.n] += inf.kn*un pe[p1.n].es[l.n] += inf.ks*us hn = p1.en[l.n] + inf.etan*vn hs = p1.es[l.n] + inf.etas*vs if hn <= 0.0: #法線力がなければ、せん断力は0 hs = 0.0 elif abs(hs) >= inf.frc*hn: #摩擦力以上のせん断力は働かない hs = inf.frc*fabs(hn)*hs/fabs(hs) #全体合力(全体座標系) pe[p1.n].fx += -hn*cos_a + hs*sin_a pe[p1.n].fy += -hn*sin_a - hs*cos_a pe[p1.n].fm -= p1.r*hs cdef void updateCoord(): cdef float ax,ay,aa cdef int i for i in range(peCount): #位置更新(オイラー差分) ax = pe[i].fx/pe[i].m ay = pe[i].fy/pe[i].m aa = pe[i].fm/pe[i].Ir pe[i].vx += ax*dt pe[i].vy += ay*dt pe[i].va += aa*dt pe[i].dx = pe[i].vx*dt pe[i].dy = pe[i].vy*dt pe[i].da = pe[i].va*dt pe[i].x += pe[i].dx pe[i].y += pe[i].dy pe[i].a += pe[i].da # ------------------------- # Pythonからの設定用 # ------------------------- def setDeltaTime(sec): global dt dt = sec def setNumberOfParticle(n): global peCount peCount = n def numberOfParticle(): return peCount def setParticle(pe_no,pe_obj): pe[pe_no].x = pe_obj.x pe[pe_no].y = pe_obj.y def particle(pe_no,pe_obj): pe_obj.x = pe[pe_no].x pe_obj.y = pe[pe_no].y pe_obj.a = pe[pe_no].a return pe_obj def setLine(l_no,l_obj): le[l_no].x1 = l_obj.x1 le[l_no].y1 = l_obj.y1 le[l_no].x2 = l_obj.x2 le[l_no].y2 = l_obj.y2 return l_obj def setNumberOfLine(n): global leCount leCount = n def line(l_no,l_obj): l_obj.x1 = le[l_no].x1 l_obj.y1 = le[l_no].y1 l_obj.x2 = le[l_no].x2 l_obj.y2 = le[l_no].y2 return l_no def initialize(): cdef int i,j for i in range(MAX_ELEM_NO): pe[i].x = 0 pe[i].y = 0 pe[i].r = 0 pe[i].a = 0 pe[i].dx = 0 pe[i].dy = 0 pe[i].da = 0 pe[i].vx = 0 pe[i].vy = 0 pe[i].va = 0 pe[i].fx = 0 pe[i].fy = 0 pe[i].fm = 0 pe[i].m = 0 pe[i].Ir = 0 for j in range(MAX_ELEM_NO): pe[i].en[j] = 0 pe[i].es[j] = 0 def setup(): cdef int i #粒子要素 for i in range(peCount): pe[i].n = i pe[i].r = 5.0 pe[i].m = 4.0/3.0*PI*rho*pe[i].r**3 # 質量 pe[i].Ir = PI*rho*pe[i].r**4/2.0 #慣性モーメント #線要素 for i in range(leCount): le[i].n = peCount+i def step(): return st def calcStep(int n=1): cdef int i for i in range(n): nextStep()
dem_ui.py
# -*- coding: utf-8 -*- print u'読み込み中...', import sys import math import random import time import Tkinter import dem from PIL import ImageGrab class Element(object): def __init__(self): self.n = 0 #要素No. self.r = 0 #半径 self.x = 0 #X座標 self.y = 0 #Y座標 self.a = 0 #角度 self.dx = 0 #X方向増加量 self.dy = 0 #Y方向増加量 self.da = 0 #角度増加量 self.vx = 0 #X方向速度 self.vy = 0 #Y方向速度 self.va = 0 #角速度 self.fy = 0 self.fx = 0 self.fm = 0 self.en = [] #弾性力(直方向) self.es = [] #弾性力(せん断方向) class Particle(Element): def __init__(self,x=0,y=0,vx=0,vy=0): super(Particle,self).__init__() self.type = 1 self.x = x #X座標 self.y = y #Y座標 self.vx = vx #X方向速度 self.vy = vy #Y方向速度 rho = 10 self.r = 5 #半径 self.m = 4.0/3.0*math.pi*rho*self.r**3 # 質量 self.loop_nr = math.pi*rho*self.r**4/2.0 #慣性モーメント class Line(Element): def __init__(self,x1,y1,x2,y2): super(Line,self).__init__() self.type = 2 self.x1 = x1 self.y1 = y1 self.x2 = x2 self.y2 = y2 class DEM_UI: def __init__(self): self._pars = [] self._lines = [] self._setup() def _setup(self): dem.initialize() dem.setDeltaTime(0.01) # set lines self._lines = [] self._lines.append(Line(5,5,295,5)) self._lines.append(Line(5,195,5,5)) self._lines.append(Line(295,195,295,5)) self._lines.append(Line(5,195,295,195)) self._lines.append(Line(100,100,300,150)) self._lines.append(Line(10,80,160,50)) dem.setNumberOfLine(len(self._lines)) for i,l in enumerate(self._lines): dem.setLine(i,l) # set particle max_n = 100 self.parCount = 0 for x in range(40,290,2): for y in range(145,195,2): if self._hitParticle(x,y,5): continue if self._hitLine(x,y,5): continue p = Particle(x,y) dem.setParticle(self.parCount,p) self.parCount += 1 dem.setNumberOfParticle(self.parCount) dem.setup() print(u'完了') print(u'粒子要素数: %d ' % self.parCount) def _hitParticle(self,x,y,r): hit = False for p in self.particles(): lx = p.x - x ly = p.y - y ld = (lx**2+ly**2)**0.5 if (p.r+r)>=ld: hit = True break return hit def _hitLine(self,px,py,pr): hit = False for l in self._lines: th0 = math.atan2(l.y2-l.y1,l.x2-l.x1) th1 = math.atan2(py-l.y1,px-l.x1) a = math.sqrt((px-l.x1)**2+(py-l.y1)**2) d = abs(a*math.sin(th1-th0)) if d < pr: b = math.sqrt((px-l.x2)**2+(py-l.y2)**2) s = math.sqrt((l.x2-l.x1)**2+(l.y2-l.y1)**2) if a < s and b < s: hit = True elif a < b and a < pr: hit = True elif b < pr: hit = True if hit: break return hit def particles(self): pars = [] for i in range(self.parCount): p = dem.particle(i,Particle()) pars.append(p) return pars def lines(self): return self._lines class Window(Tkinter.Tk): def __init__(self): self.loop_n = 1 print u'初期設定中...', Tkinter.Tk.__init__(self) self.canvas = Tkinter.Canvas(self, bg="white") self.canvas.pack(fill=Tkinter.BOTH,expand=True) self.geometry('300x200') self.title('DEM') self.dem_ui = DEM_UI() for l in self.dem_ui.lines(): xy = self.viewCoord([l.x1,l.y1,l.x2,l.y2]) self.canvas.create_line(xy,width=1) self.redraw() self.update_idletasks() print(u'解析開始') def calcloop(self): dem.calcStep(5) if self.loop_n == 2: self.saveCalcTime('start') if self.loop_n % 1 == 0: print('Step %d' % dem.step()) if self.loop_n % 2 == 0: self.redraw() if self.loop_n % 4 == 0: #self.saveImage() pass if self.loop_n >= 1000: self.saveCalcTime('finish') print(u'解析終了.設定最大ループに達しました') else: self.after(0,self.calcloop) self.update_idletasks() self.loop_n += 1 def redraw(self): self.canvas.delete('elem') h = 200 for p in self.dem_ui.particles(): x1,y1 = self.viewCoord([p.x-p.r,p.y-p.r]) x2,y2 = self.viewCoord([p.x+p.r,p.y+p.r]) self.canvas.create_oval(x1,y1,x2,y2,tags='elem') x1,y1 = self.viewCoord([p.x,p.y]) x2,y2 = self.viewCoord([p.x+p.r*math.cos(p.a), p.y+p.r*math.sin(p.a)]) self.canvas.create_line(x1,y1,x2,y2,tags='elem') def viewCoord(self,coords,offset=(0,0)): s = 1.0 # 表示倍率 h = 200 #表示画面高さ w = 300 #表示画面幅 x_offset = 0#int(w/2) y_offset = 0#int(h/2) xy_list = [] for i in range(0,len(coords),2): x = round(s*coords[i])+x_offset y = round(h-s*coords[i+1])-y_offset x = x + offset[0] y = y + offset[1] xy_list.append(x) xy_list.append(y) return xy_list def saveCalcTime(self,option): if option == 'start': self.st_time = time.time() self.st_step = dem.step() elif option == 'finish': now = time.time() dt = now-self.st_time ds = dem.step() - self.st_step +1 f = open('calc_time.txt','w') f.write('START STEP %d\n' % self.st_step) f.write('START TIME {0}\n'.format(self.st_time)) f.write('END STEP %d\n' % dem.step()) f.write('END TIME {0}\n'.format(now)) f.write('DIFF STEP %d \n' % ds) f.write('DIFF TIME {0}\n'.format(dt)) f.write('ONE STEP TIME {0}'.format(dt/ds)) f.close() def saveImage(self): filepath = 'c://Temp/dem/capture%05d.png' % dem.step() img = ImageGrab.grab() s,x,y = self.geometry().split('+') w,h = s.split('x') w,h,x,y = map(int,[w,h,x,y]) x += 8 y += 30 img = img.crop((x,y,x+w,y+h)) img.save(filepath) def main(): w = Window() w.after(0,w.calcloop) w.mainloop() print u'完了' if __name__ == '__main__': main()