読者です 読者をやめる 読者になる 読者になる

わらばんし仄聞記

南の国で引きこもってるWeb屋さん

pythonでJIT(64bit版Linux環境) part2

jit python

pythonでJIT part1は非常に単純な例だったので、もう少し複雑にしてみることに。

目標

アルファベットのA~Zをループで取得し、それぞれをputchar関数を使って出力。終わったら改行文字の出力をするネイティブコードをJITで実行する。

アセンブラでのコード

目標の出力を得るアセンブラのコードを書いてみる

  • sample.s
.intel_syntax noprefix
.globl main

main:
    mov     r12, 26
    mov     rbx, 0x41

loop:
    cmp     r12, 0
    jz      end

    mov     rdi, rbx
    call    putchar

    inc     rbx
    dec     r12
    jmp     loop

end:
    mov     rdi, 0xa
    call    putchar

コンパイルして実行してみる

$ gcc sample.s
$ ./a.out
ABCDEFGHIJKLMNOPQRSTUVWXYZ

ネイティブコードを得る

上記のコードをJITで動かす為に少し調整し、ネイティブコードを得る

  • base.s
.intel_syntax noprefix
.globl main

main:
    push    r12
    push    r13
    push    rbx

    mov     r13, 0x123456789abcdef0
    mov     r12, 26
    mov     rbx, 0x41

loop:
    cmp     r12, 0
    jz      end

    mov     rdi, rbx
    call    r13

    inc     rbx
    dec     r12
    jmp     loop

end:
    mov     rdi, 0xa
    call    r13

    pop     rbx
    pop     r13
    pop     r12
    ret

異なる点としては

  • r12,r13,rbxをpushとpopしている(pushした逆順にpopすること)
  • r13に64bit分のスペースを確保させておいている
  • callでputcharでなくr13を呼んでいる

となり、r13には後ほどputchar関数のアドレスを入れる為のダミーの値を入れている。

勿論これを単純にコンパイルして実行すると、Segmentation faultとなって怒られる。retで終わってるので、このコードが復帰する場所がないと辻褄が合わないので。単体で実行したければexitで終わらなければならなく、また、その場合は復帰時を考慮しての為に値を退避させていたpushとpopの処理も必要ない。
このコードからネイティブコードを得るには

$ as base.s
$ objdump -d -M intel a.out

とする。asの変わりにgccでも問題ないが、その場合は今回やりたいことには余分なヘッダ情報なども出てくるので、必要な所だけ出すにはasコマンドの方が都合がいい。

この結果より

a.out:     file format elf64-x86-64


Disassembly of section .text:

0000000000000000 <main>:
   0:   41 54                   push   r12
   2:   41 55                   push   r13
   4:   53                      push   rbx
   5:   49 bd f0 de bc 9a 78    movabs r13,0x123456789abcdef0
   c:   56 34 12 
   f:   49 c7 c4 1a 00 00 00    mov    r12,0x1a
  16:   48 c7 c3 41 00 00 00    mov    rbx,0x41

000000000000001d <loop>:
  1d:   49 83 fc 00             cmp    r12,0x0
  21:   74 0e                   je     31 <end>
  23:   48 89 df                mov    rdi,rbx
  26:   41 ff d5                call   r13
  29:   48 ff c3                inc    rbx
  2c:   49 ff cc                dec    r12
  2f:   eb ec                   jmp    1d <loop>

0000000000000031 <end>:
  31:   48 c7 c7 0a 00 00 00    mov    rdi,0xa
  38:   41 ff d5                call   r13
  3b:   5b                      pop    rbx
  3c:   41 5d                   pop    r13
  3e:   41 5c                   pop    r12
  40:   c3                      ret 

を得る。

pythonJIT

さて本題。
基本的にはpart1のコードと同様で、ネイティブコードの部分を置き換える

import sys, struct
from ctypes import *

libc = cdll.LoadLibrary("libc.so.6")
free = libc.free
printf = libc.printf
putchar = libc.putchar

mmap = libc.mmap
mmap.restype = c_void_p
munmap = libc.munmap
munmap.argtype = [c_void_p, c_size_t]

