forked from parlance/ctcdecode
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
56 lines (47 loc) · 1.77 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python
import multiprocessing.pool
import os
from setuptools import setup, find_packages, distutils
from torch.utils.cpp_extension import BuildExtension
this_file = os.path.dirname(__file__)
# monkey-patch for parallel compilation
# See: https://stackoverflow.com/a/13176803
def parallelCCompile(self,
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None):
# those lines are copied from distutils.ccompiler.CCompiler directly
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs)
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
# parallel code
def _single_compile(obj):
try:
src, ext = build[obj]
except KeyError:
return
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
# convert to list, imap is evaluated on-demand
thread_pool = multiprocessing.pool.ThreadPool(4)
list(thread_pool.imap(_single_compile, objects))
return objects
# hack compile to support parallel compiling
distutils.ccompiler.CCompiler.compile = parallelCCompile
import build
setup(
name="ctcdecode",
version="0.4",
description="CTC Decoder for PyTorch based on Paddle Paddle's implementation",
url="https://github.com/parlance/ctcdecode",
author="Ryan Leary",
author_email="[email protected]",
# Exclude the build files.
packages=find_packages(exclude=["build"]),
ext_modules = [build.extension],
cmdclass={'build_ext': BuildExtension}
)