Skip to content

2014 08 18 python multiprocessing

Jaesoo Lim edited this page Aug 18, 2014 · 4 revisions

python 스크립트를 이용하여 많은 데이터를 처리하려면 hadoop이나 parallel을 이용하면 별도로 스크립트를 수정하지 않고도 사용할 수 있습니다. 그러나 python의 multiprocessing 모듈을 이용하면 이러한 툴들 못지 않게 편리하게 병렬 처리를 하실 수 있습니다. 단, hadoop 처럼 여러 장비에 걸쳐 병렬로 수행하는 것은 multiprocessing 모듈로도 불가능합니다.

import datetime
import multiprocessing
import sys
import time

def square(num):
  time.sleep(0.1)
  return num * num

def main():
  pool = multiprocessing.Pool(2)
  for result in pool.map(square, range(100)):
    print result

if __name__ == '__main__':
  start_time = datetime.datetime.now()
  main()
  print >> sys.stderr, 'Elapsed time:', (datetime.datetime.now() - start_time)

위 코드는 1부터 100까지 제곱을 계산하는 square() 함수를 2개의 프로세서를 이용하여 적용하고 있습니다. 그냥 계산하면 순식간에 끝나니까 제곱을 계산할 때마다 0.1초를 sleep하도록 했습니다. 실행해 보면 5초가 약간 넘게 걸리면 끝나는 것을 확인하실 수 있습니다. (100번의 연산을 0.1초씩 걸려 2개의 프로세서가 수행하니 5초입니다.)

참고로, Pool.map() 메서드는 입력과 출력의 순서가 일치됩니다. 즉, 위에서 square() 함수가 수행되는 순서가 자식 프로세스에서 어떻게 되든 result로 들어오는 결과의 순서는 입력의 순서와 일치됩니다.

그런데, map() 메서드에 준 sqare() 함수가 수행되다 죽어버리면 어떻게 될까요? 아래 코드는 sqarue() 함수에 임의로 입력값 10에 대해 비정상 종료하도록 exit 코드를 추가했습니다.

def square(num):
  if num == 10: sys.exit(-1)
  time.sleep(0.1)
  return num * num

애석하게도 이렇게 하면 Pool.map() 메서드는 영원히 결과를 기다립니다. 물론 11부터 99까지의 입력에 대해서는 백그라운드로 다 돌려 놓고도 말이죠. 더욱 난감한 것은 Ctrl+C로도 죽지 않아서 부모 프로세스를 kill로 죽여야 합니다. 위의 예에서는 2개의 프로세서 풀을 사용했으므로 3개의 프로세스가 돌아가게 되고 2개의 자식 프로세스를 갖는 부모 프로세스를 찾아서 죽여야 합니다. ^^;

이런 상황은 ctypes 모듈을 통해 C로 작성된 공유 라이브러리의 API를 호출하는 경우 통제가 불가능한 상황이 발생할 수도 있습니다. 라이브러리를 호출하다 segmentation fault로 죽었는데 결과를 계속해서 기다리는 경우입니다.

물론 방법은 있습니다. 아래 코드를 한번 보시죠.

import datetime
import multiprocessing
import sys
import time

def square(num):
  if num == 10: sys.exit(-1)
  time.sleep(0.1)
  return num * num

def main():
  pool = multiprocessing.Pool(2)
  results = [pool.apply_async(square, (num,)) for num in range(100)]
  for idx, result in enumerate(results):
    try:
      print result.get(timeout=3)
    except multiprocessing.TimeoutError:
      print 'Failed at:', idx
      continue

if __name__ == '__main__':
  start_time = datetime.datetime.now()
  main()
  print >> sys.stderr, 'Elapsed time:', (datetime.datetime.now() - start_time)

Pool.apply_async() 메서드는 이름에서도 알 수 있듯이 결과를 asynchronous하게 받습니다. 즉, 결과의 객체 리스트 results는 즉시 생성이 되고 결과 객체의 get() 메서드를 호출할 때 실제 결과를 가져오게 됩니다. 이때 time out을 설정하면 지정한 시간이 지나면 TimeoutError 예외가 발생하고 설정하지 않으면 영원히 기다리게 됩니다.

위 코드는 입력값 10에 대해 3초의 time out이 지나더라도 백그라운드로 미리 계산을 수행하고 있으므로, 동일하게 5초만에 모든 계산이 끝납니다.

참고로, apply_async() 메서드의 두번째 파라미터는 첫번째 파라미터 squre() 함수에 전달할 파라미터입니다. 반드시 튜플로 전달해야 해서 하나의 파라미터를 전달 하더라도 튜플을 만들어서 넣어 줘야 합니다.

Clone this wiki locally