Simplify and speed up multiprocessing

This commit is contained in:
Anthony Sottile 2016-11-22 15:43:05 -08:00
parent 348722d77b
commit 109f5f8888
4 changed files with 64 additions and 105 deletions

View file

@ -1,7 +1,9 @@
"""Checker Manager and Checker classes.""" """Checker Manager and Checker classes."""
import collections
import errno import errno
import logging import logging
import os import os
import signal
import sys import sys
import tokenize import tokenize
@ -10,11 +12,6 @@ try:
except ImportError: except ImportError:
multiprocessing = None multiprocessing = None
try:
import Queue as queue
except ImportError:
import queue
from flake8 import defaults from flake8 import defaults
from flake8 import exceptions from flake8 import exceptions
from flake8 import processor from flake8 import processor
@ -76,10 +73,8 @@ class Manager(object):
self.options = style_guide.options self.options = style_guide.options
self.checks = checker_plugins self.checks = checker_plugins
self.jobs = self._job_count() self.jobs = self._job_count()
self.process_queue = None
self.results_queue = None
self.statistics_queue = None
self.using_multiprocessing = self.jobs > 1 self.using_multiprocessing = self.jobs > 1
self.pool = None
self.processes = [] self.processes = []
self.checkers = [] self.checkers = []
self.statistics = { self.statistics = {
@ -91,9 +86,7 @@ class Manager(object):
if self.using_multiprocessing: if self.using_multiprocessing:
try: try:
self.process_queue = multiprocessing.Queue() self.pool = multiprocessing.Pool(self.jobs, _pool_init)
self.results_queue = multiprocessing.Queue()
self.statistics_queue = multiprocessing.Queue()
except OSError as oserr: except OSError as oserr:
if oserr.errno not in SERIAL_RETRY_ERRNOS: if oserr.errno not in SERIAL_RETRY_ERRNOS:
raise raise
@ -104,35 +97,11 @@ class Manager(object):
while not q.empty(): while not q.empty():
q.get_nowait() q.get_nowait()
def _force_cleanup(self):
if self.using_multiprocessing:
for proc in self.processes:
proc.join(0.2)
self._cleanup_queue(self.process_queue)
self._cleanup_queue(self.results_queue)
self._cleanup_queue(self.statistics_queue)
def _process_statistics(self): def _process_statistics(self):
all_statistics = self.statistics for checker in self.checkers:
if self.using_multiprocessing:
total_number_of_checkers = len(self.checkers)
statistics_gathered = 0
while statistics_gathered < total_number_of_checkers:
try:
statistics = self.statistics_queue.get(block=False)
statistics_gathered += 1
except queue.Empty:
break
for statistic in defaults.STATISTIC_NAMES: for statistic in defaults.STATISTIC_NAMES:
all_statistics[statistic] += statistics[statistic] self.statistics[statistic] += checker.statistics[statistic]
else: self.statistics['files'] += len(self.checkers)
statistics_generator = (checker.statistics
for checker in self.checkers)
for statistics in statistics_generator:
for statistic in defaults.STATISTIC_NAMES:
all_statistics[statistic] += statistics[statistic]
all_statistics['files'] += len(self.checkers)
def _job_count(self): def _job_count(self):
# type: () -> int # type: () -> int
@ -189,19 +158,6 @@ class Manager(object):
# it to an integer # it to an integer
return int(jobs) return int(jobs)
def _results(self):
seen_done = 0
LOG.info('Retrieving results')
while True:
result = self.results_queue.get()
if result == 'DONE':
seen_done += 1
if seen_done >= self.jobs:
break
continue
yield result
def _handle_results(self, filename, results): def _handle_results(self, filename, results):
style_guide = self.style_guide style_guide = self.style_guide
reported_results_count = 0 reported_results_count = 0
@ -282,12 +238,15 @@ class Manager(object):
is_stdin) is_stdin)
checks = self.checks.to_dictionary() checks = self.checks.to_dictionary()
self.checkers = [ checkers = (
FileChecker(filename, checks, self.options) FileChecker(filename, checks, self.options)
for argument in paths for argument in paths
for filename in utils.filenames_from(argument, for filename in utils.filenames_from(argument,
self.is_path_excluded) self.is_path_excluded)
if should_create_file_checker(filename, argument) if should_create_file_checker(filename, argument)
)
self.checkers = [
checker for checker in checkers if checker.should_process
] ]
LOG.info('Checking %d files', len(self.checkers)) LOG.info('Checking %d files', len(self.checkers))
@ -311,32 +270,36 @@ class Manager(object):
results_found += len(results) results_found += len(results)
return (results_found, results_reported) return (results_found, results_reported)
def _force_cleanup(self):
if self.pool is not None:
self.pool.terminate()
self.pool.join()
def run_parallel(self): def run_parallel(self):
"""Run the checkers in parallel.""" """Run the checkers in parallel."""
LOG.info('Starting %d process workers', self.jobs) final_results = collections.defaultdict(list)
for i in range(self.jobs): final_statistics = collections.defaultdict(dict)
proc = multiprocessing.Process( for ret in self.pool.imap_unordered(
target=_run_checks_from_queue, _run_checks, self.checkers,
args=(self.process_queue, self.results_queue, chunksize=_pool_chunksize(len(self.checkers), self.jobs),
self.statistics_queue) ):
) filename, results, statistics = ret
proc.daemon = True
proc.start()
self.processes.append(proc)
final_results = {}
for (filename, results) in self._results():
final_results[filename] = results final_results[filename] = results
final_statistics[filename] = statistics
self.pool.close()
self.pool.join()
self.pool = None
for checker in self.checkers: for checker in self.checkers:
filename = checker.display_name filename = checker.display_name
checker.results = sorted(final_results.get(filename, []), checker.results = sorted(final_results[filename],
key=lambda tup: (tup[2], tup[2])) key=lambda tup: (tup[2], tup[2]))
checker.statistics = final_statistics[filename]
def run_serial(self): def run_serial(self):
"""Run the checkers in serial.""" """Run the checkers in serial."""
for checker in self.checkers: for checker in self.checkers:
checker.run_checks(self.results_queue, self.statistics_queue) checker.run_checks()
def run(self): def run(self):
"""Run all the checkers. """Run all the checkers.
@ -374,15 +337,6 @@ class Manager(object):
""" """
LOG.info('Making checkers') LOG.info('Making checkers')
self.make_checkers(paths) self.make_checkers(paths)
if not self.using_multiprocessing:
return
LOG.info('Populating process queue')
for checker in self.checkers:
self.process_queue.put(checker)
for i in range(self.jobs):
self.process_queue.put('DONE')
def stop(self): def stop(self):
"""Stop checking files.""" """Stop checking files."""
@ -413,13 +367,18 @@ class FileChecker(object):
self.filename = filename self.filename = filename
self.checks = checks self.checks = checks
self.results = [] self.results = []
self.processor = self._make_processor()
self.display_name = self.processor.filename
self.statistics = { self.statistics = {
'tokens': 0, 'tokens': 0,
'logical lines': 0, 'logical lines': 0,
'physical lines': len(self.processor.lines), 'physical lines': 0,
} }
self.processor = self._make_processor()
self.display_name = filename
self.should_process = False
if self.processor is not None:
self.display_name = self.processor.filename
self.should_process = not self.processor.should_ignore_file()
self.statistics['physical lines'] = len(self.processor.lines)
def _make_processor(self): def _make_processor(self):
try: try:
@ -597,11 +556,8 @@ class FileChecker(object):
self.run_physical_checks(file_processor.lines[-1]) self.run_physical_checks(file_processor.lines[-1])
self.run_logical_checks() self.run_logical_checks()
def run_checks(self, results_queue, statistics_queue): def run_checks(self):
"""Run checks against the file.""" """Run checks against the file."""
if self.processor.should_ignore_file():
return
try: try:
self.process_tokens() self.process_tokens()
except exceptions.InvalidSyntax as exc: except exceptions.InvalidSyntax as exc:
@ -610,13 +566,9 @@ class FileChecker(object):
self.run_ast_checks() self.run_ast_checks()
if results_queue is not None:
results_queue.put((self.filename, self.results))
logical_lines = self.processor.statistics['logical lines'] logical_lines = self.processor.statistics['logical lines']
self.statistics['logical lines'] = logical_lines self.statistics['logical lines'] = logical_lines
if statistics_queue is not None: return self.filename, self.results, self.statistics
statistics_queue.put(self.statistics)
def handle_comment(self, token, token_text): def handle_comment(self, token, token_text):
"""Handle the logic when encountering a comment token.""" """Handle the logic when encountering a comment token."""
@ -663,19 +615,25 @@ class FileChecker(object):
override_error_line=token[4]) override_error_line=token[4])
def _run_checks_from_queue(process_queue, results_queue, statistics_queue): def _pool_init():
LOG.info('Running checks in parallel') """Ensure correct signaling of ^C using multiprocessing.Pool."""
try: signal.signal(signal.SIGINT, signal.SIG_IGN)
for checker in iter(process_queue.get, 'DONE'):
LOG.info('Checking "%s"', checker.filename)
checker.run_checks(results_queue, statistics_queue) def _pool_chunksize(num_checkers, num_jobs):
except exceptions.PluginRequestedUnknownParameters as exc: """Determine the chunksize for the multiprocessing Pool.
print(str(exc))
except Exception as exc: - For chunksize, see: https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.imap # noqa
LOG.error('Unhandled exception occurred') - This formula, while not perfect, aims to give each worker two batches of
raise work.
finally: - See: https://gitlab.com/pycqa/flake8/merge_requests/156#note_18878876
results_queue.put('DONE') - See: https://gitlab.com/pycqa/flake8/issues/265
"""
return max(num_checkers // (num_jobs * 2), 1)
def _run_checks(checker):
return checker.run_checks()
def find_offset(offset, mapping): def find_offset(offset, mapping):

View file

@ -367,6 +367,7 @@ def is_eol_token(token):
"""Check if the token is an end-of-line token.""" """Check if the token is an end-of-line token."""
return token[0] in NEWLINE or token[4][token[3][1]:].lstrip() == '\\\n' return token[0] in NEWLINE or token[4][token[3][1]:].lstrip() == '\\\n'
if COMMENT_WITH_NL: # If on Python 2.6 if COMMENT_WITH_NL: # If on Python 2.6
def is_eol_token(token, _is_eol_token=is_eol_token): def is_eol_token(token, _is_eol_token=is_eol_token):
"""Check if the token is an end-of-line token.""" """Check if the token is an end-of-line token."""

View file

@ -20,7 +20,7 @@ def test_oserrors_cause_serial_fall_back():
"""Verify that OSErrors will cause the Manager to fallback to serial.""" """Verify that OSErrors will cause the Manager to fallback to serial."""
err = OSError(errno.ENOSPC, 'Ominous message about spaceeeeee') err = OSError(errno.ENOSPC, 'Ominous message about spaceeeeee')
style_guide = style_guide_mock() style_guide = style_guide_mock()
with mock.patch('multiprocessing.Queue', side_effect=err): with mock.patch('_multiprocessing.SemLock', side_effect=err):
manager = checker.Manager(style_guide, [], []) manager = checker.Manager(style_guide, [], [])
assert manager.using_multiprocessing is False assert manager.using_multiprocessing is False
@ -30,7 +30,7 @@ def test_oserrors_are_reraised(is_windows):
"""Verify that OSErrors will cause the Manager to fallback to serial.""" """Verify that OSErrors will cause the Manager to fallback to serial."""
err = OSError(errno.EAGAIN, 'Ominous message') err = OSError(errno.EAGAIN, 'Ominous message')
style_guide = style_guide_mock() style_guide = style_guide_mock()
with mock.patch('multiprocessing.Queue', side_effect=err): with mock.patch('_multiprocessing.SemLock', side_effect=err):
with pytest.raises(OSError): with pytest.raises(OSError):
checker.Manager(style_guide, [], []) checker.Manager(style_guide, [], [])

View file

@ -59,7 +59,7 @@ def test_read_lines_from_stdin(stdin_get_value):
stdin_value = mock.Mock() stdin_value = mock.Mock()
stdin_value.splitlines.return_value = [] stdin_value.splitlines.return_value = []
stdin_get_value.return_value = stdin_value stdin_get_value.return_value = stdin_value
file_processor = processor.FileProcessor('-', options_from()) processor.FileProcessor('-', options_from())
stdin_get_value.assert_called_once_with() stdin_get_value.assert_called_once_with()
stdin_value.splitlines.assert_called_once_with(True) stdin_value.splitlines.assert_called_once_with(True)