PROT_READ       = 1
PROT_WRITE      = 2
PROT_EXEC       = 4
MAP_PRIVATE     = 2
MAP_ANONYMOUS   = 0x20

def conv64(dw):
    return map(ord, struct.pack("<q" if dw < 0 else "<Q", dw))

codes = (c_ubyte * 128) (
    0x41, 0x54,                                 # push  r12
    0x41, 0x55,                                 # push  r13
    0x53,                                       # push  rbx
    0x49, 0xbd, 0x00, 0x00, 0x00, 0x00, 0x00,   # mov   r13, (long)
    0x00, 0x00, 0x00,
    0x49, 0xc7, 0xc4, 0x1a, 0x00, 0x00, 0x00,   # mov   r12, 0x1a
    0x48, 0xc7, 0xc3, 0x41, 0x00, 0x00, 0x00,   # mov   rbx, 0x41
    0x49, 0x83, 0xfc, 0x00,                     # cmp   r12, 0
    0x74, 0x0e,                                 # je    <end>
    0x48, 0x89, 0xdf,                           # mov   rdi, rbx
    0x41, 0xff, 0xd5,                           # call  r13
    0x48, 0xff, 0xc3,                           # inc   rbx
    0x49, 0xff, 0xcc,                           # dec   r12
    0xeb, 0xec,                                 # jmp   <loop>
    0x48, 0xc7, 0xc7, 0x0a, 0x00, 0x00, 0x00,   # mov   rdi, 0xa
    0x41, 0xff, 0xd5,                           # call  r13
    0x5b,                                       # pop   rbx
    0x41, 0x5d,                                 # pop   r13
    0x41, 0x5c,                                 # pop   r12
    0xc3,                                       # ret
)

buflen = len(codes)
p = mmap(
    0, buflen,
    PROT_READ | PROT_WRITE | PROT_EXEC,
    MAP_PRIVATE | MAP_ANONYMOUS,
    -1, 0
)

getaddr = CFUNCTYPE(c_void_p, c_void_p)(lambda p: p)
f       = CFUNCTYPE(c_void_p)(p)

codes[7:15] = conv64(getaddr(putchar))
memmove(p, addressof(codes), buflen)

f() 

munmap(p, buflen)

基本的には先に出力したネイティブコードを踏襲してますが、r13にmovしてた値を一応0x00で埋めてあります。

ポイント

  • putcharのアドレス取得方法

コードを見るとaddressof(putchar)とはせず、単に受け取ったポインタを返すだけのgetaddrを使っています。これはおそらく、pythonがlibcを読み込んだ時に各関数の実体を参集するリストを作成し、addressof(putchar)とすると、そのリストのputcharについての参照を持つ箇所を取得してしまい、putcharそのものを取得出来ないためと思われます(教えて頂いたので、ちゃんと自分で検証まではしてません)。
試しに

print("putchar address = %s" % hex(addressof(putchar)))
print("putchar address = %s" % hex(addressof(libc.putchar)))
libc.printf("putchar address = %p\n", putchar)
print("putchar address = %s" % hex(getaddr(putchar)))

とすると、結果は

putchar address = 0x7733e0
putchar address = 0x7733e0
putchar address = 0x7fc64fb0fbd0
putchar address = 0x7fc64fb0fbd0

となります。

  • putchar関数の埋め込み

ネイティブコードへputchar関数のアドレスを埋め込む際

codes[7:15] = conv64(getaddr(putchar))

としています。conv64はコードの上部で定義しているように、渡した値をリトルエンディアンな1バイトずつの配列にして返しています。それをcodesのputcharを埋め込みたい箇所に代入していることになります。

  • ジャンプ命令の位置指定

たとえば

0x74, 0x0e,                                 # je    

としているジャンプ命令。この後半1バイトの0x0eは、この次の命令の先頭アドレスからジャンプ先命令の先頭アドレスとのオフセットになります。

結果

$ python python_jit_2.py
ABCDEFGHIJKLMNOPQRSTUVWXYZ

無事目標を得られました