From 2c8f6639ba2fc1b83af5c95115045fb8b1b9fc88 Mon Sep 17 00:00:00 2001 From: phlax Date: Thu, 18 Mar 2021 02:05:27 +0000 Subject: [PATCH] python: Set indents to 4-space (#15539) Signed-off-by: Ryan Northey --- api/tools/generate_listeners.py | 54 +- api/tools/generate_listeners_test.py | 8 +- api/tools/tap2pcap.py | 96 +- api/tools/tap2pcap_test.py | 27 +- ci/flaky_test/process_xml.py | 414 ++-- configs/configgen.py | 12 +- configs/example_configs_validation.py | 22 +- docs/_ext/validating_code_block.py | 78 +- docs/conf.py | 40 +- examples/cache/service.py | 40 +- examples/cors/backend/service.py | 4 +- examples/cors/frontend/service.py | 6 +- examples/csrf/crosssite/service.py | 6 +- examples/csrf/samesite/service.py | 10 +- examples/double-proxy/service.py | 14 +- examples/ext_authz/upstream/service/server.py | 4 +- examples/front-proxy/service.py | 28 +- examples/grpc-bridge/client/client.py | 80 +- .../load-reporting-service/http_server.py | 4 +- restarter/hot-restarter.py | 268 +-- .../network/kafka/protocol/generator.py | 1084 +++++----- .../network/kafka/protocol/launcher.py | 18 +- .../network/kafka/serialization/generator.py | 46 +- .../network/kafka/serialization/launcher.py | 8 +- .../generate_test_data.py | 26 +- .../test_access_log_schema.py | 160 +- .../test_cluster_schema.py | 28 +- .../test_http_conn_network_filter_schema.py | 12 +- .../test_http_router_schema.py | 24 +- .../test_listener_schema.py | 12 +- .../test_route_configuration_schema.py | 12 +- .../test_route_entry_schema.py | 26 +- .../test_top_level_config_schema.py | 28 +- .../json/config_schemas_test_data/util.py | 19 +- .../kafka_broker_integration_test.py | 971 ++++----- .../network/kafka/protocol/launcher.py | 18 +- .../network/kafka/serialization/launcher.py | 8 +- .../network/thrift_proxy/driver/client.py | 438 ++-- .../driver/fbthrift/THeaderTransport.py | 1041 ++++----- .../driver/finagle/TFinagleServerProcessor.py | 82 +- .../driver/finagle/TFinagleServerProtocol.py | 50 +- .../network/thrift_proxy/driver/server.py | 372 ++-- test/integration/capture_fuzz_gen.py | 98 +- tools/api/generate_go_protobuf.py | 169 +- tools/api/validate_structure.py | 66 +- tools/api_boost/api_boost.py | 291 +-- tools/api_boost/api_boost_test.py | 102 +- tools/api_proto_plugin/annotations.py | 74 +- tools/api_proto_plugin/plugin.py | 91 +- tools/api_proto_plugin/traverse.py | 81 +- tools/api_proto_plugin/type_context.py | 311 +-- tools/api_proto_plugin/utils.py | 14 +- tools/api_proto_plugin/visitor.py | 26 +- .../generate_api_version_header.py | 42 +- .../generate_api_version_header_test.py | 135 +- tools/build_profile.py | 20 +- tools/code_format/.style.yapf | 2 +- tools/code_format/check_format.py | 1917 +++++++++-------- tools/code_format/check_format_test_helper.py | 522 ++--- tools/code_format/common.py | 18 +- tools/code_format/envoy_build_fixer.py | 252 +-- tools/code_format/flake8.conf | 2 +- tools/code_format/format_python_tools.py | 95 +- tools/code_format/header_order.py | 190 +- tools/code_format/paths.py | 8 +- tools/code_format/python_flake8.py | 6 +- tools/config_validation/validate_fragment.py | 56 +- tools/dependency/cve_scan.py | 290 +-- tools/dependency/cve_scan_test.py | 507 ++--- tools/dependency/exports.py | 8 +- tools/dependency/generate_external_dep_rst.py | 132 +- tools/dependency/ossf_scorecard.py | 175 +- tools/dependency/release_dates.py | 123 +- tools/dependency/utils.py | 66 +- tools/dependency/validate.py | 374 ++-- tools/dependency/validate_test.py | 208 +- .../deprecate_features/deprecate_features.py | 69 +- tools/deprecate_version/deprecate_version.py | 247 +-- tools/envoy_collect/envoy_collect.py | 298 +-- tools/envoy_headersplit/headersplit.py | 345 +-- tools/envoy_headersplit/headersplit_test.py | 153 +- tools/envoy_headersplit/replace_includes.py | 117 +- .../replace_includes_test.py | 106 +- tools/extensions/generate_extension_db.py | 150 +- tools/extensions/generate_extension_rst.py | 70 +- tools/find_related_envoy_files.py | 54 +- tools/gen_compilation_database.py | 122 +- tools/github/sync_assignable.py | 44 +- tools/print_dependencies.py | 54 +- tools/proto_format/active_protos_gen.py | 40 +- tools/proto_format/proto_sync.py | 499 ++--- tools/protodoc/generate_empty.py | 39 +- tools/protodoc/protodoc.py | 699 +++--- tools/protoxform/merge_active_shadow.py | 360 ++-- tools/protoxform/merge_active_shadow_test.py | 260 +-- tools/protoxform/migrate.py | 413 ++-- tools/protoxform/options.py | 24 +- tools/protoxform/protoprint.py | 827 +++---- tools/protoxform/protoxform.py | 116 +- tools/protoxform/protoxform_test_helper.py | 120 +- tools/protoxform/utils.py | 12 +- tools/run_command.py | 6 +- tools/socket_passing.py | 165 +- tools/spelling/check_spelling_pedantic.py | 1207 +++++------ .../spelling/check_spelling_pedantic_test.py | 105 +- tools/stack_decode.py | 147 +- .../file_descriptor_set_text_gen.py | 18 +- .../type_whisperer/proto_build_targets_gen.py | 72 +- tools/type_whisperer/proto_cc_source_gen.py | 12 +- tools/type_whisperer/type_whisperer.py | 81 +- tools/type_whisperer/typedb_gen.py | 202 +- tools/vscode/generate_debug_config.py | 155 +- 112 files changed, 9847 insertions(+), 9730 deletions(-) diff --git a/api/tools/generate_listeners.py b/api/tools/generate_listeners.py index 181c7755c644..f94a975d7e51 100644 --- a/api/tools/generate_listeners.py +++ b/api/tools/generate_listeners.py @@ -22,44 +22,44 @@ # Convert an arbitrary proto object to its Struct proto representation. def proto_to_struct(proto): - json_rep = json_format.MessageToJson(proto) - parsed_msg = struct_pb2.Struct() - json_format.Parse(json_rep, parsed_msg) - return parsed_msg + json_rep = json_format.MessageToJson(proto) + parsed_msg = struct_pb2.Struct() + json_format.Parse(json_rep, parsed_msg) + return parsed_msg # Parse a proto from the filesystem. def parse_proto(path, filter_name): - # We only know about some filter config protos ahead of time. - KNOWN_FILTERS = { - 'http_connection_manager': lambda: http_connection_manager_pb2.HttpConnectionManager() - } - filter_config = KNOWN_FILTERS[filter_name]() - with open(path, 'r') as f: - text_format.Merge(f.read(), filter_config) - return filter_config + # We only know about some filter config protos ahead of time. + KNOWN_FILTERS = { + 'http_connection_manager': lambda: http_connection_manager_pb2.HttpConnectionManager() + } + filter_config = KNOWN_FILTERS[filter_name]() + with open(path, 'r') as f: + text_format.Merge(f.read(), filter_config) + return filter_config def generate_listeners(listeners_pb_path, output_pb_path, output_json_path, fragments): - listener = lds_pb2.Listener() - with open(listeners_pb_path, 'r') as f: - text_format.Merge(f.read(), listener) + listener = lds_pb2.Listener() + with open(listeners_pb_path, 'r') as f: + text_format.Merge(f.read(), listener) - for filter_chain in listener.filter_chains: - for f in filter_chain.filters: - f.config.CopyFrom(proto_to_struct(parse_proto(next(fragments), f.name))) + for filter_chain in listener.filter_chains: + for f in filter_chain.filters: + f.config.CopyFrom(proto_to_struct(parse_proto(next(fragments), f.name))) - with open(output_pb_path, 'w') as f: - f.write(str(listener)) + with open(output_pb_path, 'w') as f: + f.write(str(listener)) - with open(output_json_path, 'w') as f: - f.write(json_format.MessageToJson(listener)) + with open(output_json_path, 'w') as f: + f.write(json_format.MessageToJson(listener)) if __name__ == '__main__': - if len(sys.argv) < 4: - print('Usage: %s ') % sys.argv[0] - sys.exit(1) + if len(sys.argv) < 4: + print('Usage: %s ') % sys.argv[0] + sys.exit(1) - generate_listeners(sys.argv[1], sys.argv[2], sys.argv[3], iter(sys.argv[4:])) + generate_listeners(sys.argv[1], sys.argv[2], sys.argv[3], iter(sys.argv[4:])) diff --git a/api/tools/generate_listeners_test.py b/api/tools/generate_listeners_test.py index 6d0af4edfcc7..f67ef4bbb5aa 100644 --- a/api/tools/generate_listeners_test.py +++ b/api/tools/generate_listeners_test.py @@ -5,7 +5,7 @@ import generate_listeners if __name__ == "__main__": - srcdir = os.path.join(os.getenv("TEST_SRCDIR"), 'envoy_api_canonical') - generate_listeners.generate_listeners( - os.path.join(srcdir, "examples/service_envoy/listeners.pb"), "/dev/stdout", "/dev/stdout", - iter([os.path.join(srcdir, "examples/service_envoy/http_connection_manager.pb")])) + srcdir = os.path.join(os.getenv("TEST_SRCDIR"), 'envoy_api_canonical') + generate_listeners.generate_listeners( + os.path.join(srcdir, "examples/service_envoy/listeners.pb"), "/dev/stdout", "/dev/stdout", + iter([os.path.join(srcdir, "examples/service_envoy/http_connection_manager.pb")])) diff --git a/api/tools/tap2pcap.py b/api/tools/tap2pcap.py index a9f7a3e71760..93a861039928 100644 --- a/api/tools/tap2pcap.py +++ b/api/tools/tap2pcap.py @@ -32,57 +32,57 @@ def dump_event(direction, timestamp, data): - dump = io.StringIO() - dump.write('%s\n' % direction) - # Adjust to local timezone - adjusted_dt = timestamp.ToDatetime() - datetime.timedelta(seconds=time.altzone) - dump.write('%s\n' % adjusted_dt) - od = sp.Popen(['od', '-Ax', '-tx1', '-v'], stdout=sp.PIPE, stdin=sp.PIPE, stderr=sp.PIPE) - packet_dump = od.communicate(data)[0] - dump.write(packet_dump.decode()) - return dump.getvalue() + dump = io.StringIO() + dump.write('%s\n' % direction) + # Adjust to local timezone + adjusted_dt = timestamp.ToDatetime() - datetime.timedelta(seconds=time.altzone) + dump.write('%s\n' % adjusted_dt) + od = sp.Popen(['od', '-Ax', '-tx1', '-v'], stdout=sp.PIPE, stdin=sp.PIPE, stderr=sp.PIPE) + packet_dump = od.communicate(data)[0] + dump.write(packet_dump.decode()) + return dump.getvalue() def tap2pcap(tap_path, pcap_path): - wrapper = wrapper_pb2.TraceWrapper() - if tap_path.endswith('.pb_text'): - with open(tap_path, 'r') as f: - text_format.Merge(f.read(), wrapper) - else: - with open(tap_path, 'r') as f: - wrapper.ParseFromString(f.read()) - - trace = wrapper.socket_buffered_trace - local_address = trace.connection.local_address.socket_address.address - local_port = trace.connection.local_address.socket_address.port_value - remote_address = trace.connection.remote_address.socket_address.address - remote_port = trace.connection.remote_address.socket_address.port_value - - dumps = [] - for event in trace.events: - if event.HasField('read'): - dumps.append(dump_event('I', event.timestamp, event.read.data.as_bytes)) - elif event.HasField('write'): - dumps.append(dump_event('O', event.timestamp, event.write.data.as_bytes)) - - ipv6 = False - try: - socket.inet_pton(socket.AF_INET6, local_address) - ipv6 = True - except socket.error: - pass - - text2pcap_args = [ - 'text2pcap', '-D', '-t', '%Y-%m-%d %H:%M:%S.', '-6' if ipv6 else '-4', - '%s,%s' % (remote_address, local_address), '-T', - '%d,%d' % (remote_port, local_port), '-', pcap_path - ] - text2pcap = sp.Popen(text2pcap_args, stdout=sp.PIPE, stdin=sp.PIPE) - text2pcap.communicate('\n'.join(dumps).encode()) + wrapper = wrapper_pb2.TraceWrapper() + if tap_path.endswith('.pb_text'): + with open(tap_path, 'r') as f: + text_format.Merge(f.read(), wrapper) + else: + with open(tap_path, 'r') as f: + wrapper.ParseFromString(f.read()) + + trace = wrapper.socket_buffered_trace + local_address = trace.connection.local_address.socket_address.address + local_port = trace.connection.local_address.socket_address.port_value + remote_address = trace.connection.remote_address.socket_address.address + remote_port = trace.connection.remote_address.socket_address.port_value + + dumps = [] + for event in trace.events: + if event.HasField('read'): + dumps.append(dump_event('I', event.timestamp, event.read.data.as_bytes)) + elif event.HasField('write'): + dumps.append(dump_event('O', event.timestamp, event.write.data.as_bytes)) + + ipv6 = False + try: + socket.inet_pton(socket.AF_INET6, local_address) + ipv6 = True + except socket.error: + pass + + text2pcap_args = [ + 'text2pcap', '-D', '-t', '%Y-%m-%d %H:%M:%S.', '-6' if ipv6 else '-4', + '%s,%s' % (remote_address, local_address), '-T', + '%d,%d' % (remote_port, local_port), '-', pcap_path + ] + text2pcap = sp.Popen(text2pcap_args, stdout=sp.PIPE, stdin=sp.PIPE) + text2pcap.communicate('\n'.join(dumps).encode()) if __name__ == '__main__': - if len(sys.argv) != 3: - print('Usage: %s ' % sys.argv[0]) - sys.exit(1) - tap2pcap(sys.argv[1], sys.argv[2]) + if len(sys.argv) != 3: + print('Usage: %s ' % sys.argv[0]) + sys.exit(1) + tap2pcap(sys.argv[1], sys.argv[2]) diff --git a/api/tools/tap2pcap_test.py b/api/tools/tap2pcap_test.py index 183a40decb72..fd13cf32ff69 100644 --- a/api/tools/tap2pcap_test.py +++ b/api/tools/tap2pcap_test.py @@ -11,17 +11,18 @@ # a golden output file for the tshark dump. Since we run tap2pcap in a # subshell with a limited environment, the inferred time zone should be UTC. if __name__ == '__main__': - srcdir = os.path.join(os.getenv('TEST_SRCDIR'), 'envoy_api_canonical') - tap_path = os.path.join(srcdir, 'tools/data/tap2pcap_h2_ipv4.pb_text') - expected_path = os.path.join(srcdir, 'tools/data/tap2pcap_h2_ipv4.txt') - pcap_path = os.path.join(os.getenv('TEST_TMPDIR'), 'generated.pcap') + srcdir = os.path.join(os.getenv('TEST_SRCDIR'), 'envoy_api_canonical') + tap_path = os.path.join(srcdir, 'tools/data/tap2pcap_h2_ipv4.pb_text') + expected_path = os.path.join(srcdir, 'tools/data/tap2pcap_h2_ipv4.txt') + pcap_path = os.path.join(os.getenv('TEST_TMPDIR'), 'generated.pcap') - tap2pcap.tap2pcap(tap_path, pcap_path) - actual_output = sp.check_output(['tshark', '-r', pcap_path, '-d', 'tcp.port==10000,http2', '-P']) - with open(expected_path, 'rb') as f: - expected_output = f.read() - if actual_output != expected_output: - print('Mismatch') - print('Expected: %s' % expected_output) - print('Actual: %s' % actual_output) - sys.exit(1) + tap2pcap.tap2pcap(tap_path, pcap_path) + actual_output = sp.check_output( + ['tshark', '-r', pcap_path, '-d', 'tcp.port==10000,http2', '-P']) + with open(expected_path, 'rb') as f: + expected_output = f.read() + if actual_output != expected_output: + print('Mismatch') + print('Expected: %s' % expected_output) + print('Actual: %s' % actual_output) + sys.exit(1) diff --git a/ci/flaky_test/process_xml.py b/ci/flaky_test/process_xml.py index be7572a8e419..54abfad37f36 100755 --- a/ci/flaky_test/process_xml.py +++ b/ci/flaky_test/process_xml.py @@ -13,261 +13,263 @@ # Returns a boolean indicating if a test passed. def did_test_pass(file): - tree = ET.parse(file) - root = tree.getroot() - for testsuite in root: - if testsuite.attrib['failures'] != '0' or testsuite.attrib['errors'] != '0': - return False - return True + tree = ET.parse(file) + root = tree.getroot() + for testsuite in root: + if testsuite.attrib['failures'] != '0' or testsuite.attrib['errors'] != '0': + return False + return True # Returns a pretty-printed string of a test case failure. def print_test_case_failure(testcase, testsuite, failure_msg, log_path): - ret = "Test flake details:\n" - ret += "- Test suite: {}\n".format(testsuite) - ret += "- Test case: {}\n".format(testcase) - ret += "- Log path: {}\n".format(log_path) - ret += "- Details:\n" - for line in failure_msg.splitlines(): - ret += "\t" + line + "\n" - ret += section_delimiter + "\n" - return ret + ret = "Test flake details:\n" + ret += "- Test suite: {}\n".format(testsuite) + ret += "- Test case: {}\n".format(testcase) + ret += "- Log path: {}\n".format(log_path) + ret += "- Details:\n" + for line in failure_msg.splitlines(): + ret += "\t" + line + "\n" + ret += section_delimiter + "\n" + return ret # Returns a pretty-printed string of a test suite error, such as an exception or a timeout. def print_test_suite_error(testsuite, testcase, log_path, duration, time, error_msg, output): - ret = "Test flake details:\n" - ret += "- Test suite: {}\n".format(testsuite) - ret += "- Test case: {}\n".format(testcase) - ret += "- Log path: {}\n".format(log_path) - - errno_string = os.strerror(int(error_msg.split(' ')[-1])) - ret += "- Error: {} ({})\n".format(error_msg.capitalize(), errno_string) - - if duration == time and duration in well_known_timeouts: - ret += "- Note: This error is likely a timeout (test duration == {}, a well known timeout value).\n".format( - duration) - - # If there's a call stack, print it. Otherwise, attempt to print the most recent, - # relevant lines. - output = output.rstrip('\n') - traceback_index = output.rfind('Traceback (most recent call last)') - - if traceback_index != -1: - ret += "- Relevant snippet:\n" - for line in output[traceback_index:].splitlines(): - ret += "\t" + line + "\n" - else: - # No traceback found. Attempt to print the most recent snippet from the last test case. - max_snippet_size = 20 - last_testcase_index = output.rfind('[ RUN ]') - output_lines = output[last_testcase_index:].splitlines() - num_lines_to_print = min(len(output_lines), max_snippet_size) + ret = "Test flake details:\n" + ret += "- Test suite: {}\n".format(testsuite) + ret += "- Test case: {}\n".format(testcase) + ret += "- Log path: {}\n".format(log_path) + + errno_string = os.strerror(int(error_msg.split(' ')[-1])) + ret += "- Error: {} ({})\n".format(error_msg.capitalize(), errno_string) + + if duration == time and duration in well_known_timeouts: + ret += "- Note: This error is likely a timeout (test duration == {}, a well known timeout value).\n".format( + duration) + + # If there's a call stack, print it. Otherwise, attempt to print the most recent, + # relevant lines. + output = output.rstrip('\n') + traceback_index = output.rfind('Traceback (most recent call last)') + + if traceback_index != -1: + ret += "- Relevant snippet:\n" + for line in output[traceback_index:].splitlines(): + ret += "\t" + line + "\n" + else: + # No traceback found. Attempt to print the most recent snippet from the last test case. + max_snippet_size = 20 + last_testcase_index = output.rfind('[ RUN ]') + output_lines = output[last_testcase_index:].splitlines() + num_lines_to_print = min(len(output_lines), max_snippet_size) - ret += "- Last {} line(s):\n".format(num_lines_to_print) - for line in output_lines[-num_lines_to_print:]: - ret += "\t" + line + "\n" + ret += "- Last {} line(s):\n".format(num_lines_to_print) + for line in output_lines[-num_lines_to_print:]: + ret += "\t" + line + "\n" - ret += "\n" + section_delimiter + "\n" + ret += "\n" + section_delimiter + "\n" - return ret + return ret # Parses a test suite error, such as an exception or a timeout, and returns a pretty-printed # string of the error. This function is dependent on the structure of the XML and the contents # of the test log and will need to be adjusted should those change. def parse_and_print_test_suite_error(testsuite, log_path): - error_msg = "" - test_duration = 0 - test_time = 0 - last_testsuite = testsuite.attrib['name'] - last_testcase = testsuite.attrib['name'] - test_output = "" - - # Test suites with errors are expected to have 2 children elements: a generic testcase tag - # with the runtimes and a child containing the error message, and another with the entire - # output of the test suite. - for testcase in testsuite: - if testcase.tag == "testcase": - test_duration = int(testcase.attrib['duration']) - test_time = int(testcase.attrib['time']) - - for child in testcase: - if child.tag == "error": - error_msg = child.attrib['message'] - elif testcase.tag == "system-out": - test_output = testcase.text - - # For test suites with errors like this one, the test suite and test case names were not - # parsed into the XML metadata. Here we attempt to extract those names from the log by - # finding the last test case to run. The expected format of that is: - # "[ RUN ] /.\n". - last_test_fullname = test_output.split('[ RUN ]')[-1].splitlines()[0] - last_testsuite = last_test_fullname.split('/')[1].split('.')[0] - last_testcase = last_test_fullname.split('.')[1] - - if error_msg != "": - return print_test_suite_error(last_testsuite, last_testcase, log_path, test_duration, test_time, - error_msg, test_output) - - return "" + error_msg = "" + test_duration = 0 + test_time = 0 + last_testsuite = testsuite.attrib['name'] + last_testcase = testsuite.attrib['name'] + test_output = "" + + # Test suites with errors are expected to have 2 children elements: a generic testcase tag + # with the runtimes and a child containing the error message, and another with the entire + # output of the test suite. + for testcase in testsuite: + if testcase.tag == "testcase": + test_duration = int(testcase.attrib['duration']) + test_time = int(testcase.attrib['time']) + + for child in testcase: + if child.tag == "error": + error_msg = child.attrib['message'] + elif testcase.tag == "system-out": + test_output = testcase.text + + # For test suites with errors like this one, the test suite and test case names were not + # parsed into the XML metadata. Here we attempt to extract those names from the log by + # finding the last test case to run. The expected format of that is: + # "[ RUN ] /.\n". + last_test_fullname = test_output.split('[ RUN ]')[-1].splitlines()[0] + last_testsuite = last_test_fullname.split('/')[1].split('.')[0] + last_testcase = last_test_fullname.split('.')[1] + + if error_msg != "": + return print_test_suite_error(last_testsuite, last_testcase, log_path, test_duration, + test_time, error_msg, test_output) + + return "" # Parses a failed test's XML, adds any flaky tests found to the visited set, and returns a # well-formatted string describing all failures and errors. def parse_xml(file, visited): - # This is dependent on the fact that log files reside in the same directory - # as their corresponding xml files. - log_file = file.split('.') - log_file_path = "" - for token in log_file[:-1]: - log_file_path += token - log_file_path += ".log" - - tree = ET.parse(file) - root = tree.getroot() - - # This loop is dependent on the structure of xml file emitted for test runs. - # Should this change in the future, appropriate adjustments need to be made. - ret = "" - for testsuite in root: - if testsuite.attrib['failures'] != '0': - for testcase in testsuite: - for failure_msg in testcase: - if (testcase.attrib['name'], testsuite.attrib['name']) not in visited: - ret += print_test_case_failure(testcase.attrib['name'], testsuite.attrib['name'], - failure_msg.text, log_file_path) - visited.add((testcase.attrib['name'], testsuite.attrib['name'])) - elif testsuite.attrib['errors'] != '0': - # If an unexpected error occurred, such as an exception or a timeout, the test suite was - # likely not parsed into XML properly, including the suite's name and the test case that - # caused the error. More parsing is needed to extract details about the error. - if (testsuite.attrib['name'], testsuite.attrib['name']) not in visited: - ret += parse_and_print_test_suite_error(testsuite, log_file_path) - visited.add((testsuite.attrib['name'], testsuite.attrib['name'])) - - return ret + # This is dependent on the fact that log files reside in the same directory + # as their corresponding xml files. + log_file = file.split('.') + log_file_path = "" + for token in log_file[:-1]: + log_file_path += token + log_file_path += ".log" + + tree = ET.parse(file) + root = tree.getroot() + + # This loop is dependent on the structure of xml file emitted for test runs. + # Should this change in the future, appropriate adjustments need to be made. + ret = "" + for testsuite in root: + if testsuite.attrib['failures'] != '0': + for testcase in testsuite: + for failure_msg in testcase: + if (testcase.attrib['name'], testsuite.attrib['name']) not in visited: + ret += print_test_case_failure(testcase.attrib['name'], + testsuite.attrib['name'], failure_msg.text, + log_file_path) + visited.add((testcase.attrib['name'], testsuite.attrib['name'])) + elif testsuite.attrib['errors'] != '0': + # If an unexpected error occurred, such as an exception or a timeout, the test suite was + # likely not parsed into XML properly, including the suite's name and the test case that + # caused the error. More parsing is needed to extract details about the error. + if (testsuite.attrib['name'], testsuite.attrib['name']) not in visited: + ret += parse_and_print_test_suite_error(testsuite, log_file_path) + visited.add((testsuite.attrib['name'], testsuite.attrib['name'])) + + return ret # The following function links the filepath of 'test.xml' (the result for the last attempt) with # that of its 'attempt_n.xml' file and stores it in a dictionary for easy lookup. def process_find_output(f, problematic_tests): - for line in f: - lineList = line.split('/') - filepath = "" - for i in range(len(lineList)): - if i >= len(lineList) - 2: - break - filepath += lineList[i] + "/" - filepath += "test.xml" - problematic_tests[filepath] = line.strip('\n') + for line in f: + lineList = line.split('/') + filepath = "" + for i in range(len(lineList)): + if i >= len(lineList) - 2: + break + filepath += lineList[i] + "/" + filepath += "test.xml" + problematic_tests[filepath] = line.strip('\n') # Returns helpful information on the run using Git. # Should Git change the output of the used commands in the future, # this will likely need adjustments as well. def get_git_info(CI_TARGET): - ret = "" + ret = "" - if CI_TARGET != "": - ret += "Target: {}\n".format(CI_TARGET) + if CI_TARGET != "": + ret += "Target: {}\n".format(CI_TARGET) - if os.getenv('SYSTEM_STAGEDISPLAYNAME') and os.getenv('SYSTEM_STAGEJOBNAME'): - ret += "Stage: {} {}\n".format(os.environ['SYSTEM_STAGEDISPLAYNAME'], - os.environ['SYSTEM_STAGEJOBNAME']) + if os.getenv('SYSTEM_STAGEDISPLAYNAME') and os.getenv('SYSTEM_STAGEJOBNAME'): + ret += "Stage: {} {}\n".format(os.environ['SYSTEM_STAGEDISPLAYNAME'], + os.environ['SYSTEM_STAGEJOBNAME']) - if os.getenv('BUILD_REASON') == "PullRequest" and os.getenv( - 'SYSTEM_PULLREQUEST_PULLREQUESTNUMBER'): - ret += "Pull request: {}/pull/{}\n".format(os.environ['REPO_URI'], - os.environ['SYSTEM_PULLREQUEST_PULLREQUESTNUMBER']) - elif os.getenv('BUILD_REASON'): - ret += "Build reason: {}\n".format(os.environ['BUILD_REASON']) + if os.getenv('BUILD_REASON') == "PullRequest" and os.getenv( + 'SYSTEM_PULLREQUEST_PULLREQUESTNUMBER'): + ret += "Pull request: {}/pull/{}\n".format( + os.environ['REPO_URI'], os.environ['SYSTEM_PULLREQUEST_PULLREQUESTNUMBER']) + elif os.getenv('BUILD_REASON'): + ret += "Build reason: {}\n".format(os.environ['BUILD_REASON']) - output = subprocess.check_output(['git', 'log', '--format=%H', '-n', '1'], encoding='utf-8') - ret += "Commmit: {}/commit/{}".format(os.environ['REPO_URI'], output) + output = subprocess.check_output(['git', 'log', '--format=%H', '-n', '1'], encoding='utf-8') + ret += "Commmit: {}/commit/{}".format(os.environ['REPO_URI'], output) - build_id = os.environ['BUILD_URI'].split('/')[-1] - ret += "CI results: https://dev.azure.com/cncf/envoy/_build/results?buildId=" + build_id + "\n" + build_id = os.environ['BUILD_URI'].split('/')[-1] + ret += "CI results: https://dev.azure.com/cncf/envoy/_build/results?buildId=" + build_id + "\n" - ret += "\n" + ret += "\n" - remotes = subprocess.check_output(['git', 'remote'], encoding='utf-8').splitlines() + remotes = subprocess.check_output(['git', 'remote'], encoding='utf-8').splitlines() - if ("origin" in remotes): - output = subprocess.check_output(['git', 'remote', 'get-url', 'origin'], encoding='utf-8') - ret += "Origin: {}".format(output.replace('.git', '')) + if ("origin" in remotes): + output = subprocess.check_output(['git', 'remote', 'get-url', 'origin'], encoding='utf-8') + ret += "Origin: {}".format(output.replace('.git', '')) - if ("upstream" in remotes): - output = subprocess.check_output(['git', 'remote', 'get-url', 'upstream'], encoding='utf-8') - ret += "Upstream: {}".format(output.replace('.git', '')) + if ("upstream" in remotes): + output = subprocess.check_output(['git', 'remote', 'get-url', 'upstream'], encoding='utf-8') + ret += "Upstream: {}".format(output.replace('.git', '')) - output = subprocess.check_output(['git', 'describe', '--all', '--always'], encoding='utf-8') - ret += "Latest ref: {}".format(output) + output = subprocess.check_output(['git', 'describe', '--all', '--always'], encoding='utf-8') + ret += "Latest ref: {}".format(output) - ret += "\n" + ret += "\n" - ret += "Last commit:\n" - output = subprocess.check_output(['git', 'show', '-s'], encoding='utf-8') - for line in output.splitlines(): - ret += "\t" + line + "\n" + ret += "Last commit:\n" + output = subprocess.check_output(['git', 'show', '-s'], encoding='utf-8') + for line in output.splitlines(): + ret += "\t" + line + "\n" - ret += section_delimiter + ret += section_delimiter - return ret + return ret if __name__ == "__main__": - CI_TARGET = "" - if len(sys.argv) == 2: - CI_TARGET = sys.argv[1] - - if os.getenv('TEST_TMPDIR') and os.getenv('REPO_URI') and os.getenv("BUILD_URI"): - os.environ["TMP_OUTPUT_PROCESS_XML"] = os.getenv("TEST_TMPDIR") + "/tmp_output_process_xml.txt" - else: - print("Set the env variables TEST_TMPDIR, REPO_URI, and BUILD_URI first.") - sys.exit(0) - - find_dir = "{}/**/**/**/**/bazel-testlogs/".format(os.environ['TEST_TMPDIR']).replace('\\', '/') - if CI_TARGET == "MacOS": - find_dir = '${TEST_TMPDIR}/' - os.system( - 'sh -c "/usr/bin/find {} -name attempt_*.xml > ${{TMP_OUTPUT_PROCESS_XML}}"'.format(find_dir)) - - # All output of find command should be either failed or flaky tests, as only then will - # a test be rerun and have an 'attempt_n.xml' file. problematic_tests holds a lookup - # table between the most recent run's xml filepath and the original attempt's failed xml - # filepath. - problematic_tests = {} - with open(os.environ['TMP_OUTPUT_PROCESS_XML'], 'r+') as f: - process_find_output(f, problematic_tests) - - # The logic here goes as follows: If there is a test suite that has run multiple times, - # which produces attempt_*.xml files, it means that the end result of that test - # is either flaky or failed. So if we find that the last run of the test succeeds - # we know for sure that this is a flaky test. - has_flaky_test = False - failure_output = "" - flaky_tests_visited = set() - for k in problematic_tests.keys(): - if did_test_pass(k): - has_flaky_test = True - failure_output += parse_xml(problematic_tests[k], flaky_tests_visited) - - if has_flaky_test: - output_msg = "``` \n" + get_git_info(CI_TARGET) + "\n" + failure_output + "``` \n" - - if os.getenv("SLACK_TOKEN"): - SLACKTOKEN = os.environ["SLACK_TOKEN"] - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - # Due to a weird interaction between `websocket-client` and Slack client - # we need to set the ssl context. See `slackapi/python-slack-sdk/issues/334` - client = slack.WebClient(token=SLACKTOKEN, ssl=ssl_context) - client.chat_postMessage(channel='test-flaky', text=output_msg, as_user="true") + CI_TARGET = "" + if len(sys.argv) == 2: + CI_TARGET = sys.argv[1] + + if os.getenv('TEST_TMPDIR') and os.getenv('REPO_URI') and os.getenv("BUILD_URI"): + os.environ["TMP_OUTPUT_PROCESS_XML"] = os.getenv( + "TEST_TMPDIR") + "/tmp_output_process_xml.txt" + else: + print("Set the env variables TEST_TMPDIR, REPO_URI, and BUILD_URI first.") + sys.exit(0) + + find_dir = "{}/**/**/**/**/bazel-testlogs/".format(os.environ['TEST_TMPDIR']).replace('\\', '/') + if CI_TARGET == "MacOS": + find_dir = '${TEST_TMPDIR}/' + os.system('sh -c "/usr/bin/find {} -name attempt_*.xml > ${{TMP_OUTPUT_PROCESS_XML}}"'.format( + find_dir)) + + # All output of find command should be either failed or flaky tests, as only then will + # a test be rerun and have an 'attempt_n.xml' file. problematic_tests holds a lookup + # table between the most recent run's xml filepath and the original attempt's failed xml + # filepath. + problematic_tests = {} + with open(os.environ['TMP_OUTPUT_PROCESS_XML'], 'r+') as f: + process_find_output(f, problematic_tests) + + # The logic here goes as follows: If there is a test suite that has run multiple times, + # which produces attempt_*.xml files, it means that the end result of that test + # is either flaky or failed. So if we find that the last run of the test succeeds + # we know for sure that this is a flaky test. + has_flaky_test = False + failure_output = "" + flaky_tests_visited = set() + for k in problematic_tests.keys(): + if did_test_pass(k): + has_flaky_test = True + failure_output += parse_xml(problematic_tests[k], flaky_tests_visited) + + if has_flaky_test: + output_msg = "``` \n" + get_git_info(CI_TARGET) + "\n" + failure_output + "``` \n" + + if os.getenv("SLACK_TOKEN"): + SLACKTOKEN = os.environ["SLACK_TOKEN"] + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + # Due to a weird interaction between `websocket-client` and Slack client + # we need to set the ssl context. See `slackapi/python-slack-sdk/issues/334` + client = slack.WebClient(token=SLACKTOKEN, ssl=ssl_context) + client.chat_postMessage(channel='test-flaky', text=output_msg, as_user="true") + else: + print(output_msg) else: - print(output_msg) - else: - print('No flaky tests found.\n') + print('No flaky tests found.\n') - os.remove(os.environ["TMP_OUTPUT_PROCESS_XML"]) + os.remove(os.environ["TMP_OUTPUT_PROCESS_XML"]) diff --git a/configs/configgen.py b/configs/configgen.py index 4cc5292ba595..b42c32276934 100755 --- a/configs/configgen.py +++ b/configs/configgen.py @@ -97,12 +97,12 @@ def generate_config(template_path, template, output_file, **context): - """ Generate a final config file based on a template and some context. """ - env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_path, followlinks=True), - undefined=jinja2.StrictUndefined) - raw_output = env.get_template(template).render(**context) - with open(output_file, 'w') as fh: - fh.write(raw_output) + """ Generate a final config file based on a template and some context. """ + env = jinja2.Environment(loader=jinja2.FileSystemLoader(template_path, followlinks=True), + undefined=jinja2.StrictUndefined) + raw_output = env.get_template(template).render(**context) + with open(output_file, 'w') as fh: + fh.write(raw_output) # TODO(sunjayBhatia, wrowe): Avoiding tracing extensions until they build on Windows diff --git a/configs/example_configs_validation.py b/configs/example_configs_validation.py index 885693118d7f..2481ec7af044 100644 --- a/configs/example_configs_validation.py +++ b/configs/example_configs_validation.py @@ -11,18 +11,18 @@ def main(): - errors = [] - for arg in sys.argv[1:]: - try: - validate_fragment("envoy.config.bootstrap.v3.Bootstrap", - yaml.safe_load(pathlib.Path(arg).read_text())) - except (ParseError, KeyError) as e: - errors.append(arg) - print(f"\nERROR (validation failed): {arg}\n{e}\n\n") + errors = [] + for arg in sys.argv[1:]: + try: + validate_fragment("envoy.config.bootstrap.v3.Bootstrap", + yaml.safe_load(pathlib.Path(arg).read_text())) + except (ParseError, KeyError) as e: + errors.append(arg) + print(f"\nERROR (validation failed): {arg}\n{e}\n\n") - if errors: - raise SystemExit(f"ERROR: some configuration files ({len(errors)}) failed to validate") + if errors: + raise SystemExit(f"ERROR: some configuration files ({len(errors)}) failed to validate") if __name__ == "__main__": - main() + main() diff --git a/docs/_ext/validating_code_block.py b/docs/_ext/validating_code_block.py index 6220ae98618b..83ef9f2a28f7 100644 --- a/docs/_ext/validating_code_block.py +++ b/docs/_ext/validating_code_block.py @@ -12,51 +12,51 @@ class ValidatingCodeBlock(CodeBlock): - """A directive that provides protobuf yaml formatting and validation. + """A directive that provides protobuf yaml formatting and validation. 'type-name' option is required and expected to conain full Envoy API type. An ExtensionError is raised on validation failure. Validation will be skipped if SPHINX_SKIP_CONFIG_VALIDATION environment variable is set. """ - has_content = True - required_arguments = CodeBlock.required_arguments - optional_arguments = CodeBlock.optional_arguments - final_argument_whitespace = CodeBlock.final_argument_whitespace - option_spec = { - 'type-name': directives.unchanged, - } - option_spec.update(CodeBlock.option_spec) - skip_validation = (os.getenv('SPHINX_SKIP_CONFIG_VALIDATION') or 'false').lower() == 'true' - - def run(self): - source, line = self.state_machine.get_source_and_line(self.lineno) - # built-in directives.unchanged_required option validator produces a confusing error message - if self.options.get('type-name') == None: - raise ExtensionError("Expected type name in: {0} line: {1}".format(source, line)) - - if not ValidatingCodeBlock.skip_validation: - args = [ - 'bazel-bin/tools/config_validation/validate_fragment', - self.options.get('type-name'), '-s', '\n'.join(self.content) - ] - completed = subprocess.run(args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding='utf-8') - if completed.returncode != 0: - raise ExtensionError( - "Failed config validation for type: '{0}' in: {1} line: {2}:\n {3}".format( - self.options.get('type-name'), source, line, completed.stderr)) - - self.options.pop('type-name', None) - return list(CodeBlock.run(self)) + has_content = True + required_arguments = CodeBlock.required_arguments + optional_arguments = CodeBlock.optional_arguments + final_argument_whitespace = CodeBlock.final_argument_whitespace + option_spec = { + 'type-name': directives.unchanged, + } + option_spec.update(CodeBlock.option_spec) + skip_validation = (os.getenv('SPHINX_SKIP_CONFIG_VALIDATION') or 'false').lower() == 'true' + + def run(self): + source, line = self.state_machine.get_source_and_line(self.lineno) + # built-in directives.unchanged_required option validator produces a confusing error message + if self.options.get('type-name') == None: + raise ExtensionError("Expected type name in: {0} line: {1}".format(source, line)) + + if not ValidatingCodeBlock.skip_validation: + args = [ + 'bazel-bin/tools/config_validation/validate_fragment', + self.options.get('type-name'), '-s', '\n'.join(self.content) + ] + completed = subprocess.run(args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding='utf-8') + if completed.returncode != 0: + raise ExtensionError( + "Failed config validation for type: '{0}' in: {1} line: {2}:\n {3}".format( + self.options.get('type-name'), source, line, completed.stderr)) + + self.options.pop('type-name', None) + return list(CodeBlock.run(self)) def setup(app): - app.add_directive("validated-code-block", ValidatingCodeBlock) + app.add_directive("validated-code-block", ValidatingCodeBlock) - return { - 'version': '0.1', - 'parallel_read_safe': True, - 'parallel_write_safe': True, - } + return { + 'version': '0.1', + 'parallel_read_safe': True, + 'parallel_write_safe': True, + } diff --git a/docs/conf.py b/docs/conf.py index fd3c098417c8..d4fa5605f588 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,35 +21,35 @@ # https://stackoverflow.com/questions/44761197/how-to-use-substitution-definitions-with-code-blocks class SubstitutionCodeBlock(CodeBlock): - """ + """ Similar to CodeBlock but replaces placeholders with variables. See "substitutions" below. """ - def run(self): - """ + def run(self): + """ Replace placeholders with given variables. """ - app = self.state.document.settings.env.app - new_content = [] - existing_content = self.content - for item in existing_content: - for pair in app.config.substitutions: - original, replacement = pair - item = item.replace(original, replacement) - new_content.append(item) + app = self.state.document.settings.env.app + new_content = [] + existing_content = self.content + for item in existing_content: + for pair in app.config.substitutions: + original, replacement = pair + item = item.replace(original, replacement) + new_content.append(item) - self.content = new_content - return list(CodeBlock.run(self)) + self.content = new_content + return list(CodeBlock.run(self)) def setup(app): - app.add_config_value('release_level', '', 'env') - app.add_config_value('substitutions', [], 'html') - app.add_directive('substitution-code-block', SubstitutionCodeBlock) + app.add_config_value('release_level', '', 'env') + app.add_config_value('substitutions', [], 'html') + app.add_directive('substitution-code-block', SubstitutionCodeBlock) if not os.environ.get('ENVOY_DOCS_RELEASE_LEVEL'): - raise Exception("ENVOY_DOCS_RELEASE_LEVEL env var must be defined") + raise Exception("ENVOY_DOCS_RELEASE_LEVEL env var must be defined") release_level = os.environ['ENVOY_DOCS_RELEASE_LEVEL'] blob_sha = os.environ['ENVOY_BLOB_SHA'] @@ -81,9 +81,9 @@ def setup(app): # Setup global substitutions if 'pre-release' in release_level: - substitutions = [('|envoy_docker_image|', 'envoy-dev:{}'.format(blob_sha))] + substitutions = [('|envoy_docker_image|', 'envoy-dev:{}'.format(blob_sha))] else: - substitutions = [('|envoy_docker_image|', 'envoy:{}'.format(blob_sha))] + substitutions = [('|envoy_docker_image|', 'envoy:{}'.format(blob_sha))] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -112,7 +112,7 @@ def setup(app): # built documents. if not os.environ.get('ENVOY_DOCS_VERSION_STRING'): - raise Exception("ENVOY_DOCS_VERSION_STRING env var must be defined") + raise Exception("ENVOY_DOCS_VERSION_STRING env var must be defined") # The short X.Y version. version = os.environ['ENVOY_DOCS_VERSION_STRING'] diff --git a/examples/cache/service.py b/examples/cache/service.py index 100f82c1545d..4433bc4e3c24 100644 --- a/examples/cache/service.py +++ b/examples/cache/service.py @@ -13,32 +13,32 @@ @app.route('/service//') def get(service_number, response_id): - stored_response = yaml.load(open('/etc/responses.yaml', 'r')).get(response_id) + stored_response = yaml.load(open('/etc/responses.yaml', 'r')).get(response_id) - if stored_response is None: - abort(404, 'No response found with the given id') + if stored_response is None: + abort(404, 'No response found with the given id') - response = make_response(stored_response.get('body') + '\n') - if stored_response.get('headers'): - response.headers.update(stored_response.get('headers')) + response = make_response(stored_response.get('body') + '\n') + if stored_response.get('headers'): + response.headers.update(stored_response.get('headers')) - # Generate etag header - response.add_etag() + # Generate etag header + response.add_etag() - # Append the date of response generation - body_with_date = "{}\nResponse generated at: {}\n".format( - response.get_data(as_text=True), - datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")) + # Append the date of response generation + body_with_date = "{}\nResponse generated at: {}\n".format( + response.get_data(as_text=True), + datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")) - response.set_data(body_with_date) + response.set_data(body_with_date) - # response.make_conditional() will change the response to a 304 response - # if a 'if-none-match' header exists in the request and matches the etag - return response.make_conditional(request) + # response.make_conditional() will change the response to a 304 response + # if a 'if-none-match' header exists in the request and matches the etag + return response.make_conditional(request) if __name__ == "__main__": - if not os.path.isfile('/etc/responses.yaml'): - print('Responses file not found at /etc/responses.yaml') - exit(1) - app.run(host='127.0.0.1', port=8080, debug=True) + if not os.path.isfile('/etc/responses.yaml'): + print('Responses file not found at /etc/responses.yaml') + exit(1) + app.run(host='127.0.0.1', port=8080, debug=True) diff --git a/examples/cors/backend/service.py b/examples/cors/backend/service.py index b8813a3bfea0..23641d79c402 100644 --- a/examples/cors/backend/service.py +++ b/examples/cors/backend/service.py @@ -6,8 +6,8 @@ @app.route('/cors/') def cors_enabled(status): - return 'Success!' + return 'Success!' if __name__ == "__main__": - app.run(host='127.0.0.1', port=8080, debug=True) + app.run(host='127.0.0.1', port=8080, debug=True) diff --git a/examples/cors/frontend/service.py b/examples/cors/frontend/service.py index ccd032094d2a..5db220ed34bf 100644 --- a/examples/cors/frontend/service.py +++ b/examples/cors/frontend/service.py @@ -6,9 +6,9 @@ @app.route('/') def index(): - file_dir = os.path.dirname(os.path.realpath(__file__)) - return send_from_directory(file_dir, 'index.html') + file_dir = os.path.dirname(os.path.realpath(__file__)) + return send_from_directory(file_dir, 'index.html') if __name__ == "__main__": - app.run(host='127.0.0.1', port=8080, debug=True) + app.run(host='127.0.0.1', port=8080, debug=True) diff --git a/examples/csrf/crosssite/service.py b/examples/csrf/crosssite/service.py index d46ece5eb8b4..10b71e0e6e47 100644 --- a/examples/csrf/crosssite/service.py +++ b/examples/csrf/crosssite/service.py @@ -8,9 +8,9 @@ @app.route('/', methods=['GET']) def index(): - file_dir = os.path.dirname(os.path.realpath(__file__)) - return send_from_directory(file_dir, 'index.html') + file_dir = os.path.dirname(os.path.realpath(__file__)) + return send_from_directory(file_dir, 'index.html') if __name__ == "__main__": - app.run(host='127.0.0.1', port=8080, debug=True) + app.run(host='127.0.0.1', port=8080, debug=True) diff --git a/examples/csrf/samesite/service.py b/examples/csrf/samesite/service.py index 290cc8bee48f..3d54346532af 100644 --- a/examples/csrf/samesite/service.py +++ b/examples/csrf/samesite/service.py @@ -8,19 +8,19 @@ @app.route('/csrf/ignored', methods=['GET']) def csrf_ignored(): - return 'Success!' + return 'Success!' @app.route('/csrf/', methods=['POST']) def csrf_with_status(status): - return 'Success!' + return 'Success!' @app.route('/', methods=['GET']) def index(): - file_dir = os.path.dirname(os.path.realpath(__file__)) - return send_from_directory(file_dir, 'index.html') + file_dir = os.path.dirname(os.path.realpath(__file__)) + return send_from_directory(file_dir, 'index.html') if __name__ == "__main__": - app.run(host='127.0.0.1', port=8080, debug=True) + app.run(host='127.0.0.1', port=8080, debug=True) diff --git a/examples/double-proxy/service.py b/examples/double-proxy/service.py index 40c693c11efc..ac40a6c7ca4f 100644 --- a/examples/double-proxy/service.py +++ b/examples/double-proxy/service.py @@ -9,13 +9,13 @@ @app.route('/') def hello(): - conn = psycopg2.connect("host=postgres user=postgres") - cur = conn.cursor() - cur.execute('SELECT version()') - msg = 'Connected to Postgres, version: %s' % cur.fetchone() - cur.close() - return msg + conn = psycopg2.connect("host=postgres user=postgres") + cur = conn.cursor() + cur.execute('SELECT version()') + msg = 'Connected to Postgres, version: %s' % cur.fetchone() + cur.close() + return msg if __name__ == "__main__": - app.run(host='0.0.0.0', port=8000, debug=True) + app.run(host='0.0.0.0', port=8000, debug=True) diff --git a/examples/ext_authz/upstream/service/server.py b/examples/ext_authz/upstream/service/server.py index a3d539f195ab..f80428858918 100644 --- a/examples/ext_authz/upstream/service/server.py +++ b/examples/ext_authz/upstream/service/server.py @@ -5,8 +5,8 @@ @app.route('/service') def hello(): - return 'Hello ' + request.headers.get('x-current-user') + ' from behind Envoy!' + return 'Hello ' + request.headers.get('x-current-user') + ' from behind Envoy!' if __name__ == "__main__": - app.run(host='0.0.0.0', port=8080, debug=False) + app.run(host='0.0.0.0', port=8080, debug=False) diff --git a/examples/front-proxy/service.py b/examples/front-proxy/service.py index c0e008d73404..e0a80e2a3fa9 100644 --- a/examples/front-proxy/service.py +++ b/examples/front-proxy/service.py @@ -28,24 +28,24 @@ @app.route('/service/') def hello(service_number): - return ('Hello from behind Envoy (service {})! hostname: {} resolved' - 'hostname: {}\n'.format(os.environ['SERVICE_NAME'], socket.gethostname(), - socket.gethostbyname(socket.gethostname()))) + return ('Hello from behind Envoy (service {})! hostname: {} resolved' + 'hostname: {}\n'.format(os.environ['SERVICE_NAME'], socket.gethostname(), + socket.gethostbyname(socket.gethostname()))) @app.route('/trace/') def trace(service_number): - headers = {} - # call service 2 from service 1 - if int(os.environ['SERVICE_NAME']) == 1: - for header in TRACE_HEADERS_TO_PROPAGATE: - if header in request.headers: - headers[header] = request.headers[header] - requests.get("http://localhost:9000/trace/2", headers=headers) - return ('Hello from behind Envoy (service {})! hostname: {} resolved' - 'hostname: {}\n'.format(os.environ['SERVICE_NAME'], socket.gethostname(), - socket.gethostbyname(socket.gethostname()))) + headers = {} + # call service 2 from service 1 + if int(os.environ['SERVICE_NAME']) == 1: + for header in TRACE_HEADERS_TO_PROPAGATE: + if header in request.headers: + headers[header] = request.headers[header] + requests.get("http://localhost:9000/trace/2", headers=headers) + return ('Hello from behind Envoy (service {})! hostname: {} resolved' + 'hostname: {}\n'.format(os.environ['SERVICE_NAME'], socket.gethostname(), + socket.gethostbyname(socket.gethostname()))) if __name__ == "__main__": - app.run(host='127.0.0.1', port=8080, debug=True) + app.run(host='127.0.0.1', port=8080, debug=True) diff --git a/examples/grpc-bridge/client/client.py b/examples/grpc-bridge/client/client.py index 3c9f9d075509..8bcf29f22cba 100755 --- a/examples/grpc-bridge/client/client.py +++ b/examples/grpc-bridge/client/client.py @@ -21,65 +21,65 @@ class KVClient(): - def get(self, key): - r = kv.GetRequest(key=key) + def get(self, key): + r = kv.GetRequest(key=key) - # Build the gRPC frame - data = r.SerializeToString() - data = pack('!cI', b'\0', len(data)) + data + # Build the gRPC frame + data = r.SerializeToString() + data = pack('!cI', b'\0', len(data)) + data - resp = requests.post(HOST + "/kv.KV/Get", data=data, headers=HEADERS) + resp = requests.post(HOST + "/kv.KV/Get", data=data, headers=HEADERS) - return kv.GetResponse().FromString(resp.content[5:]) + return kv.GetResponse().FromString(resp.content[5:]) - def set(self, key, value): - r = kv.SetRequest(key=key, value=value) - data = r.SerializeToString() - data = pack('!cI', b'\0', len(data)) + data + def set(self, key, value): + r = kv.SetRequest(key=key, value=value) + data = r.SerializeToString() + data = pack('!cI', b'\0', len(data)) + data - return requests.post(HOST + "/kv.KV/Set", data=data, headers=HEADERS) + return requests.post(HOST + "/kv.KV/Set", data=data, headers=HEADERS) def run(): - if len(sys.argv) == 1: - print(USAGE) + if len(sys.argv) == 1: + print(USAGE) - sys.exit(0) + sys.exit(0) - cmd = sys.argv[1] + cmd = sys.argv[1] - client = KVClient() + client = KVClient() - if cmd == "get": - # ensure a key was provided - if len(sys.argv) != 3: - print(USAGE) - sys.exit(1) + if cmd == "get": + # ensure a key was provided + if len(sys.argv) != 3: + print(USAGE) + sys.exit(1) - # get the key to fetch - key = sys.argv[2] + # get the key to fetch + key = sys.argv[2] - # send the request to the server - response = client.get(key) + # send the request to the server + response = client.get(key) - print(response.value) - sys.exit(0) + print(response.value) + sys.exit(0) - elif cmd == "set": - # ensure a key and value were provided - if len(sys.argv) < 4: - print(USAGE) - sys.exit(1) + elif cmd == "set": + # ensure a key and value were provided + if len(sys.argv) < 4: + print(USAGE) + sys.exit(1) - # get the key and the full text of value - key = sys.argv[2] - value = " ".join(sys.argv[3:]) + # get the key and the full text of value + key = sys.argv[2] + value = " ".join(sys.argv[3:]) - # send the request to the server - response = client.set(key, value) + # send the request to the server + response = client.set(key, value) - print("setf %s to %s" % (key, value)) + print("setf %s to %s" % (key, value)) if __name__ == '__main__': - run() + run() diff --git a/examples/load-reporting-service/http_server.py b/examples/load-reporting-service/http_server.py index c2cb7136788a..ada6194e7f73 100644 --- a/examples/load-reporting-service/http_server.py +++ b/examples/load-reporting-service/http_server.py @@ -5,8 +5,8 @@ @app.route('/service') def hello(): - return 'Hello from behind Envoy!' + return 'Hello from behind Envoy!' if __name__ == "__main__": - app.run(host='0.0.0.0', port=8082, debug=False) + app.run(host='0.0.0.0', port=8082, debug=False) diff --git a/restarter/hot-restarter.py b/restarter/hot-restarter.py index 5750fa61feac..3577c83182e6 100644 --- a/restarter/hot-restarter.py +++ b/restarter/hot-restarter.py @@ -18,192 +18,192 @@ def term_all_children(): - """ Iterate through all known child processes, send a TERM signal to each of + """ Iterate through all known child processes, send a TERM signal to each of them, and then wait up to TERM_WAIT_SECONDS for them to exit gracefully, exiting early if all children go away. If one or more children have not exited after TERM_WAIT_SECONDS, they will be forcibly killed """ - # First uninstall the SIGCHLD handler so that we don't get called again. - signal.signal(signal.SIGCHLD, signal.SIG_DFL) - - global pid_list - for pid in pid_list: - print("sending TERM to PID={}".format(pid)) - try: - os.kill(pid, signal.SIGTERM) - except OSError: - print("error sending TERM to PID={} continuing".format(pid)) - - all_exited = False - - # wait for TERM_WAIT_SECONDS seconds for children to exit cleanly - retries = 0 - while not all_exited and retries < TERM_WAIT_SECONDS: - for pid in list(pid_list): - ret_pid, exit_status = os.waitpid(pid, os.WNOHANG) - if ret_pid == 0 and exit_status == 0: - # the child is still running - continue - - pid_list.remove(pid) - - if len(pid_list) == 0: - all_exited = True - else: - retries += 1 - time.sleep(1) + # First uninstall the SIGCHLD handler so that we don't get called again. + signal.signal(signal.SIGCHLD, signal.SIG_DFL) - if all_exited: - print("all children exited cleanly") - else: + global pid_list for pid in pid_list: - print("child PID={} did not exit cleanly, killing".format(pid)) - force_kill_all_children() - sys.exit(1) # error status because a child did not exit cleanly + print("sending TERM to PID={}".format(pid)) + try: + os.kill(pid, signal.SIGTERM) + except OSError: + print("error sending TERM to PID={} continuing".format(pid)) + + all_exited = False + + # wait for TERM_WAIT_SECONDS seconds for children to exit cleanly + retries = 0 + while not all_exited and retries < TERM_WAIT_SECONDS: + for pid in list(pid_list): + ret_pid, exit_status = os.waitpid(pid, os.WNOHANG) + if ret_pid == 0 and exit_status == 0: + # the child is still running + continue + + pid_list.remove(pid) + + if len(pid_list) == 0: + all_exited = True + else: + retries += 1 + time.sleep(1) + + if all_exited: + print("all children exited cleanly") + else: + for pid in pid_list: + print("child PID={} did not exit cleanly, killing".format(pid)) + force_kill_all_children() + sys.exit(1) # error status because a child did not exit cleanly def force_kill_all_children(): - """ Iterate through all known child processes and force kill them. Typically + """ Iterate through all known child processes and force kill them. Typically term_all_children() should be attempted first to give child processes an opportunity to clean up state before exiting """ - global pid_list - for pid in pid_list: - print("force killing PID={}".format(pid)) - try: - os.kill(pid, signal.SIGKILL) - except OSError: - print("error force killing PID={} continuing".format(pid)) + global pid_list + for pid in pid_list: + print("force killing PID={}".format(pid)) + try: + os.kill(pid, signal.SIGKILL) + except OSError: + print("error force killing PID={} continuing".format(pid)) - pid_list = [] + pid_list = [] def shutdown(): - """ Attempt to gracefully shutdown all child Envoy processes and then exit. + """ Attempt to gracefully shutdown all child Envoy processes and then exit. See term_all_children() for further discussion. """ - term_all_children() - sys.exit(0) + term_all_children() + sys.exit(0) def sigterm_handler(signum, frame): - """ Handler for SIGTERM. """ - print("got SIGTERM") - shutdown() + """ Handler for SIGTERM. """ + print("got SIGTERM") + shutdown() def sigint_handler(signum, frame): - """ Handler for SIGINT (ctrl-c). The same as the SIGTERM handler. """ - print("got SIGINT") - shutdown() + """ Handler for SIGINT (ctrl-c). The same as the SIGTERM handler. """ + print("got SIGINT") + shutdown() def sighup_handler(signum, frame): - """ Handler for SIGHUP. This signal is used to cause the restarter to fork and exec a new + """ Handler for SIGHUP. This signal is used to cause the restarter to fork and exec a new child. """ - print("got SIGHUP") - fork_and_exec() + print("got SIGHUP") + fork_and_exec() def sigusr1_handler(signum, frame): - """ Handler for SIGUSR1. Propagate SIGUSR1 to all of the child processes """ + """ Handler for SIGUSR1. Propagate SIGUSR1 to all of the child processes """ - global pid_list - for pid in pid_list: - print("sending SIGUSR1 to PID={}".format(pid)) - try: - os.kill(pid, signal.SIGUSR1) - except OSError: - print("error in SIGUSR1 to PID={} continuing".format(pid)) + global pid_list + for pid in pid_list: + print("sending SIGUSR1 to PID={}".format(pid)) + try: + os.kill(pid, signal.SIGUSR1) + except OSError: + print("error in SIGUSR1 to PID={} continuing".format(pid)) def sigchld_handler(signum, frame): - """ Handler for SIGCHLD. Iterates through all of our known child processes and figures out whether + """ Handler for SIGCHLD. Iterates through all of our known child processes and figures out whether the signal/exit was expected or not. Python doesn't have any of the native signal handlers ability to get the child process info directly from the signal handler so we need to iterate through all child processes and see what happened.""" - print("got SIGCHLD") - - kill_all_and_exit = False - global pid_list - pid_list_copy = list(pid_list) - for pid in pid_list_copy: - ret_pid, exit_status = os.waitpid(pid, os.WNOHANG) - if ret_pid == 0 and exit_status == 0: - # This child is still running. - continue - - pid_list.remove(pid) - - # Now we see how the child exited. - if os.WIFEXITED(exit_status): - exit_code = os.WEXITSTATUS(exit_status) - print("PID={} exited with code={}".format(ret_pid, exit_code)) - if exit_code == 0: - # Normal exit. We assume this was on purpose. - pass - else: - # Something bad happened. We need to tear everything down so that whoever started the - # restarter can know about this situation and restart the whole thing. - kill_all_and_exit = True - elif os.WIFSIGNALED(exit_status): - print("PID={} was killed with signal={}".format(ret_pid, os.WTERMSIG(exit_status))) - kill_all_and_exit = True - else: - kill_all_and_exit = True - - if kill_all_and_exit: - print("Due to abnormal exit, force killing all child processes and exiting") - - # First uninstall the SIGCHLD handler so that we don't get called again. - signal.signal(signal.SIGCHLD, signal.SIG_DFL) - - force_kill_all_children() - - # Our last child died, so we have no purpose. Exit. - if not pid_list: - print("exiting due to lack of child processes") - sys.exit(1 if kill_all_and_exit else 0) + print("got SIGCHLD") + + kill_all_and_exit = False + global pid_list + pid_list_copy = list(pid_list) + for pid in pid_list_copy: + ret_pid, exit_status = os.waitpid(pid, os.WNOHANG) + if ret_pid == 0 and exit_status == 0: + # This child is still running. + continue + + pid_list.remove(pid) + + # Now we see how the child exited. + if os.WIFEXITED(exit_status): + exit_code = os.WEXITSTATUS(exit_status) + print("PID={} exited with code={}".format(ret_pid, exit_code)) + if exit_code == 0: + # Normal exit. We assume this was on purpose. + pass + else: + # Something bad happened. We need to tear everything down so that whoever started the + # restarter can know about this situation and restart the whole thing. + kill_all_and_exit = True + elif os.WIFSIGNALED(exit_status): + print("PID={} was killed with signal={}".format(ret_pid, os.WTERMSIG(exit_status))) + kill_all_and_exit = True + else: + kill_all_and_exit = True + + if kill_all_and_exit: + print("Due to abnormal exit, force killing all child processes and exiting") + + # First uninstall the SIGCHLD handler so that we don't get called again. + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + + force_kill_all_children() + + # Our last child died, so we have no purpose. Exit. + if not pid_list: + print("exiting due to lack of child processes") + sys.exit(1 if kill_all_and_exit else 0) def fork_and_exec(): - """ This routine forks and execs a new child process and keeps track of its PID. Before we fork, + """ This routine forks and execs a new child process and keeps track of its PID. Before we fork, set the current restart epoch in an env variable that processes can read if they care. """ - global restart_epoch - os.environ['RESTART_EPOCH'] = str(restart_epoch) - print("forking and execing new child process at epoch {}".format(restart_epoch)) - restart_epoch += 1 + global restart_epoch + os.environ['RESTART_EPOCH'] = str(restart_epoch) + print("forking and execing new child process at epoch {}".format(restart_epoch)) + restart_epoch += 1 - child_pid = os.fork() - if child_pid == 0: - # Child process - os.execl(sys.argv[1], sys.argv[1]) - else: - # Parent process - print("forked new child process with PID={}".format(child_pid)) - pid_list.append(child_pid) + child_pid = os.fork() + if child_pid == 0: + # Child process + os.execl(sys.argv[1], sys.argv[1]) + else: + # Parent process + print("forked new child process with PID={}".format(child_pid)) + pid_list.append(child_pid) def main(): - """ Script main. This script is designed so that a process watcher like runit or monit can watch + """ Script main. This script is designed so that a process watcher like runit or monit can watch this process and take corrective action if it ever goes away. """ - print("starting hot-restarter with target: {}".format(sys.argv[1])) + print("starting hot-restarter with target: {}".format(sys.argv[1])) - signal.signal(signal.SIGTERM, sigterm_handler) - signal.signal(signal.SIGINT, sigint_handler) - signal.signal(signal.SIGHUP, sighup_handler) - signal.signal(signal.SIGCHLD, sigchld_handler) - signal.signal(signal.SIGUSR1, sigusr1_handler) + signal.signal(signal.SIGTERM, sigterm_handler) + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGHUP, sighup_handler) + signal.signal(signal.SIGCHLD, sigchld_handler) + signal.signal(signal.SIGUSR1, sigusr1_handler) - # Start the first child process and then go into an endless loop since everything else happens via - # signals. - fork_and_exec() - while True: - time.sleep(60) + # Start the first child process and then go into an endless loop since everything else happens via + # signals. + fork_and_exec() + while True: + time.sleep(60) if __name__ == '__main__': - main() + main() diff --git a/source/extensions/filters/network/kafka/protocol/generator.py b/source/extensions/filters/network/kafka/protocol/generator.py index d03791c6eef6..d54b6cb5e2c8 100755 --- a/source/extensions/filters/network/kafka/protocol/generator.py +++ b/source/extensions/filters/network/kafka/protocol/generator.py @@ -4,7 +4,7 @@ def generate_main_code(type, main_header_file, resolver_cc_file, metrics_header_file, input_files): - """ + """ Main code generator. Takes input files and processes them into structures representing a Kafka message (request or @@ -15,47 +15,47 @@ def generate_main_code(type, main_header_file, resolver_cc_file, metrics_header_ - resolver_cc_file - contains request api key & version mapping to deserializer (from header file) - metrics_header_file - contains metrics with names corresponding to messages """ - processor = StatefulProcessor() - # Parse provided input files. - messages = processor.parse_messages(input_files) + processor = StatefulProcessor() + # Parse provided input files. + messages = processor.parse_messages(input_files) - complex_type_template = RenderingHelper.get_template('complex_type_template.j2') - parsers_template = RenderingHelper.get_template("%s_parser.j2" % type) + complex_type_template = RenderingHelper.get_template('complex_type_template.j2') + parsers_template = RenderingHelper.get_template("%s_parser.j2" % type) - main_header_contents = '' + main_header_contents = '' - for message in messages: - # For each child structure that is used by request/response, render its matching C++ code. - dependencies = message.compute_declaration_chain() - for dependency in dependencies: - main_header_contents += complex_type_template.render(complex_type=dependency) - # Each top-level structure (e.g. FetchRequest/FetchResponse) needs corresponding parsers. - main_header_contents += parsers_template.render(complex_type=message) + for message in messages: + # For each child structure that is used by request/response, render its matching C++ code. + dependencies = message.compute_declaration_chain() + for dependency in dependencies: + main_header_contents += complex_type_template.render(complex_type=dependency) + # Each top-level structure (e.g. FetchRequest/FetchResponse) needs corresponding parsers. + main_header_contents += parsers_template.render(complex_type=message) - # Full file with headers, namespace declaration etc. - template = RenderingHelper.get_template("%ss_h.j2" % type) - contents = template.render(contents=main_header_contents) + # Full file with headers, namespace declaration etc. + template = RenderingHelper.get_template("%ss_h.j2" % type) + contents = template.render(contents=main_header_contents) - # Generate main header file. - with open(main_header_file, 'w') as fd: - fd.write(contents) + # Generate main header file. + with open(main_header_file, 'w') as fd: + fd.write(contents) - # Generate ...resolver.cc file. - template = RenderingHelper.get_template("kafka_%s_resolver_cc.j2" % type) - contents = template.render(message_types=messages) - with open(resolver_cc_file, 'w') as fd: - fd.write(contents) + # Generate ...resolver.cc file. + template = RenderingHelper.get_template("kafka_%s_resolver_cc.j2" % type) + contents = template.render(message_types=messages) + with open(resolver_cc_file, 'w') as fd: + fd.write(contents) - # Generate ...metrics.h file. - template = RenderingHelper.get_template("%s_metrics_h.j2" % type) - contents = template.render(message_types=messages) - with open(metrics_header_file, 'w') as fd: - fd.write(contents) + # Generate ...metrics.h file. + template = RenderingHelper.get_template("%s_metrics_h.j2" % type) + contents = template.render(message_types=messages) + with open(metrics_header_file, 'w') as fd: + fd.write(contents) def generate_test_code(type, header_test_cc_file, codec_test_cc_file, utilities_cc_file, input_files): - """ + """ Test code generator. Takes input files and processes them into structures representing a Kafka message (request or @@ -66,642 +66,652 @@ def generate_test_code(type, header_test_cc_file, codec_test_cc_file, utilities_ - codec_test_cc_file - tests involving codec and Request/ResponseParserResolver, - utilities_cc_file - utilities for creating sample messages. """ - processor = StatefulProcessor() - # Parse provided input files. - messages = processor.parse_messages(input_files) + processor = StatefulProcessor() + # Parse provided input files. + messages = processor.parse_messages(input_files) - # Generate header-test file. - template = RenderingHelper.get_template("%ss_test_cc.j2" % type) - contents = template.render(message_types=messages) - with open(header_test_cc_file, 'w') as fd: - fd.write(contents) + # Generate header-test file. + template = RenderingHelper.get_template("%ss_test_cc.j2" % type) + contents = template.render(message_types=messages) + with open(header_test_cc_file, 'w') as fd: + fd.write(contents) - # Generate codec-test file. - template = RenderingHelper.get_template("%s_codec_%s_test_cc.j2" % (type, type)) - contents = template.render(message_types=messages) - with open(codec_test_cc_file, 'w') as fd: - fd.write(contents) + # Generate codec-test file. + template = RenderingHelper.get_template("%s_codec_%s_test_cc.j2" % (type, type)) + contents = template.render(message_types=messages) + with open(codec_test_cc_file, 'w') as fd: + fd.write(contents) - # Generate utilities file. - template = RenderingHelper.get_template("%s_utilities_cc.j2" % type) - contents = template.render(message_types=messages) - with open(utilities_cc_file, 'w') as fd: - fd.write(contents) + # Generate utilities file. + template = RenderingHelper.get_template("%s_utilities_cc.j2" % type) + contents = template.render(message_types=messages) + with open(utilities_cc_file, 'w') as fd: + fd.write(contents) class StatefulProcessor: - """ + """ Helper entity that keeps state during the processing. Some state needs to be shared across multiple message types, as we need to handle identical sub-type names (e.g. both AlterConfigsRequest & IncrementalAlterConfigsRequest have child AlterConfigsResource, what would cause a compile-time error if we were to handle it trivially). """ - def __init__(self): - # Complex types that have been encountered during processing. - self.known_types = set() - # Name of parent message type that's being processed right now. - self.currently_processed_message_type = None - # Common structs declared in this message type. - self.common_structs = {} + def __init__(self): + # Complex types that have been encountered during processing. + self.known_types = set() + # Name of parent message type that's being processed right now. + self.currently_processed_message_type = None + # Common structs declared in this message type. + self.common_structs = {} - def parse_messages(self, input_files): - """ + def parse_messages(self, input_files): + """ Parse request/response structures from provided input files. """ - import re - import json - - messages = [] - # Sort the input files, as the processing is stateful, as we want the same order every time. - input_files.sort() - # For each specification file, remove comments, and parse the remains. - for input_file in input_files: - try: - with open(input_file, 'r') as fd: - raw_contents = fd.read() - without_comments = re.sub(r'\s*//.*\n', '\n', raw_contents) - without_empty_newlines = re.sub(r'^\s*$', '', without_comments, flags=re.MULTILINE) - message_spec = json.loads(without_empty_newlines) - message = self.parse_top_level_element(message_spec) - messages.append(message) - except Exception as e: - print('could not process %s' % input_file) - raise - - # Sort messages by api_key. - messages.sort(key=lambda x: x.get_extra('api_key')) - return messages - - def parse_top_level_element(self, spec): - """ + import re + import json + + messages = [] + # Sort the input files, as the processing is stateful, as we want the same order every time. + input_files.sort() + # For each specification file, remove comments, and parse the remains. + for input_file in input_files: + try: + with open(input_file, 'r') as fd: + raw_contents = fd.read() + without_comments = re.sub(r'\s*//.*\n', '\n', raw_contents) + without_empty_newlines = re.sub(r'^\s*$', + '', + without_comments, + flags=re.MULTILINE) + message_spec = json.loads(without_empty_newlines) + message = self.parse_top_level_element(message_spec) + messages.append(message) + except Exception as e: + print('could not process %s' % input_file) + raise + + # Sort messages by api_key. + messages.sort(key=lambda x: x.get_extra('api_key')) + return messages + + def parse_top_level_element(self, spec): + """ Parse a given structure into a request/response. Request/response is just a complex type, that has name & version information kept in differently named fields, compared to sub-structures in a message. """ - self.currently_processed_message_type = spec['name'] - - # Figure out all versions of this message type. - versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) - - # Figure out the flexible versions. - flexible_versions_string = spec.get('flexibleVersions', 'none') - if 'none' != flexible_versions_string: - flexible_versions = Statics.parse_version_string(flexible_versions_string, versions[-1]) - else: - flexible_versions = [] - - # Sanity check - all flexible versions need to be versioned. - if [x for x in flexible_versions if x not in versions]: - raise ValueError('invalid flexible versions') - - try: - # In 2.4 some types are declared at top level, and only referenced inside. - # So let's parse them and store them in state. - common_structs = spec.get('commonStructs') - if common_structs is not None: - for common_struct in common_structs: - common_struct_name = common_struct['name'] - common_struct_versions = Statics.parse_version_string(common_struct['versions'], - versions[-1]) - parsed_complex = self.parse_complex_type(common_struct_name, common_struct, - common_struct_versions) - self.common_structs[parsed_complex.name] = parsed_complex - - # Parse the type itself. - complex_type = self.parse_complex_type(self.currently_processed_message_type, spec, versions) - complex_type.register_flexible_versions(flexible_versions) - - # Request / response types need to carry api key version. - result = complex_type.with_extra('api_key', spec['apiKey']) - return result - - finally: - self.common_structs = {} - self.currently_processed_message_type = None - - def parse_complex_type(self, type_name, field_spec, versions): - """ + self.currently_processed_message_type = spec['name'] + + # Figure out all versions of this message type. + versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) + + # Figure out the flexible versions. + flexible_versions_string = spec.get('flexibleVersions', 'none') + if 'none' != flexible_versions_string: + flexible_versions = Statics.parse_version_string(flexible_versions_string, versions[-1]) + else: + flexible_versions = [] + + # Sanity check - all flexible versions need to be versioned. + if [x for x in flexible_versions if x not in versions]: + raise ValueError('invalid flexible versions') + + try: + # In 2.4 some types are declared at top level, and only referenced inside. + # So let's parse them and store them in state. + common_structs = spec.get('commonStructs') + if common_structs is not None: + for common_struct in common_structs: + common_struct_name = common_struct['name'] + common_struct_versions = Statics.parse_version_string( + common_struct['versions'], versions[-1]) + parsed_complex = self.parse_complex_type(common_struct_name, common_struct, + common_struct_versions) + self.common_structs[parsed_complex.name] = parsed_complex + + # Parse the type itself. + complex_type = self.parse_complex_type(self.currently_processed_message_type, spec, + versions) + complex_type.register_flexible_versions(flexible_versions) + + # Request / response types need to carry api key version. + result = complex_type.with_extra('api_key', spec['apiKey']) + return result + + finally: + self.common_structs = {} + self.currently_processed_message_type = None + + def parse_complex_type(self, type_name, field_spec, versions): + """ Parse given complex type, returning a structure that holds its name, field specification and allowed versions. """ - fields_el = field_spec.get('fields') - - if fields_el is not None: - fields = [] - for child_field in field_spec['fields']: - child = self.parse_field(child_field, versions[-1]) - if child is not None: - fields.append(child) - - # Some of the types repeat multiple times (e.g. AlterableConfig). - # In such a case, every second or later occurrence of the same name is going to be prefixed - # with parent type, e.g. we have AlterableConfig (for AlterConfigsRequest) and then - # IncrementalAlterConfigsRequestAlterableConfig (for IncrementalAlterConfigsRequest). - # This keeps names unique, while keeping non-duplicate ones short. - if type_name not in self.known_types: - self.known_types.add(type_name) - else: - type_name = self.currently_processed_message_type + type_name - self.known_types.add(type_name) - - return Complex(type_name, fields, versions) - - else: - return self.common_structs[type_name] - - def parse_field(self, field_spec, highest_possible_version): - """ + fields_el = field_spec.get('fields') + + if fields_el is not None: + fields = [] + for child_field in field_spec['fields']: + child = self.parse_field(child_field, versions[-1]) + if child is not None: + fields.append(child) + + # Some of the types repeat multiple times (e.g. AlterableConfig). + # In such a case, every second or later occurrence of the same name is going to be prefixed + # with parent type, e.g. we have AlterableConfig (for AlterConfigsRequest) and then + # IncrementalAlterConfigsRequestAlterableConfig (for IncrementalAlterConfigsRequest). + # This keeps names unique, while keeping non-duplicate ones short. + if type_name not in self.known_types: + self.known_types.add(type_name) + else: + type_name = self.currently_processed_message_type + type_name + self.known_types.add(type_name) + + return Complex(type_name, fields, versions) + + else: + return self.common_structs[type_name] + + def parse_field(self, field_spec, highest_possible_version): + """ Parse given field, returning a structure holding the name, type, and versions when this field is actually used (nullable or not). Obviously, field cannot be used in version higher than its type's usage. """ - if field_spec.get('tag') is not None: - return None + if field_spec.get('tag') is not None: + return None - version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) - version_usage_as_nullable = Statics.parse_version_string( - field_spec['nullableVersions'], - highest_possible_version) if 'nullableVersions' in field_spec else range(-1) - parsed_type = self.parse_type(field_spec['type'], field_spec, highest_possible_version) - return FieldSpec(field_spec['name'], parsed_type, version_usage, version_usage_as_nullable) + version_usage = Statics.parse_version_string(field_spec['versions'], + highest_possible_version) + version_usage_as_nullable = Statics.parse_version_string( + field_spec['nullableVersions'], + highest_possible_version) if 'nullableVersions' in field_spec else range(-1) + parsed_type = self.parse_type(field_spec['type'], field_spec, highest_possible_version) + return FieldSpec(field_spec['name'], parsed_type, version_usage, version_usage_as_nullable) - def parse_type(self, type_name, field_spec, highest_possible_version): - """ + def parse_type(self, type_name, field_spec, highest_possible_version): + """ Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one. """ - if (type_name.startswith('[]')): - # In spec files, array types are defined as `[]underlying_type` instead of having its own - # element with type inside. - underlying_type = self.parse_type(type_name[2:], field_spec, highest_possible_version) - return Array(underlying_type) - else: - if (type_name in Primitive.USABLE_PRIMITIVE_TYPE_NAMES): - return Primitive(type_name, field_spec.get('default')) - else: - versions = Statics.parse_version_string(field_spec['versions'], highest_possible_version) - return self.parse_complex_type(type_name, field_spec, versions) + if (type_name.startswith('[]')): + # In spec files, array types are defined as `[]underlying_type` instead of having its own + # element with type inside. + underlying_type = self.parse_type(type_name[2:], field_spec, highest_possible_version) + return Array(underlying_type) + else: + if (type_name in Primitive.USABLE_PRIMITIVE_TYPE_NAMES): + return Primitive(type_name, field_spec.get('default')) + else: + versions = Statics.parse_version_string(field_spec['versions'], + highest_possible_version) + return self.parse_complex_type(type_name, field_spec, versions) class Statics: - @staticmethod - def parse_version_string(raw_versions, highest_possible_version): - """ + @staticmethod + def parse_version_string(raw_versions, highest_possible_version): + """ Return integer range that corresponds to version string in spec file. """ - if raw_versions.endswith('+'): - return range(int(raw_versions[:-1]), highest_possible_version + 1) - else: - if '-' in raw_versions: - tokens = raw_versions.split('-', 1) - return range(int(tokens[0]), int(tokens[1]) + 1) - else: - single_version = int(raw_versions) - return range(single_version, single_version + 1) + if raw_versions.endswith('+'): + return range(int(raw_versions[:-1]), highest_possible_version + 1) + else: + if '-' in raw_versions: + tokens = raw_versions.split('-', 1) + return range(int(tokens[0]), int(tokens[1]) + 1) + else: + single_version = int(raw_versions) + return range(single_version, single_version + 1) class FieldList: - """ + """ List of fields used by given entity (request or child structure) in given message version (as fields get added or removed across versions and/or they change compaction level). """ - def __init__(self, version, uses_compact_fields, fields): - self.version = version - self.uses_compact_fields = uses_compact_fields - self.fields = fields + def __init__(self, version, uses_compact_fields, fields): + self.version = version + self.uses_compact_fields = uses_compact_fields + self.fields = fields - def used_fields(self): - """ + def used_fields(self): + """ Return list of fields that are actually used in this version of structure. """ - return filter(lambda x: x.used_in_version(self.version), self.fields) + return filter(lambda x: x.used_in_version(self.version), self.fields) - def constructor_signature(self): - """ + def constructor_signature(self): + """ Return constructor signature. Multiple versions of the same structure can have identical signatures (due to version bumps in Kafka). """ - parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) - return ', '.join(parameter_spec) + parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) + return ', '.join(parameter_spec) - def constructor_init_list(self): - """ + def constructor_init_list(self): + """ Renders member initialization list in constructor. Takes care of potential optional conversions (as field could be T in V1, but optional in V2). """ - init_list = [] - for field in self.fields: - if field.used_in_version(self.version): - if field.is_nullable(): - if field.is_nullable_in_version(self.version): - # Field is optional, and the parameter is optional in this version. - init_list_item = '%s_{%s}' % (field.name, field.name) - init_list.append(init_list_item) - else: - # Field is optional, and the parameter is T in this version. - init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) - init_list.append(init_list_item) - else: - # Field is T, so parameter cannot be optional. - init_list_item = '%s_{%s}' % (field.name, field.name) - init_list.append(init_list_item) - else: - # Field is not used in this version, so we need to put in default value. - init_list_item = '%s_{%s}' % (field.name, field.default_value()) - init_list.append(init_list_item) - pass - return ', '.join(init_list) - - def field_count(self): - return len(list(self.used_fields())) - - def example_value(self): - return ', '.join(map(lambda x: x.example_value_for_test(self.version), self.used_fields())) + init_list = [] + for field in self.fields: + if field.used_in_version(self.version): + if field.is_nullable(): + if field.is_nullable_in_version(self.version): + # Field is optional, and the parameter is optional in this version. + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # Field is optional, and the parameter is T in this version. + init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # Field is T, so parameter cannot be optional. + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # Field is not used in this version, so we need to put in default value. + init_list_item = '%s_{%s}' % (field.name, field.default_value()) + init_list.append(init_list_item) + pass + return ', '.join(init_list) + + def field_count(self): + return len(list(self.used_fields())) + + def example_value(self): + return ', '.join(map(lambda x: x.example_value_for_test(self.version), self.used_fields())) class FieldSpec: - """ + """ Represents a field present in a structure (request, or child structure thereof). Contains name, type, and versions when it is used (nullable or not). """ - def __init__(self, name, type, version_usage, version_usage_as_nullable): - import re - separated = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - self.name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', separated).lower() - self.type = type - self.version_usage = version_usage - self.version_usage_as_nullable = version_usage_as_nullable + def __init__(self, name, type, version_usage, version_usage_as_nullable): + import re + separated = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + self.name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', separated).lower() + self.type = type + self.version_usage = version_usage + self.version_usage_as_nullable = version_usage_as_nullable - def is_nullable(self): - return len(self.version_usage_as_nullable) > 0 + def is_nullable(self): + return len(self.version_usage_as_nullable) > 0 - def is_nullable_in_version(self, version): - """ + def is_nullable_in_version(self, version): + """ Whether the field is nullable in given version. Fields can be non-nullable in earlier versions. See https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message#nullable-fields """ - return version in self.version_usage_as_nullable - - def used_in_version(self, version): - return version in self.version_usage - - def field_declaration(self): - if self.is_nullable(): - return 'absl::optional<%s> %s' % (self.type.name, self.name) - else: - return '%s %s' % (self.type.name, self.name) - - def parameter_declaration(self, version): - if self.is_nullable_in_version(version): - return 'absl::optional<%s> %s' % (self.type.name, self.name) - else: - return '%s %s' % (self.type.name, self.name) - - def default_value(self): - if self.is_nullable(): - type_default_value = self.type.default_value() - # For nullable fields, it's possible to have (Java) null as default value. - if type_default_value != 'null': - return '{%s}' % type_default_value - else: - return 'absl::nullopt' - else: - return str(self.type.default_value()) - - def example_value_for_test(self, version): - if self.is_nullable(): - return 'absl::make_optional<%s>(%s)' % (self.type.name, - self.type.example_value_for_test(version)) - else: - return str(self.type.example_value_for_test(version)) - - def deserializer_name_in_version(self, version, compact): - if self.is_nullable_in_version(version): - return 'Nullable%s' % self.type.deserializer_name_in_version(version, compact) - else: - return self.type.deserializer_name_in_version(version, compact) - - def is_printable(self): - return self.type.is_printable() + return version in self.version_usage_as_nullable + + def used_in_version(self, version): + return version in self.version_usage + + def field_declaration(self): + if self.is_nullable(): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def parameter_declaration(self, version): + if self.is_nullable_in_version(version): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def default_value(self): + if self.is_nullable(): + type_default_value = self.type.default_value() + # For nullable fields, it's possible to have (Java) null as default value. + if type_default_value != 'null': + return '{%s}' % type_default_value + else: + return 'absl::nullopt' + else: + return str(self.type.default_value()) + + def example_value_for_test(self, version): + if self.is_nullable(): + return 'absl::make_optional<%s>(%s)' % (self.type.name, + self.type.example_value_for_test(version)) + else: + return str(self.type.example_value_for_test(version)) + + def deserializer_name_in_version(self, version, compact): + if self.is_nullable_in_version(version): + return 'Nullable%s' % self.type.deserializer_name_in_version(version, compact) + else: + return self.type.deserializer_name_in_version(version, compact) + + def is_printable(self): + return self.type.is_printable() class TypeSpecification: - def compute_declaration_chain(self): - """ + def compute_declaration_chain(self): + """ Computes types that need to be declared before this type can be declared, in C++ sense. """ - raise NotImplementedError() + raise NotImplementedError() - def deserializer_name_in_version(self, version, compact): - """ + def deserializer_name_in_version(self, version, compact): + """ Renders the deserializer name of given type, in message with given version. """ - raise NotImplementedError() + raise NotImplementedError() - def default_value(self): - """ + def default_value(self): + """ Returns a default value for given type. """ - raise NotImplementedError() + raise NotImplementedError() - def has_flexible_handling(self): - """ + def has_flexible_handling(self): + """ Whether the given type has special encoding when carrying message is using flexible encoding. """ - raise NotImplementedError() + raise NotImplementedError() - def example_value_for_test(self, version): - raise NotImplementedError() + def example_value_for_test(self, version): + raise NotImplementedError() - def is_printable(self): - raise NotImplementedError() + def is_printable(self): + raise NotImplementedError() class Array(TypeSpecification): - """ + """ Represents array complex type. To use instance of this type, it is necessary to declare structures required by self.underlying (e.g. to use Array, we need to have `struct Foo {...}`). """ - def __init__(self, underlying): - self.underlying = underlying + def __init__(self, underlying): + self.underlying = underlying - @property - def name(self): - return 'std::vector<%s>' % self.underlying.name + @property + def name(self): + return 'std::vector<%s>' % self.underlying.name - def compute_declaration_chain(self): - # To use an array of type T, we just need to be capable of using type T. - return self.underlying.compute_declaration_chain() + def compute_declaration_chain(self): + # To use an array of type T, we just need to be capable of using type T. + return self.underlying.compute_declaration_chain() - def deserializer_name_in_version(self, version, compact): - # For arrays, deserializer name is (Compact)(Nullable)ArrayDeserializer. - element_deserializer_name = self.underlying.deserializer_name_in_version(version, compact) - return '%sArrayDeserializer<%s>' % ("Compact" if compact else "", element_deserializer_name) + def deserializer_name_in_version(self, version, compact): + # For arrays, deserializer name is (Compact)(Nullable)ArrayDeserializer. + element_deserializer_name = self.underlying.deserializer_name_in_version(version, compact) + return '%sArrayDeserializer<%s>' % ("Compact" if compact else "", element_deserializer_name) - def default_value(self): - return 'std::vector<%s>{}' % (self.underlying.name) + def default_value(self): + return 'std::vector<%s>{}' % (self.underlying.name) - def has_flexible_handling(self): - return True + def has_flexible_handling(self): + return True - def example_value_for_test(self, version): - return 'std::vector<%s>{ %s }' % (self.underlying.name, - self.underlying.example_value_for_test(version)) + def example_value_for_test(self, version): + return 'std::vector<%s>{ %s }' % (self.underlying.name, + self.underlying.example_value_for_test(version)) - def is_printable(self): - return self.underlying.is_printable() + def is_printable(self): + return self.underlying.is_printable() class Primitive(TypeSpecification): - """ + """ Represents a Kafka primitive value. """ - USABLE_PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] - - KAFKA_TYPE_TO_ENVOY_TYPE = { - 'string': 'std::string', - 'bool': 'bool', - 'int8': 'int8_t', - 'int16': 'int16_t', - 'int32': 'int32_t', - 'int64': 'int64_t', - 'bytes': 'Bytes', - 'tagged_fields': 'TaggedFields', - } - - KAFKA_TYPE_TO_DESERIALIZER = { - 'string': 'StringDeserializer', - 'bool': 'BooleanDeserializer', - 'int8': 'Int8Deserializer', - 'int16': 'Int16Deserializer', - 'int32': 'Int32Deserializer', - 'int64': 'Int64Deserializer', - 'bytes': 'BytesDeserializer', - 'tagged_fields': 'TaggedFieldsDeserializer', - } - - KAFKA_TYPE_TO_COMPACT_DESERIALIZER = { - 'string': 'CompactStringDeserializer', - 'bytes': 'CompactBytesDeserializer' - } - - # See https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages - KAFKA_TYPE_TO_DEFAULT_VALUE = { - 'string': '""', - 'bool': 'false', - 'int8': '0', - 'int16': '0', - 'int32': '0', - 'int64': '0', - 'bytes': '{}', - 'tagged_fields': 'TaggedFields({})', - } - - # Custom values that make test code more readable. - KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { - 'string': - '"string"', - 'bool': - 'false', - 'int8': - 'static_cast(8)', - 'int16': - 'static_cast(16)', - 'int32': - 'static_cast(32)', - 'int64': - 'static_cast(64)', - 'bytes': - 'Bytes({0, 1, 2, 3})', - 'tagged_fields': - 'TaggedFields{std::vector{{10, Bytes({1, 2, 3})}, {20, Bytes({4, 5, 6})}}}', - } - - def __init__(self, name, custom_default_value): - self.original_name = name - self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) - self.custom_default_value = custom_default_value - - @staticmethod - def compute(name, map): - if name in map: - return map[name] - else: - raise ValueError(name) - - def compute_declaration_chain(self): - # Primitives need no declarations. - return [] - - def deserializer_name_in_version(self, version, compact): - if compact and self.original_name in Primitive.KAFKA_TYPE_TO_COMPACT_DESERIALIZER.keys(): - return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_COMPACT_DESERIALIZER) - else: - return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) - - def default_value(self): - if self.custom_default_value is not None: - return self.custom_default_value - else: - return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) - - def has_flexible_handling(self): - return self.original_name in ['string', 'bytes', 'tagged_fields'] - - def example_value_for_test(self, version): - return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) - - def is_printable(self): - return self.name not in ['Bytes'] + USABLE_PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] + + KAFKA_TYPE_TO_ENVOY_TYPE = { + 'string': 'std::string', + 'bool': 'bool', + 'int8': 'int8_t', + 'int16': 'int16_t', + 'int32': 'int32_t', + 'int64': 'int64_t', + 'bytes': 'Bytes', + 'tagged_fields': 'TaggedFields', + } + + KAFKA_TYPE_TO_DESERIALIZER = { + 'string': 'StringDeserializer', + 'bool': 'BooleanDeserializer', + 'int8': 'Int8Deserializer', + 'int16': 'Int16Deserializer', + 'int32': 'Int32Deserializer', + 'int64': 'Int64Deserializer', + 'bytes': 'BytesDeserializer', + 'tagged_fields': 'TaggedFieldsDeserializer', + } + + KAFKA_TYPE_TO_COMPACT_DESERIALIZER = { + 'string': 'CompactStringDeserializer', + 'bytes': 'CompactBytesDeserializer' + } + + # See https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages + KAFKA_TYPE_TO_DEFAULT_VALUE = { + 'string': '""', + 'bool': 'false', + 'int8': '0', + 'int16': '0', + 'int32': '0', + 'int64': '0', + 'bytes': '{}', + 'tagged_fields': 'TaggedFields({})', + } + + # Custom values that make test code more readable. + KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { + 'string': + '"string"', + 'bool': + 'false', + 'int8': + 'static_cast(8)', + 'int16': + 'static_cast(16)', + 'int32': + 'static_cast(32)', + 'int64': + 'static_cast(64)', + 'bytes': + 'Bytes({0, 1, 2, 3})', + 'tagged_fields': + 'TaggedFields{std::vector{{10, Bytes({1, 2, 3})}, {20, Bytes({4, 5, 6})}}}', + } + + def __init__(self, name, custom_default_value): + self.original_name = name + self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) + self.custom_default_value = custom_default_value + + @staticmethod + def compute(name, map): + if name in map: + return map[name] + else: + raise ValueError(name) + + def compute_declaration_chain(self): + # Primitives need no declarations. + return [] + + def deserializer_name_in_version(self, version, compact): + if compact and self.original_name in Primitive.KAFKA_TYPE_TO_COMPACT_DESERIALIZER.keys(): + return Primitive.compute(self.original_name, + Primitive.KAFKA_TYPE_TO_COMPACT_DESERIALIZER) + else: + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) + + def default_value(self): + if self.custom_default_value is not None: + return self.custom_default_value + else: + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) + + def has_flexible_handling(self): + return self.original_name in ['string', 'bytes', 'tagged_fields'] + + def example_value_for_test(self, version): + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) + + def is_printable(self): + return self.name not in ['Bytes'] class FieldSerializationSpec(): - def __init__(self, field, versions, compute_size_method_name, encode_method_name): - self.field = field - self.versions = versions - self.compute_size_method_name = compute_size_method_name - self.encode_method_name = encode_method_name + def __init__(self, field, versions, compute_size_method_name, encode_method_name): + self.field = field + self.versions = versions + self.compute_size_method_name = compute_size_method_name + self.encode_method_name = encode_method_name class Complex(TypeSpecification): - """ + """ Represents a complex type (multiple types aggregated into one). This type gets mapped to a C++ struct. """ - def __init__(self, name, fields, versions): - self.name = name - self.fields = fields - self.versions = versions - self.flexible_versions = None # Will be set in 'register_flexible_versions'. - self.attributes = {} - - def register_flexible_versions(self, flexible_versions): - # If flexible versions are present, so we need to add placeholder 'tagged_fields' field to - # *every* type that's used in by this message type. - for type in self.compute_declaration_chain(): - type.flexible_versions = flexible_versions - if len(flexible_versions) > 0: - tagged_fields_field = FieldSpec('tagged_fields', Primitive('tagged_fields', None), - flexible_versions, []) - type.fields.append(tagged_fields_field) - - def compute_declaration_chain(self): - """ + def __init__(self, name, fields, versions): + self.name = name + self.fields = fields + self.versions = versions + self.flexible_versions = None # Will be set in 'register_flexible_versions'. + self.attributes = {} + + def register_flexible_versions(self, flexible_versions): + # If flexible versions are present, so we need to add placeholder 'tagged_fields' field to + # *every* type that's used in by this message type. + for type in self.compute_declaration_chain(): + type.flexible_versions = flexible_versions + if len(flexible_versions) > 0: + tagged_fields_field = FieldSpec('tagged_fields', Primitive('tagged_fields', None), + flexible_versions, []) + type.fields.append(tagged_fields_field) + + def compute_declaration_chain(self): + """ Computes all dependencies, what means all non-primitive types used by this type. They need to be declared before this struct is declared. """ - result = [] - for field in self.fields: - field_dependencies = field.type.compute_declaration_chain() - for field_dependency in field_dependencies: - if field_dependency not in result: - result.append(field_dependency) - result.append(self) - return result - - def with_extra(self, key, value): - self.attributes[key] = value - return self - - def get_extra(self, key): - return self.attributes[key] - - def compute_constructors(self): - """ + result = [] + for field in self.fields: + field_dependencies = field.type.compute_declaration_chain() + for field_dependency in field_dependencies: + if field_dependency not in result: + result.append(field_dependency) + result.append(self) + return result + + def with_extra(self, key, value): + self.attributes[key] = value + return self + + def get_extra(self, key): + return self.attributes[key] + + def compute_constructors(self): + """ Field lists for different versions may not differ (as Kafka can bump version without any changes). But constructors need to be unique, so we need to remove duplicates if the signatures match. """ - signature_to_constructor = {} - for field_list in self.compute_field_lists(): - signature = field_list.constructor_signature() - constructor = signature_to_constructor.get(signature) - if constructor is None: - entry = {} - entry['versions'] = [field_list.version] - entry['signature'] = signature - if (len(signature) > 0): - entry['full_declaration'] = '%s(%s): %s {};' % (self.name, signature, - field_list.constructor_init_list()) - else: - entry['full_declaration'] = '%s() {};' % self.name - signature_to_constructor[signature] = entry - else: - constructor['versions'].append(field_list.version) - return sorted(signature_to_constructor.values(), key=lambda x: x['versions'][0]) - - def compute_field_lists(self): - """ + signature_to_constructor = {} + for field_list in self.compute_field_lists(): + signature = field_list.constructor_signature() + constructor = signature_to_constructor.get(signature) + if constructor is None: + entry = {} + entry['versions'] = [field_list.version] + entry['signature'] = signature + if (len(signature) > 0): + entry['full_declaration'] = '%s(%s): %s {};' % ( + self.name, signature, field_list.constructor_init_list()) + else: + entry['full_declaration'] = '%s() {};' % self.name + signature_to_constructor[signature] = entry + else: + constructor['versions'].append(field_list.version) + return sorted(signature_to_constructor.values(), key=lambda x: x['versions'][0]) + + def compute_field_lists(self): + """ Return field lists representing each of structure versions. """ - field_lists = [] - for version in self.versions: - field_list = FieldList(version, version in self.flexible_versions, self.fields) - field_lists.append(field_list) - return field_lists - - def compute_serialization_specs(self): - result = [] - for field in self.fields: - if field.type.has_flexible_handling(): - flexible = [x for x in field.version_usage if x in self.flexible_versions] - non_flexible = [x for x in field.version_usage if x not in flexible] - if non_flexible: - result.append(FieldSerializationSpec(field, non_flexible, 'computeSize', 'encode')) - if flexible: - result.append( - FieldSerializationSpec(field, flexible, 'computeCompactSize', 'encodeCompact')) - else: - result.append(FieldSerializationSpec(field, field.version_usage, 'computeSize', 'encode')) - return result - - def deserializer_name_in_version(self, version, compact): - return '%sV%dDeserializer' % (self.name, version) - - def name_in_c_case(self): - import re - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', self.name) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() - - def default_value(self): - raise NotImplementedError('unable to create default value of complex type') - - def has_flexible_handling(self): - return False - - def example_value_for_test(self, version): - field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) - example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) - return '%s(%s)' % (self.name, ', '.join(example_values)) - - def is_printable(self): - return True + field_lists = [] + for version in self.versions: + field_list = FieldList(version, version in self.flexible_versions, self.fields) + field_lists.append(field_list) + return field_lists + + def compute_serialization_specs(self): + result = [] + for field in self.fields: + if field.type.has_flexible_handling(): + flexible = [x for x in field.version_usage if x in self.flexible_versions] + non_flexible = [x for x in field.version_usage if x not in flexible] + if non_flexible: + result.append( + FieldSerializationSpec(field, non_flexible, 'computeSize', 'encode')) + if flexible: + result.append( + FieldSerializationSpec(field, flexible, 'computeCompactSize', + 'encodeCompact')) + else: + result.append( + FieldSerializationSpec(field, field.version_usage, 'computeSize', 'encode')) + return result + + def deserializer_name_in_version(self, version, compact): + return '%sV%dDeserializer' % (self.name, version) + + def name_in_c_case(self): + import re + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', self.name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + def default_value(self): + raise NotImplementedError('unable to create default value of complex type') + + def has_flexible_handling(self): + return False + + def example_value_for_test(self, version): + field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) + example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) + return '%s(%s)' % (self.name, ', '.join(example_values)) + + def is_printable(self): + return True class RenderingHelper: - """ + """ Helper for jinja templates. """ - @staticmethod - def get_template(template): - import jinja2 - import os - import sys - # Templates are resolved relatively to main start script, due to main & test templates being - # stored in different directories. - env = jinja2.Environment(loader=jinja2.FileSystemLoader( - searchpath=os.path.dirname(os.path.abspath(sys.argv[0])))) - return env.get_template(template) + @staticmethod + def get_template(template): + import jinja2 + import os + import sys + # Templates are resolved relatively to main start script, due to main & test templates being + # stored in different directories. + env = jinja2.Environment(loader=jinja2.FileSystemLoader( + searchpath=os.path.dirname(os.path.abspath(sys.argv[0])))) + return env.get_template(template) diff --git a/source/extensions/filters/network/kafka/protocol/launcher.py b/source/extensions/filters/network/kafka/protocol/launcher.py index ce16acb63242..449e41e70d8c 100644 --- a/source/extensions/filters/network/kafka/protocol/launcher.py +++ b/source/extensions/filters/network/kafka/protocol/launcher.py @@ -8,7 +8,7 @@ def main(): - """ + """ Kafka code generator script ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generates C++ code from Kafka protocol specification for Kafka codec. @@ -38,14 +38,14 @@ def main(): - to create '${MESSAGE_TYPE}_metrics.h': ${MESSAGE_TYPE}_metrics_h.j2. """ - type = sys.argv[1] - main_header_file = os.path.abspath(sys.argv[2]) - resolver_cc_file = os.path.abspath(sys.argv[3]) - metrics_h_file = os.path.abspath(sys.argv[4]) - input_files = sys.argv[5:] - generator.generate_main_code(type, main_header_file, resolver_cc_file, metrics_h_file, - input_files) + type = sys.argv[1] + main_header_file = os.path.abspath(sys.argv[2]) + resolver_cc_file = os.path.abspath(sys.argv[3]) + metrics_h_file = os.path.abspath(sys.argv[4]) + input_files = sys.argv[5:] + generator.generate_main_code(type, main_header_file, resolver_cc_file, metrics_h_file, + input_files) if __name__ == "__main__": - main() + main() diff --git a/source/extensions/filters/network/kafka/serialization/generator.py b/source/extensions/filters/network/kafka/serialization/generator.py index a05012e4837d..d28e2650300a 100755 --- a/source/extensions/filters/network/kafka/serialization/generator.py +++ b/source/extensions/filters/network/kafka/serialization/generator.py @@ -4,54 +4,54 @@ def generate_main_code(serialization_composite_h_file): - """ + """ Main code generator. Renders the header file for serialization composites. The location of output file is provided as argument. """ - generate_code('serialization_composite_h.j2', serialization_composite_h_file) + generate_code('serialization_composite_h.j2', serialization_composite_h_file) def generate_test_code(serialization_composite_test_cc_file): - """ + """ Test code generator. Renders the test file for serialization composites. The location of output file is provided as argument. """ - generate_code('serialization_composite_test_cc.j2', serialization_composite_test_cc_file) + generate_code('serialization_composite_test_cc.j2', serialization_composite_test_cc_file) def generate_code(template_name, output_file): - """ + """ Gets definition of structures to render. Then renders these structures using template provided into provided output file. """ - field_counts = get_field_counts() - template = RenderingHelper.get_template(template_name) - contents = template.render(counts=field_counts) - with open(output_file, 'w') as fd: - fd.write(contents) + field_counts = get_field_counts() + template = RenderingHelper.get_template(template_name) + contents = template.render(counts=field_counts) + with open(output_file, 'w') as fd: + fd.write(contents) def get_field_counts(): - """ + """ Generate argument counts that should be processed by composite deserializers. """ - return range(1, 12) + return range(1, 12) class RenderingHelper: - """ + """ Helper for jinja templates. """ - @staticmethod - def get_template(template): - import jinja2 - import os - import sys - # Templates are resolved relatively to main start script, due to main & test templates being - # stored in different directories. - env = jinja2.Environment(loader=jinja2.FileSystemLoader( - searchpath=os.path.dirname(os.path.abspath(sys.argv[0])))) - return env.get_template(template) + @staticmethod + def get_template(template): + import jinja2 + import os + import sys + # Templates are resolved relatively to main start script, due to main & test templates being + # stored in different directories. + env = jinja2.Environment(loader=jinja2.FileSystemLoader( + searchpath=os.path.dirname(os.path.abspath(sys.argv[0])))) + return env.get_template(template) diff --git a/source/extensions/filters/network/kafka/serialization/launcher.py b/source/extensions/filters/network/kafka/serialization/launcher.py index 2f63ffb45b96..571f448086a9 100644 --- a/source/extensions/filters/network/kafka/serialization/launcher.py +++ b/source/extensions/filters/network/kafka/serialization/launcher.py @@ -8,7 +8,7 @@ def main(): - """ + """ Serialization composite code generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generates main source code files for composite deserializers. @@ -25,9 +25,9 @@ def main(): Template used: 'serialization_composite_h.j2'. """ - serialization_composite_h_file = os.path.abspath(sys.argv[1]) - generator.generate_main_code(serialization_composite_h_file) + serialization_composite_h_file = os.path.abspath(sys.argv[1]) + generator.generate_main_code(serialization_composite_h_file) if __name__ == "__main__": - main() + main() diff --git a/test/common/json/config_schemas_test_data/generate_test_data.py b/test/common/json/config_schemas_test_data/generate_test_data.py index ac06cb0f5e42..eb5a6e5e7cd6 100755 --- a/test/common/json/config_schemas_test_data/generate_test_data.py +++ b/test/common/json/config_schemas_test_data/generate_test_data.py @@ -7,20 +7,20 @@ def main(): - test_dir = os.path.join(os.environ['TEST_TMPDIR'], 'config_schemas_test') - # Clean after previous run. This might happen e.g. with "threadsafe" Death Tests, - # where child process re-executes the unit test binary in the same workspace. - if os.path.isdir(test_dir): - shutil.rmtree(test_dir) - os.mkdir(test_dir) - writer = util.TestWriter(test_dir) + test_dir = os.path.join(os.environ['TEST_TMPDIR'], 'config_schemas_test') + # Clean after previous run. This might happen e.g. with "threadsafe" Death Tests, + # where child process re-executes the unit test binary in the same workspace. + if os.path.isdir(test_dir): + shutil.rmtree(test_dir) + os.mkdir(test_dir) + writer = util.TestWriter(test_dir) - # test discovery and execution - test_files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py")) - for test_file in test_files: - module_name = os.path.splitext(os.path.basename(test_file))[0] - __import__(module_name).test(writer) + # test discovery and execution + test_files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py")) + for test_file in test_files: + module_name = os.path.splitext(os.path.basename(test_file))[0] + __import__(module_name).test(writer) if __name__ == '__main__': - main() + main() diff --git a/test/common/json/config_schemas_test_data/test_access_log_schema.py b/test/common/json/config_schemas_test_data/test_access_log_schema.py index 09aae1658095..c4391985fb42 100644 --- a/test/common/json/config_schemas_test_data/test_access_log_schema.py +++ b/test/common/json/config_schemas_test_data/test_access_log_schema.py @@ -42,93 +42,93 @@ def test(writer): - for idx, item in enumerate(ACCESS_LOG_BLOB["access_log"]): + for idx, item in enumerate(ACCESS_LOG_BLOB["access_log"]): + writer.write_test_file( + 'Valid_idx_' + str(idx), + schema='ACCESS_LOG_SCHEMA', + data=get_blob(item), + throws=False, + ) + + blob = get_blob(ACCESS_LOG_BLOB)['access_log'][1] + blob['filter']['filters'][0]['op'] = '<' writer.write_test_file( - 'Valid_idx_' + str(idx), + 'FilterOperatorIsNotSupportedLessThan', schema='ACCESS_LOG_SCHEMA', - data=get_blob(item), - throws=False, + data=blob, + throws=True, ) - blob = get_blob(ACCESS_LOG_BLOB)['access_log'][1] - blob['filter']['filters'][0]['op'] = '<' - writer.write_test_file( - 'FilterOperatorIsNotSupportedLessThan', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) - - blob = get_blob(ACCESS_LOG_BLOB)['access_log'][1] - blob['filter']['filters'][0]['op'] = '<=' - writer.write_test_file( - 'FilterOperatorIsNotSupportedLessThanEqual', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = get_blob(ACCESS_LOG_BLOB)['access_log'][1] + blob['filter']['filters'][0]['op'] = '<=' + writer.write_test_file( + 'FilterOperatorIsNotSupportedLessThanEqual', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) - blob = get_blob(ACCESS_LOG_BLOB)['access_log'][1] - blob['filter']['filters'][0]['op'] = '>' - writer.write_test_file( - 'FilterOperatorIsNotSupportedGreaterThan', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = get_blob(ACCESS_LOG_BLOB)['access_log'][1] + blob['filter']['filters'][0]['op'] = '>' + writer.write_test_file( + 'FilterOperatorIsNotSupportedGreaterThan', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) - blob = {"path": "/dev/null", "filter": {"type": "unknown"}} - writer.write_test_file( - 'FilterTypeIsNotSupported', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = {"path": "/dev/null", "filter": {"type": "unknown"}} + writer.write_test_file( + 'FilterTypeIsNotSupported', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) - blob = {"path": "/dev/null", "filter": {"type": "logical_or", "filters": []}} - writer.write_test_file( - 'LessThanTwoFiltersInListNoneLogicalOrThrows', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = {"path": "/dev/null", "filter": {"type": "logical_or", "filters": []}} + writer.write_test_file( + 'LessThanTwoFiltersInListNoneLogicalOrThrows', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) - blob = {"path": "/dev/null", "filter": {"type": "logical_and", "filters": []}} - writer.write_test_file( - 'LessThanTwoFiltersInListNoneLogicalAndThrows', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = {"path": "/dev/null", "filter": {"type": "logical_and", "filters": []}} + writer.write_test_file( + 'LessThanTwoFiltersInListNoneLogicalAndThrows', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) - blob = { - "path": "/dev/null", - "filter": { - "type": "logical_or", - "filters": [{ - "type": "not_healthcheck" - }] - } - } - writer.write_test_file( - 'LessThanTwoFiltersInListOneLogicalOrThrows', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = { + "path": "/dev/null", + "filter": { + "type": "logical_or", + "filters": [{ + "type": "not_healthcheck" + }] + } + } + writer.write_test_file( + 'LessThanTwoFiltersInListOneLogicalOrThrows', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) - blob = { - "path": "/dev/null", - "filter": { - "type": "logical_and", - "filters": [{ - "type": "not_healthcheck" - }] - } - } - writer.write_test_file( - 'LessThanTwoFiltersInListOneLogicalAndThrows', - schema='ACCESS_LOG_SCHEMA', - data=blob, - throws=True, - ) + blob = { + "path": "/dev/null", + "filter": { + "type": "logical_and", + "filters": [{ + "type": "not_healthcheck" + }] + } + } + writer.write_test_file( + 'LessThanTwoFiltersInListOneLogicalAndThrows', + schema='ACCESS_LOG_SCHEMA', + data=blob, + throws=True, + ) diff --git a/test/common/json/config_schemas_test_data/test_cluster_schema.py b/test/common/json/config_schemas_test_data/test_cluster_schema.py index a16ac2e0e370..eaa5b1151ea3 100644 --- a/test/common/json/config_schemas_test_data/test_cluster_schema.py +++ b/test/common/json/config_schemas_test_data/test_cluster_schema.py @@ -24,18 +24,18 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='CLUSTER_SCHEMA', - data=get_blob(CLUSTER_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='CLUSTER_SCHEMA', + data=get_blob(CLUSTER_BLOB), + throws=False, + ) - blob = get_blob(CLUSTER_BLOB) - blob['features'] = "nonexistentfeature" - writer.write_test_file( - 'UnsupportedFeature', - schema='CLUSTER_SCHEMA', - data=blob, - throws=True, - ) + blob = get_blob(CLUSTER_BLOB) + blob['features'] = "nonexistentfeature" + writer.write_test_file( + 'UnsupportedFeature', + schema='CLUSTER_SCHEMA', + data=blob, + throws=True, + ) diff --git a/test/common/json/config_schemas_test_data/test_http_conn_network_filter_schema.py b/test/common/json/config_schemas_test_data/test_http_conn_network_filter_schema.py index e7566cdfcf6a..a7cb2ba5e5b6 100644 --- a/test/common/json/config_schemas_test_data/test_http_conn_network_filter_schema.py +++ b/test/common/json/config_schemas_test_data/test_http_conn_network_filter_schema.py @@ -29,9 +29,9 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='HTTP_CONN_NETWORK_FILTER_SCHEMA', - data=get_blob(HTTP_CONN_NETWORK_FILTER_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='HTTP_CONN_NETWORK_FILTER_SCHEMA', + data=get_blob(HTTP_CONN_NETWORK_FILTER_BLOB), + throws=False, + ) diff --git a/test/common/json/config_schemas_test_data/test_http_router_schema.py b/test/common/json/config_schemas_test_data/test_http_router_schema.py index 644a6b20f4b4..7f1663986449 100644 --- a/test/common/json/config_schemas_test_data/test_http_router_schema.py +++ b/test/common/json/config_schemas_test_data/test_http_router_schema.py @@ -5,16 +5,16 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='ROUTER_HTTP_FILTER_SCHEMA', - data=get_blob(ROUTER_HTTP_FILTER_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='ROUTER_HTTP_FILTER_SCHEMA', + data=get_blob(ROUTER_HTTP_FILTER_BLOB), + throws=False, + ) - writer.write_test_file( - 'ValidDefaults', - schema='ROUTER_HTTP_FILTER_SCHEMA', - data={}, - throws=False, - ) + writer.write_test_file( + 'ValidDefaults', + schema='ROUTER_HTTP_FILTER_SCHEMA', + data={}, + throws=False, + ) diff --git a/test/common/json/config_schemas_test_data/test_listener_schema.py b/test/common/json/config_schemas_test_data/test_listener_schema.py index 11ee3a5b45c2..586213f5c87f 100644 --- a/test/common/json/config_schemas_test_data/test_listener_schema.py +++ b/test/common/json/config_schemas_test_data/test_listener_schema.py @@ -15,9 +15,9 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='LISTENER_SCHEMA', - data=get_blob(LISTENER_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='LISTENER_SCHEMA', + data=get_blob(LISTENER_BLOB), + throws=False, + ) diff --git a/test/common/json/config_schemas_test_data/test_route_configuration_schema.py b/test/common/json/config_schemas_test_data/test_route_configuration_schema.py index fb492747e1fb..71d73d9d5def 100644 --- a/test/common/json/config_schemas_test_data/test_route_configuration_schema.py +++ b/test/common/json/config_schemas_test_data/test_route_configuration_schema.py @@ -18,9 +18,9 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='ROUTE_CONFIGURATION_SCHEMA', - data=get_blob(ROUTE_CONFIGURATION_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='ROUTE_CONFIGURATION_SCHEMA', + data=get_blob(ROUTE_CONFIGURATION_BLOB), + throws=False, + ) diff --git a/test/common/json/config_schemas_test_data/test_route_entry_schema.py b/test/common/json/config_schemas_test_data/test_route_entry_schema.py index 0aa1565a5fba..30c2e413ec22 100644 --- a/test/common/json/config_schemas_test_data/test_route_entry_schema.py +++ b/test/common/json/config_schemas_test_data/test_route_entry_schema.py @@ -10,17 +10,17 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='ROUTE_ENTRY_CONFIGURATION_SCHEMA', - data=get_blob(ROUTE_ENTRY_CONFIGURATION_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='ROUTE_ENTRY_CONFIGURATION_SCHEMA', + data=get_blob(ROUTE_ENTRY_CONFIGURATION_BLOB), + throws=False, + ) - blob = {"prefix": "/foo", "cluster": "local_service_grpc", "priority": "foo"} - writer.write_test_file( - 'InvalidPriority', - schema='ROUTE_ENTRY_CONFIGURATION_SCHEMA', - data=blob, - throws=True, - ) + blob = {"prefix": "/foo", "cluster": "local_service_grpc", "priority": "foo"} + writer.write_test_file( + 'InvalidPriority', + schema='ROUTE_ENTRY_CONFIGURATION_SCHEMA', + data=blob, + throws=True, + ) diff --git a/test/common/json/config_schemas_test_data/test_top_level_config_schema.py b/test/common/json/config_schemas_test_data/test_top_level_config_schema.py index 206d9f65ae33..87a89cc538c3 100644 --- a/test/common/json/config_schemas_test_data/test_top_level_config_schema.py +++ b/test/common/json/config_schemas_test_data/test_top_level_config_schema.py @@ -32,18 +32,18 @@ def test(writer): - writer.write_test_file( - 'Valid', - schema='TOP_LEVEL_CONFIG_SCHEMA', - data=get_blob(TOP_LEVEL_CONFIG_BLOB), - throws=False, - ) + writer.write_test_file( + 'Valid', + schema='TOP_LEVEL_CONFIG_SCHEMA', + data=get_blob(TOP_LEVEL_CONFIG_BLOB), + throws=False, + ) - blob = get_blob(TOP_LEVEL_CONFIG_BLOB) - blob['tracing']['http']['driver']['type'] = 'unknown' - writer.write_test_file( - 'UnsupportedTracingDriver', - schema='TOP_LEVEL_CONFIG_SCHEMA', - data=blob, - throws=True, - ) + blob = get_blob(TOP_LEVEL_CONFIG_BLOB) + blob['tracing']['http']['driver']['type'] = 'unknown' + writer.write_test_file( + 'UnsupportedTracingDriver', + schema='TOP_LEVEL_CONFIG_SCHEMA', + data=blob, + throws=True, + ) diff --git a/test/common/json/config_schemas_test_data/util.py b/test/common/json/config_schemas_test_data/util.py index 75349059b2fc..69096a442d65 100644 --- a/test/common/json/config_schemas_test_data/util.py +++ b/test/common/json/config_schemas_test_data/util.py @@ -7,18 +7,19 @@ def get_blob(blob): - return copy.deepcopy(blob) + return copy.deepcopy(blob) class TestWriter(object): - def __init__(self, test_dir): - self.test_dir = test_dir + def __init__(self, test_dir): + self.test_dir = test_dir - def write_test_file(self, name, schema, data, throws): - test_filename = os.path.join(self.test_dir, 'schematest-%s-%s.json' % (schema, name)) - if os.path.isfile(test_filename): - raise ValueError('Test with that name and schema already exists: {}'.format(test_filename)) + def write_test_file(self, name, schema, data, throws): + test_filename = os.path.join(self.test_dir, 'schematest-%s-%s.json' % (schema, name)) + if os.path.isfile(test_filename): + raise ValueError( + 'Test with that name and schema already exists: {}'.format(test_filename)) - with open(test_filename, 'w+') as fh: - json.dump({"schema": schema, "throws": throws, "data": data}, fh, indent=True) + with open(test_filename, 'w+') as fh: + json.dump({"schema": schema, "throws": throws, "data": data}, fh, indent=True) diff --git a/test/extensions/filters/network/kafka/broker/integration_test/kafka_broker_integration_test.py b/test/extensions/filters/network/kafka/broker/integration_test/kafka_broker_integration_test.py index 490f1ae7ce61..7a809573b372 100644 --- a/test/extensions/filters/network/kafka/broker/integration_test/kafka_broker_integration_test.py +++ b/test/extensions/filters/network/kafka/broker/integration_test/kafka_broker_integration_test.py @@ -16,7 +16,7 @@ class KafkaBrokerIntegrationTest(unittest.TestCase): - """ + """ All tests in this class depend on Envoy/Zookeeper/Kafka running. For each of these tests we are going to create Kafka consumers/producers/admins and point them to Envoy (that proxies Kafka). @@ -24,601 +24,606 @@ class KafkaBrokerIntegrationTest(unittest.TestCase): to increase on Envoy side (to show that messages were received and forwarded successfully). """ - services = None + services = None - @classmethod - def setUpClass(cls): - KafkaBrokerIntegrationTest.services = ServicesHolder() - KafkaBrokerIntegrationTest.services.start() + @classmethod + def setUpClass(cls): + KafkaBrokerIntegrationTest.services = ServicesHolder() + KafkaBrokerIntegrationTest.services.start() - @classmethod - def tearDownClass(cls): - KafkaBrokerIntegrationTest.services.shut_down() + @classmethod + def tearDownClass(cls): + KafkaBrokerIntegrationTest.services.shut_down() - def setUp(self): - # We want to check if our services are okay before running any kind of test. - KafkaBrokerIntegrationTest.services.check_state() - self.metrics = MetricsHolder(self) + def setUp(self): + # We want to check if our services are okay before running any kind of test. + KafkaBrokerIntegrationTest.services.check_state() + self.metrics = MetricsHolder(self) - def tearDown(self): - # We want to check if our services are okay after running any test. - KafkaBrokerIntegrationTest.services.check_state() + def tearDown(self): + # We want to check if our services are okay after running any test. + KafkaBrokerIntegrationTest.services.check_state() - @classmethod - def kafka_address(cls): - return '127.0.0.1:%s' % KafkaBrokerIntegrationTest.services.kafka_envoy_port + @classmethod + def kafka_address(cls): + return '127.0.0.1:%s' % KafkaBrokerIntegrationTest.services.kafka_envoy_port - @classmethod - def envoy_stats_address(cls): - return 'http://127.0.0.1:%s/stats' % KafkaBrokerIntegrationTest.services.envoy_monitoring_port + @classmethod + def envoy_stats_address(cls): + return 'http://127.0.0.1:%s/stats' % KafkaBrokerIntegrationTest.services.envoy_monitoring_port - def test_kafka_consumer_with_no_messages_received(self): - """ + def test_kafka_consumer_with_no_messages_received(self): + """ This test verifies that consumer sends fetches correctly, and receives nothing. """ - consumer = KafkaConsumer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address(), - fetch_max_wait_ms=500) - consumer.assign([TopicPartition('test_kafka_consumer_with_no_messages_received', 0)]) - for _ in range(10): - records = consumer.poll(timeout_ms=1000) - self.assertEqual(len(records), 0) - - self.metrics.collect_final_metrics() - # 'consumer.poll()' can translate into 0 or more fetch requests. - # We have set API timeout to 1000ms, while fetch_max_wait is 500ms. - # This means that consumer will send roughly 2 (1000/500) requests per API call (so 20 total). - # So increase of 10 (half of that value) should be safe enough to test. - self.metrics.assert_metric_increase('fetch', 10) - # Metadata is used by consumer to figure out current partition leader. - self.metrics.assert_metric_increase('metadata', 1) - - def test_kafka_producer_and_consumer(self): - """ + consumer = KafkaConsumer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address(), + fetch_max_wait_ms=500) + consumer.assign([TopicPartition('test_kafka_consumer_with_no_messages_received', 0)]) + for _ in range(10): + records = consumer.poll(timeout_ms=1000) + self.assertEqual(len(records), 0) + + self.metrics.collect_final_metrics() + # 'consumer.poll()' can translate into 0 or more fetch requests. + # We have set API timeout to 1000ms, while fetch_max_wait is 500ms. + # This means that consumer will send roughly 2 (1000/500) requests per API call (so 20 total). + # So increase of 10 (half of that value) should be safe enough to test. + self.metrics.assert_metric_increase('fetch', 10) + # Metadata is used by consumer to figure out current partition leader. + self.metrics.assert_metric_increase('metadata', 1) + + def test_kafka_producer_and_consumer(self): + """ This test verifies that producer can send messages, and consumer can receive them. """ - messages_to_send = 100 - partition = TopicPartition('test_kafka_producer_and_consumer', 0) - - producer = KafkaProducer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address()) - for _ in range(messages_to_send): - future = producer.send(value=b'some_message_bytes', - topic=partition.topic, - partition=partition.partition) - send_status = future.get() - self.assertTrue(send_status.offset >= 0) - - consumer = KafkaConsumer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address(), - auto_offset_reset='earliest', - fetch_max_bytes=100) - consumer.assign([partition]) - received_messages = [] - while (len(received_messages) < messages_to_send): - poll_result = consumer.poll(timeout_ms=1000) - received_messages += poll_result[partition] - - self.metrics.collect_final_metrics() - self.metrics.assert_metric_increase('metadata', 2) - self.metrics.assert_metric_increase('produce', 100) - # 'fetch_max_bytes' was set to a very low value, so client will need to send a FetchRequest - # multiple times to broker to get all 100 messages (otherwise all 100 records could have been - # received in one go). - self.metrics.assert_metric_increase('fetch', 20) - # Both producer & consumer had to fetch cluster metadata. - self.metrics.assert_metric_increase('metadata', 2) - - def test_consumer_with_consumer_groups(self): - """ + messages_to_send = 100 + partition = TopicPartition('test_kafka_producer_and_consumer', 0) + + producer = KafkaProducer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address()) + for _ in range(messages_to_send): + future = producer.send(value=b'some_message_bytes', + topic=partition.topic, + partition=partition.partition) + send_status = future.get() + self.assertTrue(send_status.offset >= 0) + + consumer = KafkaConsumer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address(), + auto_offset_reset='earliest', + fetch_max_bytes=100) + consumer.assign([partition]) + received_messages = [] + while (len(received_messages) < messages_to_send): + poll_result = consumer.poll(timeout_ms=1000) + received_messages += poll_result[partition] + + self.metrics.collect_final_metrics() + self.metrics.assert_metric_increase('metadata', 2) + self.metrics.assert_metric_increase('produce', 100) + # 'fetch_max_bytes' was set to a very low value, so client will need to send a FetchRequest + # multiple times to broker to get all 100 messages (otherwise all 100 records could have been + # received in one go). + self.metrics.assert_metric_increase('fetch', 20) + # Both producer & consumer had to fetch cluster metadata. + self.metrics.assert_metric_increase('metadata', 2) + + def test_consumer_with_consumer_groups(self): + """ This test verifies that multiple consumers can form a Kafka consumer group. """ - consumer_count = 10 - consumers = [] - for id in range(consumer_count): - consumer = KafkaConsumer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address(), - group_id='test', - client_id='test-%s' % id) - consumer.subscribe(['test_consumer_with_consumer_groups']) - consumers.append(consumer) - - worker_threads = [] - for consumer in consumers: - thread = Thread(target=KafkaBrokerIntegrationTest.worker, args=(consumer,)) - thread.start() - worker_threads.append(thread) - - for thread in worker_threads: - thread.join() - - for consumer in consumers: - consumer.close() - - self.metrics.collect_final_metrics() - self.metrics.assert_metric_increase('api_versions', consumer_count) - self.metrics.assert_metric_increase('metadata', consumer_count) - self.metrics.assert_metric_increase('join_group', consumer_count) - self.metrics.assert_metric_increase('find_coordinator', consumer_count) - self.metrics.assert_metric_increase('leave_group', consumer_count) - - @staticmethod - def worker(consumer): - """ + consumer_count = 10 + consumers = [] + for id in range(consumer_count): + consumer = KafkaConsumer(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address(), + group_id='test', + client_id='test-%s' % id) + consumer.subscribe(['test_consumer_with_consumer_groups']) + consumers.append(consumer) + + worker_threads = [] + for consumer in consumers: + thread = Thread(target=KafkaBrokerIntegrationTest.worker, args=(consumer,)) + thread.start() + worker_threads.append(thread) + + for thread in worker_threads: + thread.join() + + for consumer in consumers: + consumer.close() + + self.metrics.collect_final_metrics() + self.metrics.assert_metric_increase('api_versions', consumer_count) + self.metrics.assert_metric_increase('metadata', consumer_count) + self.metrics.assert_metric_increase('join_group', consumer_count) + self.metrics.assert_metric_increase('find_coordinator', consumer_count) + self.metrics.assert_metric_increase('leave_group', consumer_count) + + @staticmethod + def worker(consumer): + """ Worker thread for Kafka consumer. Multiple poll-s are done here, so that the group can safely form. """ - poll_operations = 10 - for i in range(poll_operations): - consumer.poll(timeout_ms=1000) + poll_operations = 10 + for i in range(poll_operations): + consumer.poll(timeout_ms=1000) - def test_admin_client(self): - """ + def test_admin_client(self): + """ This test verifies that Kafka Admin Client can still be used to manage Kafka. """ - admin_client = KafkaAdminClient(bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address()) - - # Create a topic with 3 partitions. - new_topic_spec = NewTopic(name='test_admin_client', num_partitions=3, replication_factor=1) - create_response = admin_client.create_topics([new_topic_spec]) - error_data = create_response.topic_errors - self.assertEqual(len(error_data), 1) - self.assertEqual(error_data[0], (new_topic_spec.name, 0, None)) - - # Alter topic (change some Kafka-level property). - config_resource = ConfigResource(ConfigResourceType.TOPIC, new_topic_spec.name, - {'flush.messages': 42}) - alter_response = admin_client.alter_configs([config_resource]) - error_data = alter_response.resources - self.assertEqual(len(error_data), 1) - self.assertEqual(error_data[0][0], 0) - - # Add 2 more partitions to topic. - new_partitions_spec = {new_topic_spec.name: NewPartitions(5)} - new_partitions_response = admin_client.create_partitions(new_partitions_spec) - error_data = create_response.topic_errors - self.assertEqual(len(error_data), 1) - self.assertEqual(error_data[0], (new_topic_spec.name, 0, None)) - - # Delete a topic. - delete_response = admin_client.delete_topics([new_topic_spec.name]) - error_data = create_response.topic_errors - self.assertEqual(len(error_data), 1) - self.assertEqual(error_data[0], (new_topic_spec.name, 0, None)) - - self.metrics.collect_final_metrics() - self.metrics.assert_metric_increase('create_topics', 1) - self.metrics.assert_metric_increase('alter_configs', 1) - self.metrics.assert_metric_increase('create_partitions', 1) - self.metrics.assert_metric_increase('delete_topics', 1) + admin_client = KafkaAdminClient( + bootstrap_servers=KafkaBrokerIntegrationTest.kafka_address()) + + # Create a topic with 3 partitions. + new_topic_spec = NewTopic(name='test_admin_client', num_partitions=3, replication_factor=1) + create_response = admin_client.create_topics([new_topic_spec]) + error_data = create_response.topic_errors + self.assertEqual(len(error_data), 1) + self.assertEqual(error_data[0], (new_topic_spec.name, 0, None)) + + # Alter topic (change some Kafka-level property). + config_resource = ConfigResource(ConfigResourceType.TOPIC, new_topic_spec.name, + {'flush.messages': 42}) + alter_response = admin_client.alter_configs([config_resource]) + error_data = alter_response.resources + self.assertEqual(len(error_data), 1) + self.assertEqual(error_data[0][0], 0) + + # Add 2 more partitions to topic. + new_partitions_spec = {new_topic_spec.name: NewPartitions(5)} + new_partitions_response = admin_client.create_partitions(new_partitions_spec) + error_data = create_response.topic_errors + self.assertEqual(len(error_data), 1) + self.assertEqual(error_data[0], (new_topic_spec.name, 0, None)) + + # Delete a topic. + delete_response = admin_client.delete_topics([new_topic_spec.name]) + error_data = create_response.topic_errors + self.assertEqual(len(error_data), 1) + self.assertEqual(error_data[0], (new_topic_spec.name, 0, None)) + + self.metrics.collect_final_metrics() + self.metrics.assert_metric_increase('create_topics', 1) + self.metrics.assert_metric_increase('alter_configs', 1) + self.metrics.assert_metric_increase('create_partitions', 1) + self.metrics.assert_metric_increase('delete_topics', 1) class MetricsHolder: - """ + """ Utility for storing Envoy metrics. Expected to be created before the test (to get initial metrics), and then to collect them at the end of test, so the expected increases can be verified. """ - def __init__(self, owner): - self.owner = owner - self.initial_requests, self.inital_responses = MetricsHolder.get_envoy_stats() - self.final_requests = None - self.final_responses = None + def __init__(self, owner): + self.owner = owner + self.initial_requests, self.inital_responses = MetricsHolder.get_envoy_stats() + self.final_requests = None + self.final_responses = None - def collect_final_metrics(self): - self.final_requests, self.final_responses = MetricsHolder.get_envoy_stats() + def collect_final_metrics(self): + self.final_requests, self.final_responses = MetricsHolder.get_envoy_stats() - def assert_metric_increase(self, message_type, count): - request_type = message_type + '_request' - response_type = message_type + '_response' + def assert_metric_increase(self, message_type, count): + request_type = message_type + '_request' + response_type = message_type + '_response' - initial_request_value = self.initial_requests.get(request_type, 0) - final_request_value = self.final_requests.get(request_type, 0) - self.owner.assertGreaterEqual(final_request_value, initial_request_value + count) + initial_request_value = self.initial_requests.get(request_type, 0) + final_request_value = self.final_requests.get(request_type, 0) + self.owner.assertGreaterEqual(final_request_value, initial_request_value + count) - initial_response_value = self.inital_responses.get(response_type, 0) - final_response_value = self.final_responses.get(response_type, 0) - self.owner.assertGreaterEqual(final_response_value, initial_response_value + count) + initial_response_value = self.inital_responses.get(response_type, 0) + final_response_value = self.final_responses.get(response_type, 0) + self.owner.assertGreaterEqual(final_response_value, initial_response_value + count) - @staticmethod - def get_envoy_stats(): - """ + @staticmethod + def get_envoy_stats(): + """ Grab request/response metrics from envoy's stats interface. """ - stats_url = KafkaBrokerIntegrationTest.envoy_stats_address() - requests = {} - responses = {} - with urllib.request.urlopen(stats_url) as remote_metrics_url: - payload = remote_metrics_url.read().decode() - lines = payload.splitlines() - for line in lines: - request_prefix = 'kafka.testfilter.request.' - response_prefix = 'kafka.testfilter.response.' - if line.startswith(request_prefix): - data = line[len(request_prefix):].split(': ') - requests[data[0]] = int(data[1]) - pass - if line.startswith(response_prefix) and '_response:' in line: - data = line[len(response_prefix):].split(': ') - responses[data[0]] = int(data[1]) - return [requests, responses] + stats_url = KafkaBrokerIntegrationTest.envoy_stats_address() + requests = {} + responses = {} + with urllib.request.urlopen(stats_url) as remote_metrics_url: + payload = remote_metrics_url.read().decode() + lines = payload.splitlines() + for line in lines: + request_prefix = 'kafka.testfilter.request.' + response_prefix = 'kafka.testfilter.response.' + if line.startswith(request_prefix): + data = line[len(request_prefix):].split(': ') + requests[data[0]] = int(data[1]) + pass + if line.startswith(response_prefix) and '_response:' in line: + data = line[len(response_prefix):].split(': ') + responses[data[0]] = int(data[1]) + return [requests, responses] class ServicesHolder: - """ + """ Utility class for setting up our external dependencies: Envoy, Zookeeper & Kafka. """ - def __init__(self): - self.kafka_tmp_dir = None + def __init__(self): + self.kafka_tmp_dir = None - self.envoy_worker = None - self.zk_worker = None - self.kafka_worker = None + self.envoy_worker = None + self.zk_worker = None + self.kafka_worker = None - @staticmethod - def get_random_listener_port(): - """ + @staticmethod + def get_random_listener_port(): + """ Here we count on OS to give us some random socket. Obviously this method will need to be invoked in a try loop anyways, as in degenerate scenario someone else might have bound to it after we had closed the socket and before the service that's supposed to use it binds to it. """ - import socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: - server_socket.bind(('0.0.0.0', 0)) - socket_port = server_socket.getsockname()[1] - print('returning %s' % socket_port) - return socket_port + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: + server_socket.bind(('0.0.0.0', 0)) + socket_port = server_socket.getsockname()[1] + print('returning %s' % socket_port) + return socket_port - def start(self): - """ + def start(self): + """ Starts all the services we need for integration tests. """ - # Find java installation that we are going to use to start Zookeeper & Kafka. - java_directory = ServicesHolder.find_java() - - launcher_environment = os.environ.copy() - # Make `java` visible to build script: - # https://github.com/apache/kafka/blob/2.2.0/bin/kafka-run-class.sh#L226 - new_path = os.path.abspath(java_directory) + os.pathsep + launcher_environment['PATH'] - launcher_environment['PATH'] = new_path - # Both ZK & Kafka use Kafka launcher script. - # By default it sets up JMX options: - # https://github.com/apache/kafka/blob/2.2.0/bin/kafka-run-class.sh#L167 - # But that forces the JVM to load file that is not present due to: - # https://docs.oracle.com/javase/9/management/monitoring-and-management-using-jmx-technology.htm - # Let's make it simple and just disable JMX. - launcher_environment['KAFKA_JMX_OPTS'] = ' ' - - # Setup a temporary directory, which will be used by Kafka & Zookeeper servers. - self.kafka_tmp_dir = tempfile.mkdtemp() - print('Temporary directory used for tests: ' + self.kafka_tmp_dir) - - # This directory will store the configuration files fed to services. - config_dir = self.kafka_tmp_dir + '/config' - os.mkdir(config_dir) - # This directory will store Zookeeper's data (== Kafka server metadata). - zookeeper_store_dir = self.kafka_tmp_dir + '/zookeeper_data' - os.mkdir(zookeeper_store_dir) - # This directory will store Kafka's data (== partitions). - kafka_store_dir = self.kafka_tmp_dir + '/kafka_data' - os.mkdir(kafka_store_dir) - - # Find the Kafka server 'bin' directory. - kafka_bin_dir = os.path.join('.', 'external', 'kafka_server_binary', 'bin') - - # Main initialization block: - # - generate random ports, - # - render configuration with these ports, - # - start services and check if they are running okay, - # - if anything is having problems, kill everything and start again. - while True: - - # Generate random ports. - zk_port = ServicesHolder.get_random_listener_port() - kafka_real_port = ServicesHolder.get_random_listener_port() - kafka_envoy_port = ServicesHolder.get_random_listener_port() - envoy_monitoring_port = ServicesHolder.get_random_listener_port() - - # These ports need to be exposed to tests. - self.kafka_envoy_port = kafka_envoy_port - self.envoy_monitoring_port = envoy_monitoring_port - - # Render config file for Envoy. - template = RenderingHelper.get_template('envoy_config_yaml.j2') - contents = template.render( - data={ - 'kafka_real_port': kafka_real_port, - 'kafka_envoy_port': kafka_envoy_port, - 'envoy_monitoring_port': envoy_monitoring_port - }) - envoy_config_file = os.path.join(config_dir, 'envoy_config.yaml') - with open(envoy_config_file, 'w') as fd: - fd.write(contents) - print('Envoy config file rendered at: ' + envoy_config_file) - - # Render config file for Zookeeper. - template = RenderingHelper.get_template('zookeeper_properties.j2') - contents = template.render(data={'data_dir': zookeeper_store_dir, 'zk_port': zk_port}) - zookeeper_config_file = os.path.join(config_dir, 'zookeeper.properties') - with open(zookeeper_config_file, 'w') as fd: - fd.write(contents) - print('Zookeeper config file rendered at: ' + zookeeper_config_file) - - # Render config file for Kafka. - template = RenderingHelper.get_template('kafka_server_properties.j2') - contents = template.render( - data={ - 'data_dir': kafka_store_dir, - 'zk_port': zk_port, - 'kafka_real_port': kafka_real_port, - 'kafka_envoy_port': kafka_envoy_port - }) - kafka_config_file = os.path.join(config_dir, 'kafka_server.properties') - with open(kafka_config_file, 'w') as fd: - fd.write(contents) - print('Kafka config file rendered at: ' + kafka_config_file) - - # Start the services now. - try: - - # Start Envoy in the background, pointing to rendered config file. - envoy_binary = ServicesHolder.find_envoy() - # --base-id is added to allow multiple Envoy instances to run at the same time. - envoy_args = [ - os.path.abspath(envoy_binary), '-c', envoy_config_file, '--base-id', - str(random.randint(1, 999999)) - ] - envoy_handle = subprocess.Popen(envoy_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - self.envoy_worker = ProcessWorker(envoy_handle, 'Envoy', 'starting main dispatch loop') - self.envoy_worker.await_startup() - - # Start Zookeeper in background, pointing to rendered config file. - zk_binary = os.path.join(kafka_bin_dir, 'zookeeper-server-start.sh') - zk_args = [os.path.abspath(zk_binary), zookeeper_config_file] - zk_handle = subprocess.Popen(zk_args, - env=launcher_environment, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - self.zk_worker = ProcessWorker(zk_handle, 'Zookeeper', 'binding to port') - self.zk_worker.await_startup() - - # Start Kafka in background, pointing to rendered config file. - kafka_binary = os.path.join(kafka_bin_dir, 'kafka-server-start.sh') - kafka_args = [os.path.abspath(kafka_binary), kafka_config_file] - kafka_handle = subprocess.Popen(kafka_args, - env=launcher_environment, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - self.kafka_worker = ProcessWorker(kafka_handle, 'Kafka', '[KafkaServer id=0] started') - self.kafka_worker.await_startup() - - # All services have started without problems - now we can finally finish. - break - - except Exception as e: - print('Could not start services, will try again', e) - - if self.kafka_worker: - self.kafka_worker.kill() - self.kafka_worker = None - if self.zk_worker: - self.zk_worker.kill() - self.zk_worker = None - if self.envoy_worker: - self.envoy_worker.kill() - self.envoy_worker = None - - @staticmethod - def find_java(): - """ + # Find java installation that we are going to use to start Zookeeper & Kafka. + java_directory = ServicesHolder.find_java() + + launcher_environment = os.environ.copy() + # Make `java` visible to build script: + # https://github.com/apache/kafka/blob/2.2.0/bin/kafka-run-class.sh#L226 + new_path = os.path.abspath(java_directory) + os.pathsep + launcher_environment['PATH'] + launcher_environment['PATH'] = new_path + # Both ZK & Kafka use Kafka launcher script. + # By default it sets up JMX options: + # https://github.com/apache/kafka/blob/2.2.0/bin/kafka-run-class.sh#L167 + # But that forces the JVM to load file that is not present due to: + # https://docs.oracle.com/javase/9/management/monitoring-and-management-using-jmx-technology.htm + # Let's make it simple and just disable JMX. + launcher_environment['KAFKA_JMX_OPTS'] = ' ' + + # Setup a temporary directory, which will be used by Kafka & Zookeeper servers. + self.kafka_tmp_dir = tempfile.mkdtemp() + print('Temporary directory used for tests: ' + self.kafka_tmp_dir) + + # This directory will store the configuration files fed to services. + config_dir = self.kafka_tmp_dir + '/config' + os.mkdir(config_dir) + # This directory will store Zookeeper's data (== Kafka server metadata). + zookeeper_store_dir = self.kafka_tmp_dir + '/zookeeper_data' + os.mkdir(zookeeper_store_dir) + # This directory will store Kafka's data (== partitions). + kafka_store_dir = self.kafka_tmp_dir + '/kafka_data' + os.mkdir(kafka_store_dir) + + # Find the Kafka server 'bin' directory. + kafka_bin_dir = os.path.join('.', 'external', 'kafka_server_binary', 'bin') + + # Main initialization block: + # - generate random ports, + # - render configuration with these ports, + # - start services and check if they are running okay, + # - if anything is having problems, kill everything and start again. + while True: + + # Generate random ports. + zk_port = ServicesHolder.get_random_listener_port() + kafka_real_port = ServicesHolder.get_random_listener_port() + kafka_envoy_port = ServicesHolder.get_random_listener_port() + envoy_monitoring_port = ServicesHolder.get_random_listener_port() + + # These ports need to be exposed to tests. + self.kafka_envoy_port = kafka_envoy_port + self.envoy_monitoring_port = envoy_monitoring_port + + # Render config file for Envoy. + template = RenderingHelper.get_template('envoy_config_yaml.j2') + contents = template.render( + data={ + 'kafka_real_port': kafka_real_port, + 'kafka_envoy_port': kafka_envoy_port, + 'envoy_monitoring_port': envoy_monitoring_port + }) + envoy_config_file = os.path.join(config_dir, 'envoy_config.yaml') + with open(envoy_config_file, 'w') as fd: + fd.write(contents) + print('Envoy config file rendered at: ' + envoy_config_file) + + # Render config file for Zookeeper. + template = RenderingHelper.get_template('zookeeper_properties.j2') + contents = template.render(data={'data_dir': zookeeper_store_dir, 'zk_port': zk_port}) + zookeeper_config_file = os.path.join(config_dir, 'zookeeper.properties') + with open(zookeeper_config_file, 'w') as fd: + fd.write(contents) + print('Zookeeper config file rendered at: ' + zookeeper_config_file) + + # Render config file for Kafka. + template = RenderingHelper.get_template('kafka_server_properties.j2') + contents = template.render( + data={ + 'data_dir': kafka_store_dir, + 'zk_port': zk_port, + 'kafka_real_port': kafka_real_port, + 'kafka_envoy_port': kafka_envoy_port + }) + kafka_config_file = os.path.join(config_dir, 'kafka_server.properties') + with open(kafka_config_file, 'w') as fd: + fd.write(contents) + print('Kafka config file rendered at: ' + kafka_config_file) + + # Start the services now. + try: + + # Start Envoy in the background, pointing to rendered config file. + envoy_binary = ServicesHolder.find_envoy() + # --base-id is added to allow multiple Envoy instances to run at the same time. + envoy_args = [ + os.path.abspath(envoy_binary), '-c', envoy_config_file, '--base-id', + str(random.randint(1, 999999)) + ] + envoy_handle = subprocess.Popen(envoy_args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.envoy_worker = ProcessWorker(envoy_handle, 'Envoy', + 'starting main dispatch loop') + self.envoy_worker.await_startup() + + # Start Zookeeper in background, pointing to rendered config file. + zk_binary = os.path.join(kafka_bin_dir, 'zookeeper-server-start.sh') + zk_args = [os.path.abspath(zk_binary), zookeeper_config_file] + zk_handle = subprocess.Popen(zk_args, + env=launcher_environment, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.zk_worker = ProcessWorker(zk_handle, 'Zookeeper', 'binding to port') + self.zk_worker.await_startup() + + # Start Kafka in background, pointing to rendered config file. + kafka_binary = os.path.join(kafka_bin_dir, 'kafka-server-start.sh') + kafka_args = [os.path.abspath(kafka_binary), kafka_config_file] + kafka_handle = subprocess.Popen(kafka_args, + env=launcher_environment, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.kafka_worker = ProcessWorker(kafka_handle, 'Kafka', + '[KafkaServer id=0] started') + self.kafka_worker.await_startup() + + # All services have started without problems - now we can finally finish. + break + + except Exception as e: + print('Could not start services, will try again', e) + + if self.kafka_worker: + self.kafka_worker.kill() + self.kafka_worker = None + if self.zk_worker: + self.zk_worker.kill() + self.zk_worker = None + if self.envoy_worker: + self.envoy_worker.kill() + self.envoy_worker = None + + @staticmethod + def find_java(): + """ This method just locates the Java installation in current directory. We cannot hardcode the name, as the dirname changes as per: https://github.com/bazelbuild/bazel/blob/master/tools/jdk/BUILD#L491 """ - external_dir = os.path.join('.', 'external') - for directory in os.listdir(external_dir): - if 'remotejdk11' in directory: - result = os.path.join(external_dir, directory, 'bin') - print('Using Java: ' + result) - return result - raise Exception('Could not find Java in: ' + external_dir) - - @staticmethod - def find_envoy(): - """ + external_dir = os.path.join('.', 'external') + for directory in os.listdir(external_dir): + if 'remotejdk11' in directory: + result = os.path.join(external_dir, directory, 'bin') + print('Using Java: ' + result) + return result + raise Exception('Could not find Java in: ' + external_dir) + + @staticmethod + def find_envoy(): + """ This method locates envoy binary. It's present at ./source/exe/envoy-static (at least for mac/bazel-asan/bazel-tsan), or at ./external/envoy/source/exe/envoy-static (for bazel-compile_time_options). """ - candidate = os.path.join('.', 'source', 'exe', 'envoy-static') - if os.path.isfile(candidate): - return candidate - candidate = os.path.join('.', 'external', 'envoy', 'source', 'exe', 'envoy-static') - if os.path.isfile(candidate): - return candidate - raise Exception("Could not find Envoy") + candidate = os.path.join('.', 'source', 'exe', 'envoy-static') + if os.path.isfile(candidate): + return candidate + candidate = os.path.join('.', 'external', 'envoy', 'source', 'exe', 'envoy-static') + if os.path.isfile(candidate): + return candidate + raise Exception("Could not find Envoy") - def shut_down(self): - # Teardown - kill Kafka, Zookeeper, and Envoy. Then delete their data directory. - print('Cleaning up') + def shut_down(self): + # Teardown - kill Kafka, Zookeeper, and Envoy. Then delete their data directory. + print('Cleaning up') - if self.kafka_worker: - self.kafka_worker.kill() + if self.kafka_worker: + self.kafka_worker.kill() - if self.zk_worker: - self.zk_worker.kill() + if self.zk_worker: + self.zk_worker.kill() - if self.envoy_worker: - self.envoy_worker.kill() + if self.envoy_worker: + self.envoy_worker.kill() - if self.kafka_tmp_dir: - print('Removing temporary directory: ' + self.kafka_tmp_dir) - shutil.rmtree(self.kafka_tmp_dir) + if self.kafka_tmp_dir: + print('Removing temporary directory: ' + self.kafka_tmp_dir) + shutil.rmtree(self.kafka_tmp_dir) - def check_state(self): - self.envoy_worker.check_state() - self.zk_worker.check_state() - self.kafka_worker.check_state() + def check_state(self): + self.envoy_worker.check_state() + self.zk_worker.check_state() + self.kafka_worker.check_state() class ProcessWorker: - """ + """ Helper class that wraps the external service process. Provides ability to wait until service is ready to use (this is done by tracing logs) and printing service's output to stdout. """ - # Service is considered to be properly initialized after it has logged its startup message - # and has been alive for INITIALIZATION_WAIT_SECONDS after that message has been seen. - # This (clunky) design is needed because Zookeeper happens to log "binding to port" and then - # might fail to bind. - INITIALIZATION_WAIT_SECONDS = 3 - - def __init__(self, process_handle, name, startup_message): - # Handle to process and pretty name. - self.process_handle = process_handle - self.name = name - - self.startup_message = startup_message - self.startup_message_ts = None - - # Semaphore raised when startup has finished and information regarding startup's success. - self.initialization_semaphore = Semaphore(value=0) - self.initialization_ok = False - - self.state_worker = Thread(target=ProcessWorker.initialization_worker, args=(self,)) - self.state_worker.start() - self.out_worker = Thread(target=ProcessWorker.pipe_handler, - args=(self, self.process_handle.stdout, 'out')) - self.out_worker.start() - self.err_worker = Thread(target=ProcessWorker.pipe_handler, - args=(self, self.process_handle.stderr, 'err')) - self.err_worker.start() - - @staticmethod - def initialization_worker(owner): - """ + # Service is considered to be properly initialized after it has logged its startup message + # and has been alive for INITIALIZATION_WAIT_SECONDS after that message has been seen. + # This (clunky) design is needed because Zookeeper happens to log "binding to port" and then + # might fail to bind. + INITIALIZATION_WAIT_SECONDS = 3 + + def __init__(self, process_handle, name, startup_message): + # Handle to process and pretty name. + self.process_handle = process_handle + self.name = name + + self.startup_message = startup_message + self.startup_message_ts = None + + # Semaphore raised when startup has finished and information regarding startup's success. + self.initialization_semaphore = Semaphore(value=0) + self.initialization_ok = False + + self.state_worker = Thread(target=ProcessWorker.initialization_worker, args=(self,)) + self.state_worker.start() + self.out_worker = Thread(target=ProcessWorker.pipe_handler, + args=(self, self.process_handle.stdout, 'out')) + self.out_worker.start() + self.err_worker = Thread(target=ProcessWorker.pipe_handler, + args=(self, self.process_handle.stderr, 'err')) + self.err_worker.start() + + @staticmethod + def initialization_worker(owner): + """ Worker thread. Responsible for detecting if service died during initialization steps and ensuring if enough time has passed since the startup message has been seen. When either of these happens, we just raise the initialization semaphore. """ - while True: - status = owner.process_handle.poll() - if status: - # Service died. - print('%s did not initialize properly - finished with: %s' % (owner.name, status)) - owner.initialization_ok = False - owner.initialization_semaphore.release() - break - else: - # Service is still running. - startup_message_ts = owner.startup_message_ts - if startup_message_ts: - # The log message has been registered (by pipe_handler thread), let's just ensure that - # some time has passed and mark the service as running. - current_time = int(round(time.time())) - if current_time - startup_message_ts >= ProcessWorker.INITIALIZATION_WAIT_SECONDS: - print('Startup message seen %s seconds ago, and service is still running' % - (ProcessWorker.INITIALIZATION_WAIT_SECONDS), - flush=True) - owner.initialization_ok = True - owner.initialization_semaphore.release() - break - time.sleep(1) - print('Initialization worker for %s has finished' % (owner.name)) - - @staticmethod - def pipe_handler(owner, pipe, pipe_name): - """ + while True: + status = owner.process_handle.poll() + if status: + # Service died. + print('%s did not initialize properly - finished with: %s' % (owner.name, status)) + owner.initialization_ok = False + owner.initialization_semaphore.release() + break + else: + # Service is still running. + startup_message_ts = owner.startup_message_ts + if startup_message_ts: + # The log message has been registered (by pipe_handler thread), let's just ensure that + # some time has passed and mark the service as running. + current_time = int(round(time.time())) + if current_time - startup_message_ts >= ProcessWorker.INITIALIZATION_WAIT_SECONDS: + print('Startup message seen %s seconds ago, and service is still running' % + (ProcessWorker.INITIALIZATION_WAIT_SECONDS), + flush=True) + owner.initialization_ok = True + owner.initialization_semaphore.release() + break + time.sleep(1) + print('Initialization worker for %s has finished' % (owner.name)) + + @staticmethod + def pipe_handler(owner, pipe, pipe_name): + """ Worker thread. If a service startup message is seen, then it just registers the timestamp of its appearance. Also prints every received message. """ - try: - for raw_line in pipe: - line = raw_line.decode().rstrip() - print('%s(%s):' % (owner.name, pipe_name), line, flush=True) - if owner.startup_message in line: - print('%s initialization message [%s] has been logged' % - (owner.name, owner.startup_message)) - owner.startup_message_ts = int(round(time.time())) - finally: - pipe.close() - print('Pipe handler for %s(%s) has finished' % (owner.name, pipe_name)) - - def await_startup(self): - """ + try: + for raw_line in pipe: + line = raw_line.decode().rstrip() + print('%s(%s):' % (owner.name, pipe_name), line, flush=True) + if owner.startup_message in line: + print('%s initialization message [%s] has been logged' % + (owner.name, owner.startup_message)) + owner.startup_message_ts = int(round(time.time())) + finally: + pipe.close() + print('Pipe handler for %s(%s) has finished' % (owner.name, pipe_name)) + + def await_startup(self): + """ Awaits on initialization semaphore, and then verifies the initialization state. If everything is okay, we just continue (we can use the service), otherwise throw. """ - print('Waiting for %s to start...' % (self.name)) - self.initialization_semaphore.acquire() - try: - if self.initialization_ok: - print('Service %s started successfully' % (self.name)) - else: - raise Exception('%s could not start' % (self.name)) - finally: - self.initialization_semaphore.release() - - def check_state(self): - """ + print('Waiting for %s to start...' % (self.name)) + self.initialization_semaphore.acquire() + try: + if self.initialization_ok: + print('Service %s started successfully' % (self.name)) + else: + raise Exception('%s could not start' % (self.name)) + finally: + self.initialization_semaphore.release() + + def check_state(self): + """ Verifies if the service is still running. Throws if it is not. """ - status = self.process_handle.poll() - if status: - raise Exception('%s died with: %s' % (self.name, str(status))) + status = self.process_handle.poll() + if status: + raise Exception('%s died with: %s' % (self.name, str(status))) - def kill(self): - """ + def kill(self): + """ Utility method to kill the main service thread and all related workers. """ - print('Stopping service %s' % self.name) + print('Stopping service %s' % self.name) - # Kill the real process. - self.process_handle.kill() - self.process_handle.wait() + # Kill the real process. + self.process_handle.kill() + self.process_handle.wait() - # The sub-workers are going to finish on their own, as they will detect main thread dying - # (through pipes closing, or .poll() returning a non-null value). - self.state_worker.join() - self.out_worker.join() - self.err_worker.join() + # The sub-workers are going to finish on their own, as they will detect main thread dying + # (through pipes closing, or .poll() returning a non-null value). + self.state_worker.join() + self.out_worker.join() + self.err_worker.join() - print('Service %s has been stopped' % self.name) + print('Service %s has been stopped' % self.name) class RenderingHelper: - """ + """ Helper for jinja templates. """ - @staticmethod - def get_template(template): - import jinja2 - import os - import sys - # Templates are resolved relatively to main start script, due to main & test templates being - # stored in different directories. - env = jinja2.Environment(loader=jinja2.FileSystemLoader( - searchpath=os.path.dirname(os.path.abspath(__file__)))) - return env.get_template(template) + @staticmethod + def get_template(template): + import jinja2 + import os + import sys + # Templates are resolved relatively to main start script, due to main & test templates being + # stored in different directories. + env = jinja2.Environment(loader=jinja2.FileSystemLoader( + searchpath=os.path.dirname(os.path.abspath(__file__)))) + return env.get_template(template) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/test/extensions/filters/network/kafka/protocol/launcher.py b/test/extensions/filters/network/kafka/protocol/launcher.py index d277457f0e2e..e75437ae9ade 100644 --- a/test/extensions/filters/network/kafka/protocol/launcher.py +++ b/test/extensions/filters/network/kafka/protocol/launcher.py @@ -8,7 +8,7 @@ def main(): - """ + """ Kafka test generator script ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generates tests from Kafka protocol specification. @@ -36,14 +36,14 @@ def main(): ${MESSAGE_TYPE}_codec_${MESSAGE_TYPE}_test_cc.j2, - to create '${MESSAGE_TYPE}_utilities.cc' - ${MESSAGE_TYPE}_utilities_cc.j2. """ - type = sys.argv[1] - header_test_cc_file = os.path.abspath(sys.argv[2]) - codec_test_cc_file = os.path.abspath(sys.argv[3]) - utilities_cc_file = os.path.abspath(sys.argv[4]) - input_files = sys.argv[5:] - generator.generate_test_code(type, header_test_cc_file, codec_test_cc_file, utilities_cc_file, - input_files) + type = sys.argv[1] + header_test_cc_file = os.path.abspath(sys.argv[2]) + codec_test_cc_file = os.path.abspath(sys.argv[3]) + utilities_cc_file = os.path.abspath(sys.argv[4]) + input_files = sys.argv[5:] + generator.generate_test_code(type, header_test_cc_file, codec_test_cc_file, utilities_cc_file, + input_files) if __name__ == "__main__": - main() + main() diff --git a/test/extensions/filters/network/kafka/serialization/launcher.py b/test/extensions/filters/network/kafka/serialization/launcher.py index 223e56ef3e90..5efd339d687e 100644 --- a/test/extensions/filters/network/kafka/serialization/launcher.py +++ b/test/extensions/filters/network/kafka/serialization/launcher.py @@ -8,7 +8,7 @@ def main(): - """ + """ Serialization composite test generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generates test source files for composite deserializers. @@ -24,9 +24,9 @@ def main(): Template used is 'serialization_composite_test_cc.j2'. """ - serialization_composite_test_cc_file = os.path.abspath(sys.argv[1]) - generator.generate_test_code(serialization_composite_test_cc_file) + serialization_composite_test_cc_file = os.path.abspath(sys.argv[1]) + generator.generate_test_code(serialization_composite_test_cc_file) if __name__ == "__main__": - main() + main() diff --git a/test/extensions/filters/network/thrift_proxy/driver/client.py b/test/extensions/filters/network/thrift_proxy/driver/client.py index 544e5acff16a..0468f40ed81d 100755 --- a/test/extensions/filters/network/thrift_proxy/driver/client.py +++ b/test/extensions/filters/network/thrift_proxy/driver/client.py @@ -16,246 +16,246 @@ # On Windows we run this test on Python3 if sys.version_info[0] != 2: - sys.stdin.reconfigure(encoding='utf-8') - sys.stdout.reconfigure(encoding='utf-8') + sys.stdin.reconfigure(encoding='utf-8') + sys.stdout.reconfigure(encoding='utf-8') class TRecordingTransport(TTransport.TTransportBase): - def __init__(self, underlying, writehandle, readhandle): - self._underlying = underlying - self._whandle = writehandle - self._rhandle = readhandle + def __init__(self, underlying, writehandle, readhandle): + self._underlying = underlying + self._whandle = writehandle + self._rhandle = readhandle - def isOpen(self): - return self._underlying.isOpen() + def isOpen(self): + return self._underlying.isOpen() - def open(self): - if not self._underlying.isOpen(): - self._underlying.open() + def open(self): + if not self._underlying.isOpen(): + self._underlying.open() - def close(self): - self._underlying.close() - self._whandle.close() - self._rhandle.close() + def close(self): + self._underlying.close() + self._whandle.close() + self._rhandle.close() - def read(self, sz): - buf = self._underlying.read(sz) - if len(buf) != 0: - self._rhandle.write(buf) - return buf + def read(self, sz): + buf = self._underlying.read(sz) + if len(buf) != 0: + self._rhandle.write(buf) + return buf - def write(self, buf): - if len(buf) != 0: - self._whandle.write(buf) - self._underlying.write(buf) + def write(self, buf): + if len(buf) != 0: + self._whandle.write(buf) + self._underlying.write(buf) - def flush(self): - self._underlying.flush() - self._whandle.flush() - self._rhandle.flush() + def flush(self): + self._underlying.flush() + self._whandle.flush() + self._rhandle.flush() def main(cfg, reqhandle, resphandle): - if cfg.unix: - if cfg.addr == "": - sys.exit("invalid unix domain socket: {}".format(cfg.addr)) - socket = TSocket.TSocket(unix_socket=cfg.addr) - else: - try: - (host, port) = cfg.addr.rsplit(":", 1) - if host == "": - host = "localhost" - socket = TSocket.TSocket(host=host, port=int(port)) - except ValueError: - sys.exit("invalid address: {}".format(cfg.addr)) - - transport = TRecordingTransport(socket, reqhandle, resphandle) - - if cfg.transport == "framed": - transport = TTransport.TFramedTransport(transport) - elif cfg.transport == "unframed": - transport = TTransport.TBufferedTransport(transport) - elif cfg.transport == "header": - transport = THeaderTransport.THeaderTransport( - transport, - client_type=THeaderTransport.CLIENT_TYPE.HEADER, - ) + if cfg.unix: + if cfg.addr == "": + sys.exit("invalid unix domain socket: {}".format(cfg.addr)) + socket = TSocket.TSocket(unix_socket=cfg.addr) + else: + try: + (host, port) = cfg.addr.rsplit(":", 1) + if host == "": + host = "localhost" + socket = TSocket.TSocket(host=host, port=int(port)) + except ValueError: + sys.exit("invalid address: {}".format(cfg.addr)) + + transport = TRecordingTransport(socket, reqhandle, resphandle) + + if cfg.transport == "framed": + transport = TTransport.TFramedTransport(transport) + elif cfg.transport == "unframed": + transport = TTransport.TBufferedTransport(transport) + elif cfg.transport == "header": + transport = THeaderTransport.THeaderTransport( + transport, + client_type=THeaderTransport.CLIENT_TYPE.HEADER, + ) + + if cfg.headers is not None: + pairs = cfg.headers.split(",") + for p in pairs: + key, value = p.split("=") + transport.set_header(key, value) + + if cfg.protocol == "binary": + transport.set_protocol_id(THeaderTransport.T_BINARY_PROTOCOL) + elif cfg.protocol == "compact": + transport.set_protocol_id(THeaderTransport.T_COMPACT_PROTOCOL) + else: + sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) + else: + sys.exit("unknown transport {0}".format(cfg.transport)) - if cfg.headers is not None: - pairs = cfg.headers.split(",") - for p in pairs: - key, value = p.split("=") - transport.set_header(key, value) + transport.open() if cfg.protocol == "binary": - transport.set_protocol_id(THeaderTransport.T_BINARY_PROTOCOL) + protocol = TBinaryProtocol.TBinaryProtocol(transport) elif cfg.protocol == "compact": - transport.set_protocol_id(THeaderTransport.T_COMPACT_PROTOCOL) - else: - sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) - else: - sys.exit("unknown transport {0}".format(cfg.transport)) - - transport.open() - - if cfg.protocol == "binary": - protocol = TBinaryProtocol.TBinaryProtocol(transport) - elif cfg.protocol == "compact": - protocol = TCompactProtocol.TCompactProtocol(transport) - elif cfg.protocol == "json": - protocol = TJSONProtocol.TJSONProtocol(transport) - elif cfg.protocol == "finagle": - protocol = TFinagleProtocol(transport, client_id="thrift-playground") - else: - sys.exit("unknown protocol {0}".format(cfg.protocol)) - - if cfg.service is not None: - protocol = TMultiplexedProtocol.TMultiplexedProtocol(protocol, cfg.service) - - client = Example.Client(protocol) - - try: - if cfg.method == "ping": - client.ping() - print("client: pinged") - elif cfg.method == "poke": - client.poke() - print("client: poked") - elif cfg.method == "add": - if len(cfg.params) != 2: - sys.exit("add takes 2 arguments, got: {0}".format(cfg.params)) - - a = int(cfg.params[0]) - b = int(cfg.params[1]) - v = client.add(a, b) - print("client: added {0} + {1} = {2}".format(a, b, v)) - elif cfg.method == "execute": - param = Param(return_fields=cfg.params, - the_works=TheWorks( - field_1=True, - field_2=0x7f, - field_3=0x7fff, - field_4=0x7fffffff, - field_5=0x7fffffffffffffff, - field_6=-1.5, - field_7=u"string is UTF-8: \U0001f60e", - field_8=b"binary is bytes: \x80\x7f\x00\x01", - field_9={ - 1: "one", - 2: "two", - 3: "three" - }, - field_10=[1, 2, 4, 8], - field_11=set(["a", "b", "c"]), - field_12=False, - )) - - try: - result = client.execute(param) - print("client: executed {0}: {1}".format(param, result)) - except AppException as e: - print("client: execute failed with IDL Exception: {0}".format(e.why)) + protocol = TCompactProtocol.TCompactProtocol(transport) + elif cfg.protocol == "json": + protocol = TJSONProtocol.TJSONProtocol(transport) + elif cfg.protocol == "finagle": + protocol = TFinagleProtocol(transport, client_id="thrift-playground") else: - sys.exit("unknown method {0}".format(cfg.method)) - except Thrift.TApplicationException as e: - print("client exception: {0}: {1}".format(e.type, e.message)) + sys.exit("unknown protocol {0}".format(cfg.protocol)) - if cfg.request is None: - req = "".join(["%02X " % ord(x) for x in reqhandle.getvalue()]).strip() - print("request: {}".format(req)) - if cfg.response is None: - resp = "".join(["%02X " % ord(x) for x in resphandle.getvalue()]).strip() - print("response: {}".format(resp)) + if cfg.service is not None: + protocol = TMultiplexedProtocol.TMultiplexedProtocol(protocol, cfg.service) - transport.close() + client = Example.Client(protocol) + + try: + if cfg.method == "ping": + client.ping() + print("client: pinged") + elif cfg.method == "poke": + client.poke() + print("client: poked") + elif cfg.method == "add": + if len(cfg.params) != 2: + sys.exit("add takes 2 arguments, got: {0}".format(cfg.params)) + + a = int(cfg.params[0]) + b = int(cfg.params[1]) + v = client.add(a, b) + print("client: added {0} + {1} = {2}".format(a, b, v)) + elif cfg.method == "execute": + param = Param(return_fields=cfg.params, + the_works=TheWorks( + field_1=True, + field_2=0x7f, + field_3=0x7fff, + field_4=0x7fffffff, + field_5=0x7fffffffffffffff, + field_6=-1.5, + field_7=u"string is UTF-8: \U0001f60e", + field_8=b"binary is bytes: \x80\x7f\x00\x01", + field_9={ + 1: "one", + 2: "two", + 3: "three" + }, + field_10=[1, 2, 4, 8], + field_11=set(["a", "b", "c"]), + field_12=False, + )) + + try: + result = client.execute(param) + print("client: executed {0}: {1}".format(param, result)) + except AppException as e: + print("client: execute failed with IDL Exception: {0}".format(e.why)) + else: + sys.exit("unknown method {0}".format(cfg.method)) + except Thrift.TApplicationException as e: + print("client exception: {0}: {1}".format(e.type, e.message)) + + if cfg.request is None: + req = "".join(["%02X " % ord(x) for x in reqhandle.getvalue()]).strip() + print("request: {}".format(req)) + if cfg.response is None: + resp = "".join(["%02X " % ord(x) for x in resphandle.getvalue()]).strip() + print("response: {}".format(resp)) + + transport.close() if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Thrift client tool.",) - parser.add_argument( - "method", - metavar="METHOD", - help="Name of the service method to invoke.", - ) - parser.add_argument( - "params", - metavar="PARAMS", - nargs="*", - help="Method parameters", - ) - parser.add_argument( - "-a", - "--addr", - metavar="ADDR", - dest="addr", - required=True, - help="Target address for requests in the form host:port. The host is optional. If --unix" + - " is set, the address is the socket name.", - ) - parser.add_argument( - "-m", - "--multiplex", - metavar="SERVICE", - dest="service", - help="Enable service multiplexing and set the service name.", - ) - parser.add_argument( - "-p", - "--protocol", - dest="protocol", - default="binary", - choices=["binary", "compact", "json", "finagle"], - help="selects a protocol.", - ) - parser.add_argument( - "--request", - metavar="FILE", - dest="request", - help="Writes the Thrift request to a file.", - ) - parser.add_argument( - "--response", - metavar="FILE", - dest="response", - help="Writes the Thrift response to a file.", - ) - parser.add_argument( - "-t", - "--transport", - dest="transport", - default="framed", - choices=["framed", "unframed", "header"], - help="selects a transport.", - ) - parser.add_argument( - "-u", - "--unix", - dest="unix", - action="store_true", - ) - parser.add_argument( - "--headers", - dest="headers", - metavar="KEY=VALUE[,KEY=VALUE]", - help="list of comma-delimited, key value pairs to include as transport headers.", - ) - - cfg = parser.parse_args() - - reqhandle = io.BytesIO() - resphandle = io.BytesIO() - if cfg.request is not None: - try: - reqhandle = io.open(cfg.request, "wb") - except IOError as e: - sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) - if cfg.response is not None: + parser = argparse.ArgumentParser(description="Thrift client tool.",) + parser.add_argument( + "method", + metavar="METHOD", + help="Name of the service method to invoke.", + ) + parser.add_argument( + "params", + metavar="PARAMS", + nargs="*", + help="Method parameters", + ) + parser.add_argument( + "-a", + "--addr", + metavar="ADDR", + dest="addr", + required=True, + help="Target address for requests in the form host:port. The host is optional. If --unix" + + " is set, the address is the socket name.", + ) + parser.add_argument( + "-m", + "--multiplex", + metavar="SERVICE", + dest="service", + help="Enable service multiplexing and set the service name.", + ) + parser.add_argument( + "-p", + "--protocol", + dest="protocol", + default="binary", + choices=["binary", "compact", "json", "finagle"], + help="selects a protocol.", + ) + parser.add_argument( + "--request", + metavar="FILE", + dest="request", + help="Writes the Thrift request to a file.", + ) + parser.add_argument( + "--response", + metavar="FILE", + dest="response", + help="Writes the Thrift response to a file.", + ) + parser.add_argument( + "-t", + "--transport", + dest="transport", + default="framed", + choices=["framed", "unframed", "header"], + help="selects a transport.", + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + action="store_true", + ) + parser.add_argument( + "--headers", + dest="headers", + metavar="KEY=VALUE[,KEY=VALUE]", + help="list of comma-delimited, key value pairs to include as transport headers.", + ) + + cfg = parser.parse_args() + + reqhandle = io.BytesIO() + resphandle = io.BytesIO() + if cfg.request is not None: + try: + reqhandle = io.open(cfg.request, "wb") + except IOError as e: + sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) + if cfg.response is not None: + try: + resphandle = io.open(cfg.response, "wb") + except IOError as e: + sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) try: - resphandle = io.open(cfg.response, "wb") - except IOError as e: - sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) - try: - main(cfg, reqhandle, resphandle) - except Thrift.TException as tx: - sys.exit("Unhandled Thrift Exception: {0}".format(tx.message)) + main(cfg, reqhandle, resphandle) + except Thrift.TException as tx: + sys.exit("Unhandled Thrift Exception: {0}".format(tx.message)) diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py index 7f14ecebb21e..eeb9c810f4b3 100644 --- a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py +++ b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py @@ -27,15 +27,15 @@ import sys if sys.version_info[0] >= 3: - from http import server - BaseHTTPServer = server - xrange = range - from io import BytesIO as StringIO - PY3 = True + from http import server + BaseHTTPServer = server + xrange = range + from io import BytesIO as StringIO + PY3 = True else: - import BaseHTTPServer - from cStringIO import StringIO - PY3 = False + import BaseHTTPServer + from cStringIO import StringIO + PY3 = False from struct import pack, unpack import zlib @@ -51,87 +51,87 @@ # INFO:(zuercher): Copied from: # https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/protocol/TCompactProtocol.py def getVarint(n): - out = [] - while True: - if n & ~0x7f == 0: - out.append(n) - break + out = [] + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + if sys.version_info[0] >= 3: + return bytes(out) else: - out.append((n & 0xff) | 0x80) - n = n >> 7 - if sys.version_info[0] >= 3: - return bytes(out) - else: - return b''.join(map(chr, out)) + return b''.join(map(chr, out)) # INFO:(zuercher): Copied from # https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/protocol/TCompactProtocol.py def readVarint(trans): - result = 0 - shift = 0 - while True: - x = trans.read(1) - byte = ord(x) - result |= (byte & 0x7f) << shift - if byte >> 7 == 0: - return result - shift += 7 + result = 0 + shift = 0 + while True: + x = trans.read(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 # Import the snappy module if it is available try: - import snappy + import snappy except ImportError: - # If snappy is not available, don't fail immediately. - # Only raise an error if we actually ever need to perform snappy - # compression. - class DummySnappy(object): + # If snappy is not available, don't fail immediately. + # Only raise an error if we actually ever need to perform snappy + # compression. + class DummySnappy(object): - def compress(self, buf): - raise TTransportException(TTransportException.INVALID_TRANSFORM, - 'snappy module not available') + def compress(self, buf): + raise TTransportException(TTransportException.INVALID_TRANSFORM, + 'snappy module not available') - def decompress(self, buf): - raise TTransportException(TTransportException.INVALID_TRANSFORM, - 'snappy module not available') + def decompress(self, buf): + raise TTransportException(TTransportException.INVALID_TRANSFORM, + 'snappy module not available') - snappy = DummySnappy() # type: ignore + snappy = DummySnappy() # type: ignore # Definitions from THeader.h class CLIENT_TYPE: - HEADER = 0 - FRAMED_DEPRECATED = 1 - UNFRAMED_DEPRECATED = 2 - HTTP_SERVER = 3 - HTTP_CLIENT = 4 - FRAMED_COMPACT = 5 - HEADER_SASL = 6 - HTTP_GET = 7 - UNKNOWN = 8 - UNFRAMED_COMPACT_DEPRECATED = 9 + HEADER = 0 + FRAMED_DEPRECATED = 1 + UNFRAMED_DEPRECATED = 2 + HTTP_SERVER = 3 + HTTP_CLIENT = 4 + FRAMED_COMPACT = 5 + HEADER_SASL = 6 + HTTP_GET = 7 + UNKNOWN = 8 + UNFRAMED_COMPACT_DEPRECATED = 9 class HEADER_FLAG: - SUPPORT_OUT_OF_ORDER = 0x01 - DUPLEX_REVERSE = 0x08 - SASL = 0x10 + SUPPORT_OUT_OF_ORDER = 0x01 + DUPLEX_REVERSE = 0x08 + SASL = 0x10 class TRANSFORM: - NONE = 0x00 - ZLIB = 0x01 - HMAC = 0x02 - SNAPPY = 0x03 - QLZ = 0x04 - ZSTD = 0x05 + NONE = 0x00 + ZLIB = 0x01 + HMAC = 0x02 + SNAPPY = 0x03 + QLZ = 0x04 + ZSTD = 0x05 class INFO: - NORMAL = 1 - PERSISTENT = 2 + NORMAL = 1 + PERSISTENT = 2 T_BINARY_PROTOCOL = 0 @@ -150,495 +150,498 @@ class INFO: class THeaderTransport(TTransportBase, CReadableTransport): - """Transport that sends headers. Also understands framed/unframed/HTTP + """Transport that sends headers. Also understands framed/unframed/HTTP transports and will do the right thing""" - __max_frame_size = MAX_FRAME_SIZE - - # Defaults to current user, but there is also a setter below. - __identity = None - IDENTITY_HEADER = "identity" - ID_VERSION_HEADER = "id_version" - ID_VERSION = "1" - - def __init__(self, trans, client_types=None, client_type=None): - self.__trans = trans - self.__rbuf = StringIO() - self.__rbuf_frame = False - self.__wbuf = StringIO() - self.seq_id = 0 - self.__flags = 0 - self.__read_transforms = [] - self.__write_transforms = [] - self.__supported_client_types = set(client_types or (CLIENT_TYPE.HEADER,)) - self.__proto_id = T_COMPACT_PROTOCOL # default to compact like c++ - self.__client_type = client_type or CLIENT_TYPE.HEADER - self.__read_headers = {} - self.__read_persistent_headers = {} - self.__write_headers = {} - self.__write_persistent_headers = {} - - self.__supported_client_types.add(self.__client_type) - - # If we support unframed binary / framed binary also support compact - if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types: - self.__supported_client_types.add(CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED) - if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types: - self.__supported_client_types.add(CLIENT_TYPE.FRAMED_COMPACT) - - def set_header_flag(self, flag): - self.__flags |= flag - - def clear_header_flag(self, flag): - self.__flags &= ~flag - - def header_flags(self): - return self.__flags - - def set_max_frame_size(self, size): - if size > MAX_BIG_FRAME_SIZE: - raise TTransportException(TTransportException.INVALID_FRAME_SIZE, - "Cannot set max frame size > %s" % MAX_BIG_FRAME_SIZE) - if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER: - raise TTransportException( - TTransportException.INVALID_FRAME_SIZE, - "Cannot set max frame size > %s for clients other than HEADER" % MAX_FRAME_SIZE) - self.__max_frame_size = size - - def get_peer_identity(self): - if self.IDENTITY_HEADER in self.__read_headers: - if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION: - return self.__read_headers[self.IDENTITY_HEADER] - return None - - def set_identity(self, identity): - self.__identity = identity - - def get_protocol_id(self): - return self.__proto_id - - def set_protocol_id(self, proto_id): - self.__proto_id = proto_id - - def set_header(self, str_key, str_value): - self.__write_headers[str_key] = str_value - - def get_write_headers(self): - return self.__write_headers - - def get_headers(self): - return self.__read_headers - - def clear_headers(self): - self.__write_headers.clear() - - def set_persistent_header(self, str_key, str_value): - self.__write_persistent_headers[str_key] = str_value - - def get_write_persistent_headers(self): - return self.__write_persistent_headers - - def clear_persistent_headers(self): - self.__write_persistent_headers.clear() - - def add_transform(self, trans_id): - self.__write_transforms.append(trans_id) - - def _reset_protocol(self): - # HTTP calls that are one way need to flush here. - if self.__client_type == CLIENT_TYPE.HTTP_SERVER: - self.flush() - # set to anything except unframed - self.__client_type = CLIENT_TYPE.UNKNOWN - # Read header bytes to check which protocol to decode - self.readFrame(0) - - def getTransport(self): - return self.__trans - - def isOpen(self): - return self.getTransport().isOpen() - - def open(self): - return self.getTransport().open() - - def close(self): - return self.getTransport().close() - - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) == sz: - return ret - - if self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, - CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): - return ret + self.getTransport().readAll(sz - len(ret)) - - self.readFrame(sz - len(ret)) - return ret + self.__rbuf.read(sz - len(ret)) - - readAll = read # TTransportBase.readAll does a needless copy here. - - def readFrame(self, req_sz): - self.__rbuf_frame = True - word1 = self.getTransport().readAll(4) - sz = unpack(b'!I', word1)[0] - proto_id = word1[0] if PY3 else ord(word1[0]) - if proto_id == BINARY_PROTO_ID: - # unframed - self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED - self.__proto_id = T_BINARY_PROTOCOL - if req_sz <= 4: # check for reads < 0. - self.__rbuf = StringIO(word1) - else: - self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4)) - elif proto_id == COMPACT_PROTO_ID: - self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED - self.__proto_id = T_COMPACT_PROTOCOL - if req_sz <= 4: # check for reads < 0. - self.__rbuf = StringIO(word1) - else: - self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4)) - elif sz == HTTP_SERVER_MAGIC: - self.__client_type = CLIENT_TYPE.HTTP_SERVER - mf = self.getTransport().handle.makefile('rb', -1) - - self.handler = RequestHandler(mf, 'client_address:port', '') - self.header = self.handler.wfile - self.__rbuf = StringIO(self.handler.data) - else: - if sz == BIG_FRAME_MAGIC: - sz = unpack(b'!Q', self.getTransport().readAll(8))[0] - # could be header format or framed. Check next two bytes. - magic = self.getTransport().readAll(2) - proto_id = magic[0] if PY3 else ord(magic[0]) - if proto_id == COMPACT_PROTO_ID: - self.__client_type = CLIENT_TYPE.FRAMED_COMPACT - self.__proto_id = T_COMPACT_PROTOCOL - _frame_size_check(sz, self.__max_frame_size, header=False) - self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2)) - elif proto_id == BINARY_PROTO_ID: - self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED - self.__proto_id = T_BINARY_PROTOCOL - _frame_size_check(sz, self.__max_frame_size, header=False) - self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2)) - elif magic == PACKED_HEADER_MAGIC: - self.__client_type = CLIENT_TYPE.HEADER - _frame_size_check(sz, self.__max_frame_size) - # flags(2), seq_id(4), header_size(2) - n_header_meta = self.getTransport().readAll(8) - self.__flags, self.seq_id, header_size = unpack(b'!HIH', n_header_meta) - data = StringIO() - data.write(magic) - data.write(n_header_meta) - data.write(self.getTransport().readAll(sz - 10)) - data.seek(10) - self.read_header_format(sz - 10, header_size, data) - else: - self.__client_type = CLIENT_TYPE.UNKNOWN - raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, - "Could not detect client transport type") - - if self.__client_type not in self.__supported_client_types: - raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, - "Client type {} not supported on server".format(self.__client_type)) - - def read_header_format(self, sz, header_size, data): - # clear out any previous transforms - self.__read_transforms = [] - - header_size = header_size * 4 - if header_size > sz: - raise TTransportException(TTransportException.INVALID_FRAME_SIZE, - "Header size is larger than frame") - end_header = header_size + data.tell() - - self.__proto_id = readVarint(data) - num_headers = readVarint(data) - - if self.__proto_id == 1 and self.__client_type != \ - CLIENT_TYPE.HTTP_SERVER: - raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, - "Trying to recv JSON encoding over binary") - - # Read the headers. Data for each header varies. - for _ in range(0, num_headers): - trans_id = readVarint(data) - if trans_id == TRANSFORM.ZLIB: - self.__read_transforms.insert(0, trans_id) - elif trans_id == TRANSFORM.SNAPPY: - self.__read_transforms.insert(0, trans_id) - elif trans_id == TRANSFORM.HMAC: - raise TApplicationException(TApplicationException.INVALID_TRANSFORM, - "Hmac transform is no longer supported: %i" % trans_id) - else: - # TApplicationException will be sent back to client - raise TApplicationException(TApplicationException.INVALID_TRANSFORM, - "Unknown transform in client request: %i" % trans_id) - - # Clear out previous info headers. - self.__read_headers.clear() - - # Read the info headers. - while data.tell() < end_header: - info_id = readVarint(data) - if info_id == INFO.NORMAL: - _read_info_headers(data, end_header, self.__read_headers) - elif info_id == INFO.PERSISTENT: - _read_info_headers(data, end_header, self.__read_persistent_headers) - else: - break # Unknown header. Stop info processing. - - if self.__read_persistent_headers: - self.__read_headers.update(self.__read_persistent_headers) - - # Skip the rest of the header - data.seek(end_header) - - payload = data.read(sz - header_size) - - # Read the data section. - self.__rbuf = StringIO(self.untransform(payload)) - - def write(self, buf): - self.__wbuf.write(buf) - - def transform(self, buf): - for trans_id in self.__write_transforms: - if trans_id == TRANSFORM.ZLIB: - buf = zlib.compress(buf) - elif trans_id == TRANSFORM.SNAPPY: - buf = snappy.compress(buf) - else: - raise TTransportException(TTransportException.INVALID_TRANSFORM, - "Unknown transform during send") - return buf - - def untransform(self, buf): - for trans_id in self.__read_transforms: - if trans_id == TRANSFORM.ZLIB: - buf = zlib.decompress(buf) - elif trans_id == TRANSFORM.SNAPPY: - buf = snappy.decompress(buf) - if trans_id not in self.__write_transforms: + __max_frame_size = MAX_FRAME_SIZE + + # Defaults to current user, but there is also a setter below. + __identity = None + IDENTITY_HEADER = "identity" + ID_VERSION_HEADER = "id_version" + ID_VERSION = "1" + + def __init__(self, trans, client_types=None, client_type=None): + self.__trans = trans + self.__rbuf = StringIO() + self.__rbuf_frame = False + self.__wbuf = StringIO() + self.seq_id = 0 + self.__flags = 0 + self.__read_transforms = [] + self.__write_transforms = [] + self.__supported_client_types = set(client_types or (CLIENT_TYPE.HEADER,)) + self.__proto_id = T_COMPACT_PROTOCOL # default to compact like c++ + self.__client_type = client_type or CLIENT_TYPE.HEADER + self.__read_headers = {} + self.__read_persistent_headers = {} + self.__write_headers = {} + self.__write_persistent_headers = {} + + self.__supported_client_types.add(self.__client_type) + + # If we support unframed binary / framed binary also support compact + if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types: + self.__supported_client_types.add(CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED) + if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types: + self.__supported_client_types.add(CLIENT_TYPE.FRAMED_COMPACT) + + def set_header_flag(self, flag): + self.__flags |= flag + + def clear_header_flag(self, flag): + self.__flags &= ~flag + + def header_flags(self): + return self.__flags + + def set_max_frame_size(self, size): + if size > MAX_BIG_FRAME_SIZE: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "Cannot set max frame size > %s" % MAX_BIG_FRAME_SIZE) + if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER: + raise TTransportException( + TTransportException.INVALID_FRAME_SIZE, + "Cannot set max frame size > %s for clients other than HEADER" % MAX_FRAME_SIZE) + self.__max_frame_size = size + + def get_peer_identity(self): + if self.IDENTITY_HEADER in self.__read_headers: + if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION: + return self.__read_headers[self.IDENTITY_HEADER] + return None + + def set_identity(self, identity): + self.__identity = identity + + def get_protocol_id(self): + return self.__proto_id + + def set_protocol_id(self, proto_id): + self.__proto_id = proto_id + + def set_header(self, str_key, str_value): + self.__write_headers[str_key] = str_value + + def get_write_headers(self): + return self.__write_headers + + def get_headers(self): + return self.__read_headers + + def clear_headers(self): + self.__write_headers.clear() + + def set_persistent_header(self, str_key, str_value): + self.__write_persistent_headers[str_key] = str_value + + def get_write_persistent_headers(self): + return self.__write_persistent_headers + + def clear_persistent_headers(self): + self.__write_persistent_headers.clear() + + def add_transform(self, trans_id): self.__write_transforms.append(trans_id) - return buf - - def flush(self): - self.flushImpl(False) - def onewayFlush(self): - self.flushImpl(True) - - def _flushHeaderMessage(self, buf, wout, wsz): - """Write a message for CLIENT_TYPE.HEADER + def _reset_protocol(self): + # HTTP calls that are one way need to flush here. + if self.__client_type == CLIENT_TYPE.HTTP_SERVER: + self.flush() + # set to anything except unframed + self.__client_type = CLIENT_TYPE.UNKNOWN + # Read header bytes to check which protocol to decode + self.readFrame(0) + + def getTransport(self): + return self.__trans + + def isOpen(self): + return self.getTransport().isOpen() + + def open(self): + return self.getTransport().open() + + def close(self): + return self.getTransport().close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) == sz: + return ret + + if self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): + return ret + self.getTransport().readAll(sz - len(ret)) + + self.readFrame(sz - len(ret)) + return ret + self.__rbuf.read(sz - len(ret)) + + readAll = read # TTransportBase.readAll does a needless copy here. + + def readFrame(self, req_sz): + self.__rbuf_frame = True + word1 = self.getTransport().readAll(4) + sz = unpack(b'!I', word1)[0] + proto_id = word1[0] if PY3 else ord(word1[0]) + if proto_id == BINARY_PROTO_ID: + # unframed + self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED + self.__proto_id = T_BINARY_PROTOCOL + if req_sz <= 4: # check for reads < 0. + self.__rbuf = StringIO(word1) + else: + self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4)) + elif proto_id == COMPACT_PROTO_ID: + self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED + self.__proto_id = T_COMPACT_PROTOCOL + if req_sz <= 4: # check for reads < 0. + self.__rbuf = StringIO(word1) + else: + self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4)) + elif sz == HTTP_SERVER_MAGIC: + self.__client_type = CLIENT_TYPE.HTTP_SERVER + mf = self.getTransport().handle.makefile('rb', -1) + + self.handler = RequestHandler(mf, 'client_address:port', '') + self.header = self.handler.wfile + self.__rbuf = StringIO(self.handler.data) + else: + if sz == BIG_FRAME_MAGIC: + sz = unpack(b'!Q', self.getTransport().readAll(8))[0] + # could be header format or framed. Check next two bytes. + magic = self.getTransport().readAll(2) + proto_id = magic[0] if PY3 else ord(magic[0]) + if proto_id == COMPACT_PROTO_ID: + self.__client_type = CLIENT_TYPE.FRAMED_COMPACT + self.__proto_id = T_COMPACT_PROTOCOL + _frame_size_check(sz, self.__max_frame_size, header=False) + self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2)) + elif proto_id == BINARY_PROTO_ID: + self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED + self.__proto_id = T_BINARY_PROTOCOL + _frame_size_check(sz, self.__max_frame_size, header=False) + self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2)) + elif magic == PACKED_HEADER_MAGIC: + self.__client_type = CLIENT_TYPE.HEADER + _frame_size_check(sz, self.__max_frame_size) + # flags(2), seq_id(4), header_size(2) + n_header_meta = self.getTransport().readAll(8) + self.__flags, self.seq_id, header_size = unpack(b'!HIH', n_header_meta) + data = StringIO() + data.write(magic) + data.write(n_header_meta) + data.write(self.getTransport().readAll(sz - 10)) + data.seek(10) + self.read_header_format(sz - 10, header_size, data) + else: + self.__client_type = CLIENT_TYPE.UNKNOWN + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Could not detect client transport type") + + if self.__client_type not in self.__supported_client_types: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Client type {} not supported on server".format(self.__client_type)) + + def read_header_format(self, sz, header_size, data): + # clear out any previous transforms + self.__read_transforms = [] + + header_size = header_size * 4 + if header_size > sz: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "Header size is larger than frame") + end_header = header_size + data.tell() + + self.__proto_id = readVarint(data) + num_headers = readVarint(data) + + if self.__proto_id == 1 and self.__client_type != \ + CLIENT_TYPE.HTTP_SERVER: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Trying to recv JSON encoding over binary") + + # Read the headers. Data for each header varies. + for _ in range(0, num_headers): + trans_id = readVarint(data) + if trans_id == TRANSFORM.ZLIB: + self.__read_transforms.insert(0, trans_id) + elif trans_id == TRANSFORM.SNAPPY: + self.__read_transforms.insert(0, trans_id) + elif trans_id == TRANSFORM.HMAC: + raise TApplicationException(TApplicationException.INVALID_TRANSFORM, + "Hmac transform is no longer supported: %i" % trans_id) + else: + # TApplicationException will be sent back to client + raise TApplicationException(TApplicationException.INVALID_TRANSFORM, + "Unknown transform in client request: %i" % trans_id) + + # Clear out previous info headers. + self.__read_headers.clear() + + # Read the info headers. + while data.tell() < end_header: + info_id = readVarint(data) + if info_id == INFO.NORMAL: + _read_info_headers(data, end_header, self.__read_headers) + elif info_id == INFO.PERSISTENT: + _read_info_headers(data, end_header, self.__read_persistent_headers) + else: + break # Unknown header. Stop info processing. + + if self.__read_persistent_headers: + self.__read_headers.update(self.__read_persistent_headers) + + # Skip the rest of the header + data.seek(end_header) + + payload = data.read(sz - header_size) + + # Read the data section. + self.__rbuf = StringIO(self.untransform(payload)) + + def write(self, buf): + self.__wbuf.write(buf) + + def transform(self, buf): + for trans_id in self.__write_transforms: + if trans_id == TRANSFORM.ZLIB: + buf = zlib.compress(buf) + elif trans_id == TRANSFORM.SNAPPY: + buf = snappy.compress(buf) + else: + raise TTransportException(TTransportException.INVALID_TRANSFORM, + "Unknown transform during send") + return buf + + def untransform(self, buf): + for trans_id in self.__read_transforms: + if trans_id == TRANSFORM.ZLIB: + buf = zlib.decompress(buf) + elif trans_id == TRANSFORM.SNAPPY: + buf = snappy.decompress(buf) + if trans_id not in self.__write_transforms: + self.__write_transforms.append(trans_id) + return buf + + def flush(self): + self.flushImpl(False) + + def onewayFlush(self): + self.flushImpl(True) + + def _flushHeaderMessage(self, buf, wout, wsz): + """Write a message for CLIENT_TYPE.HEADER @param buf(StringIO): Buffer to write message to @param wout(str): Payload @param wsz(int): Payload length """ - transform_data = StringIO() - # For now, all transforms don't require data. - num_transforms = len(self.__write_transforms) - for trans_id in self.__write_transforms: - transform_data.write(getVarint(trans_id)) - - # Add in special flags. - if self.__identity: - self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION - self.__write_headers[self.IDENTITY_HEADER] = self.__identity - - info_data = StringIO() - - # Write persistent kv-headers - _flush_info_headers(info_data, self.get_write_persistent_headers(), INFO.PERSISTENT) - - # Write non-persistent kv-headers - _flush_info_headers(info_data, self.__write_headers, INFO.NORMAL) - - header_data = StringIO() - header_data.write(getVarint(self.__proto_id)) - header_data.write(getVarint(num_transforms)) - - header_size = transform_data.tell() + header_data.tell() + \ - info_data.tell() - - padding_size = 4 - (header_size % 4) - header_size = header_size + padding_size - - # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2) - wsz += header_size + 10 - if wsz > MAX_FRAME_SIZE: - buf.write(pack(b"!I", BIG_FRAME_MAGIC)) - buf.write(pack(b"!Q", wsz)) - else: - buf.write(pack(b"!I", wsz)) - buf.write(pack(b"!HH", HEADER_MAGIC >> 16, self.__flags)) - buf.write(pack(b"!I", self.seq_id)) - buf.write(pack(b"!H", header_size // 4)) - - buf.write(header_data.getvalue()) - buf.write(transform_data.getvalue()) - buf.write(info_data.getvalue()) - - # Pad out the header with 0x00 - for _ in range(0, padding_size, 1): - buf.write(pack(b"!c", b'\0')) - - # Send data section - buf.write(wout) - - def flushImpl(self, oneway): - wout = self.__wbuf.getvalue() - wout = self.transform(wout) - wsz = len(wout) - - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf.seek(0) - self.__wbuf.truncate() - - if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER: - raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, - "Trying to send JSON encoding over binary") - - buf = StringIO() - if self.__client_type == CLIENT_TYPE.HEADER: - self._flushHeaderMessage(buf, wout, wsz) - elif self.__client_type in (CLIENT_TYPE.FRAMED_DEPRECATED, CLIENT_TYPE.FRAMED_COMPACT): - buf.write(pack(b"!i", wsz)) - buf.write(wout) - elif self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, - CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): - buf.write(wout) - elif self.__client_type == CLIENT_TYPE.HTTP_SERVER: - # Reset the client type if we sent something - - # oneway calls via HTTP expect a status response otherwise - buf.write(self.header.getvalue()) - buf.write(wout) - self.__client_type == CLIENT_TYPE.HEADER - elif self.__client_type == CLIENT_TYPE.UNKNOWN: - raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, "Unknown client type") - - # We don't include the framing bytes as part of the frame size check - frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12) - _frame_size_check(frame_size, - self.__max_frame_size, - header=self.__client_type == CLIENT_TYPE.HEADER) - self.getTransport().write(buf.getvalue()) - if oneway: - self.getTransport().onewayFlush() - else: - self.getTransport().flush() - - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - if not self.__rbuf_frame: - self.readFrame(0) - return self.__rbuf - - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastproto doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - - # On unframed clients, there is a chance there is something left - # in rbuf, and the read pointer is not advanced by fastproto - # so seek to the end to be safe - self.__rbuf.seek(0, 2) - while len(prefix) < reqlen: - prefix += self.read(reqlen) - self.__rbuf = StringIO(prefix) - return self.__rbuf + transform_data = StringIO() + # For now, all transforms don't require data. + num_transforms = len(self.__write_transforms) + for trans_id in self.__write_transforms: + transform_data.write(getVarint(trans_id)) + + # Add in special flags. + if self.__identity: + self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION + self.__write_headers[self.IDENTITY_HEADER] = self.__identity + + info_data = StringIO() + + # Write persistent kv-headers + _flush_info_headers(info_data, self.get_write_persistent_headers(), INFO.PERSISTENT) + + # Write non-persistent kv-headers + _flush_info_headers(info_data, self.__write_headers, INFO.NORMAL) + + header_data = StringIO() + header_data.write(getVarint(self.__proto_id)) + header_data.write(getVarint(num_transforms)) + + header_size = transform_data.tell() + header_data.tell() + \ + info_data.tell() + + padding_size = 4 - (header_size % 4) + header_size = header_size + padding_size + + # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2) + wsz += header_size + 10 + if wsz > MAX_FRAME_SIZE: + buf.write(pack(b"!I", BIG_FRAME_MAGIC)) + buf.write(pack(b"!Q", wsz)) + else: + buf.write(pack(b"!I", wsz)) + buf.write(pack(b"!HH", HEADER_MAGIC >> 16, self.__flags)) + buf.write(pack(b"!I", self.seq_id)) + buf.write(pack(b"!H", header_size // 4)) + + buf.write(header_data.getvalue()) + buf.write(transform_data.getvalue()) + buf.write(info_data.getvalue()) + + # Pad out the header with 0x00 + for _ in range(0, padding_size, 1): + buf.write(pack(b"!c", b'\0')) + + # Send data section + buf.write(wout) + + def flushImpl(self, oneway): + wout = self.__wbuf.getvalue() + wout = self.transform(wout) + wsz = len(wout) + + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf.seek(0) + self.__wbuf.truncate() + + if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Trying to send JSON encoding over binary") + + buf = StringIO() + if self.__client_type == CLIENT_TYPE.HEADER: + self._flushHeaderMessage(buf, wout, wsz) + elif self.__client_type in (CLIENT_TYPE.FRAMED_DEPRECATED, CLIENT_TYPE.FRAMED_COMPACT): + buf.write(pack(b"!i", wsz)) + buf.write(wout) + elif self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): + buf.write(wout) + elif self.__client_type == CLIENT_TYPE.HTTP_SERVER: + # Reset the client type if we sent something - + # oneway calls via HTTP expect a status response otherwise + buf.write(self.header.getvalue()) + buf.write(wout) + self.__client_type == CLIENT_TYPE.HEADER + elif self.__client_type == CLIENT_TYPE.UNKNOWN: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Unknown client type") + + # We don't include the framing bytes as part of the frame size check + frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12) + _frame_size_check(frame_size, + self.__max_frame_size, + header=self.__client_type == CLIENT_TYPE.HEADER) + self.getTransport().write(buf.getvalue()) + if oneway: + self.getTransport().onewayFlush() + else: + self.getTransport().flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + if not self.__rbuf_frame: + self.readFrame(0) + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastproto doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + + # On unframed clients, there is a chance there is something left + # in rbuf, and the read pointer is not advanced by fastproto + # so seek to the end to be safe + self.__rbuf.seek(0, 2) + while len(prefix) < reqlen: + prefix += self.read(reqlen) + self.__rbuf = StringIO(prefix) + return self.__rbuf def _serialize_string(str_): - if PY3 and not isinstance(str_, bytes): - str_ = str_.encode() - return getVarint(len(str_)) + str_ + if PY3 and not isinstance(str_, bytes): + str_ = str_.encode() + return getVarint(len(str_)) + str_ def _flush_info_headers(info_data, write_headers, type): - if (len(write_headers) > 0): - info_data.write(getVarint(type)) - info_data.write(getVarint(len(write_headers))) - write_headers_iter = write_headers.items() - for str_key, str_value in write_headers_iter: - info_data.write(_serialize_string(str_key)) - info_data.write(_serialize_string(str_value)) - write_headers.clear() + if (len(write_headers) > 0): + info_data.write(getVarint(type)) + info_data.write(getVarint(len(write_headers))) + write_headers_iter = write_headers.items() + for str_key, str_value in write_headers_iter: + info_data.write(_serialize_string(str_key)) + info_data.write(_serialize_string(str_value)) + write_headers.clear() def _read_string(bufio, buflimit): - str_sz = readVarint(bufio) - if str_sz + bufio.tell() > buflimit: - raise TTransportException(TTransportException.INVALID_FRAME_SIZE, "String read too big") - return bufio.read(str_sz) + str_sz = readVarint(bufio) + if str_sz + bufio.tell() > buflimit: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, "String read too big") + return bufio.read(str_sz) def _read_info_headers(data, end_header, read_headers): - num_keys = readVarint(data) - for _ in xrange(num_keys): - str_key = _read_string(data, end_header) - str_value = _read_string(data, end_header) - read_headers[str_key] = str_value + num_keys = readVarint(data) + for _ in xrange(num_keys): + str_key = _read_string(data, end_header) + str_value = _read_string(data, end_header) + read_headers[str_key] = str_value def _frame_size_check(sz, set_max_size, header=True): - if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE): - raise TTransportException(TTransportException.INVALID_FRAME_SIZE, - "%s transport frame was too large" % 'Header' if header else 'Framed') + if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE): + raise TTransportException( + TTransportException.INVALID_FRAME_SIZE, + "%s transport frame was too large" % 'Header' if header else 'Framed') class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): - # Same as superclass function, but append 'POST' because we - # stripped it in the calling function. Would be nice if - # we had an ungetch instead - def handle_one_request(self): - self.raw_requestline = self.rfile.readline() - if not self.raw_requestline: - self.close_connection = 1 - return - self.raw_requestline = "POST" + self.raw_requestline - if not self.parse_request(): - # An error code has been sent, just exit - return - mname = 'do_' + self.command - if not hasattr(self, mname): - self.send_error(501, "Unsupported method (%r)" % self.command) - return - method = getattr(self, mname) - method() - - def setup(self): - self.rfile = self.request - self.wfile = StringIO() # New output buffer - - def finish(self): - if not self.rfile.closed: - self.rfile.close() - # leave wfile open for reading. - - def do_POST(self): - if int(self.headers['Content-Length']) > 0: - self.data = self.rfile.read(int(self.headers['Content-Length'])) - else: - self.data = "" - - # Prepare a response header, to be sent later. - self.send_response(200) - self.send_header("content-type", "application/x-thrift") - self.end_headers() + # Same as superclass function, but append 'POST' because we + # stripped it in the calling function. Would be nice if + # we had an ungetch instead + def handle_one_request(self): + self.raw_requestline = self.rfile.readline() + if not self.raw_requestline: + self.close_connection = 1 + return + self.raw_requestline = "POST" + self.raw_requestline + if not self.parse_request(): + # An error code has been sent, just exit + return + mname = 'do_' + self.command + if not hasattr(self, mname): + self.send_error(501, "Unsupported method (%r)" % self.command) + return + method = getattr(self, mname) + method() + + def setup(self): + self.rfile = self.request + self.wfile = StringIO() # New output buffer + + def finish(self): + if not self.rfile.closed: + self.rfile.close() + # leave wfile open for reading. + + def do_POST(self): + if int(self.headers['Content-Length']) > 0: + self.data = self.rfile.read(int(self.headers['Content-Length'])) + else: + self.data = "" + + # Prepare a response header, to be sent later. + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() # INFO:(zuercher): Added to simplify usage class THeaderTransportFactory: - def __init__(self, proto_id): - self.__proto_id = proto_id + def __init__(self, proto_id): + self.__proto_id = proto_id - def getTransport(self, trans): - header_trans = THeaderTransport(trans, client_type=CLIENT_TYPE.HEADER) - header_trans.set_protocol_id(self.__proto_id) - return header_trans + def getTransport(self, trans): + header_trans = THeaderTransport(trans, client_type=CLIENT_TYPE.HEADER) + header_trans.set_protocol_id(self.__proto_id) + return header_trans diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py index 45e57fd8733c..6370649468b8 100644 --- a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py @@ -11,49 +11,49 @@ # Twitter's TFinagleProcessor only works for the client side of an RPC. class TFinagleServerProcessor(TProcessor): - def __init__(self, underlying): - self._underlying = underlying - - def process(self, iprot, oprot): - try: - if iprot.upgraded() is not None: - return self._underlying.process(iprot, oprot) - except AttributeError as e: - logging.exception("underlying protocol object is not a TFinagleServerProtocol", e) - return self._underlying.process(iprot, oprot) - - (name, ttype, seqid) = iprot.readMessageBegin() - if ttype != TMessageType.CALL and ttype != TMessageType.ONEWAY: - raise TException("TFinagle protocol only supports CALL & ONEWAY") - - # Check if this is an upgrade request. - if name == UPGRADE_METHOD: - connection_options = ConnectionOptions() - connection_options.read(iprot) - iprot.readMessageEnd() - - oprot.writeMessageBegin(UPGRADE_METHOD, TMessageType.REPLY, seqid) - upgrade_reply = UpgradeReply() - upgrade_reply.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - iprot.set_upgraded(True) - oprot.set_upgraded(True) - return True - - # Not upgraded. Replay the message begin to the underlying processor. - iprot.set_upgraded(False) - oprot.set_upgraded(False) - msg = (name, ttype, seqid) - return self._underlying.process(StoredMessageProtocol(iprot, msg), oprot) + def __init__(self, underlying): + self._underlying = underlying + + def process(self, iprot, oprot): + try: + if iprot.upgraded() is not None: + return self._underlying.process(iprot, oprot) + except AttributeError as e: + logging.exception("underlying protocol object is not a TFinagleServerProtocol", e) + return self._underlying.process(iprot, oprot) + + (name, ttype, seqid) = iprot.readMessageBegin() + if ttype != TMessageType.CALL and ttype != TMessageType.ONEWAY: + raise TException("TFinagle protocol only supports CALL & ONEWAY") + + # Check if this is an upgrade request. + if name == UPGRADE_METHOD: + connection_options = ConnectionOptions() + connection_options.read(iprot) + iprot.readMessageEnd() + + oprot.writeMessageBegin(UPGRADE_METHOD, TMessageType.REPLY, seqid) + upgrade_reply = UpgradeReply() + upgrade_reply.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + iprot.set_upgraded(True) + oprot.set_upgraded(True) + return True + + # Not upgraded. Replay the message begin to the underlying processor. + iprot.set_upgraded(False) + oprot.set_upgraded(False) + msg = (name, ttype, seqid) + return self._underlying.process(StoredMessageProtocol(iprot, msg), oprot) class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): - def __init__(self, protocol, messageBegin): - TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) - self.messageBegin = messageBegin + def __init__(self, protocol, messageBegin): + TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) + self.messageBegin = messageBegin - def readMessageBegin(self): - return self.messageBegin + def readMessageBegin(self): + return self.messageBegin diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py index 3b6eb56c53e3..87d914476825 100644 --- a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py @@ -4,32 +4,32 @@ class TFinagleServerProtocolFactory(object): - def getProtocol(self, trans): - return TFinagleServerProtocol(trans) + def getProtocol(self, trans): + return TFinagleServerProtocol(trans) class TFinagleServerProtocol(TBinaryProtocol.TBinaryProtocol): - def __init__(self, *args, **kw): - self._last_request = None - self._upgraded = None - TBinaryProtocol.TBinaryProtocol.__init__(self, *args, **kw) - - def upgraded(self): - return self._upgraded - - def set_upgraded(self, upgraded): - self._upgraded = upgraded - - def writeMessageBegin(self, *args, **kwargs): - if self._upgraded: - header = ResponseHeader() # .. TODO set some fields - header.write(self) - return TBinaryProtocol.TBinaryProtocol.writeMessageBegin(self, *args, **kwargs) - - def readMessageBegin(self, *args, **kwargs): - if self._upgraded: - header = RequestHeader() - header.read(self) - self._last_request = header - return TBinaryProtocol.TBinaryProtocol.readMessageBegin(self, *args, **kwargs) + def __init__(self, *args, **kw): + self._last_request = None + self._upgraded = None + TBinaryProtocol.TBinaryProtocol.__init__(self, *args, **kw) + + def upgraded(self): + return self._upgraded + + def set_upgraded(self, upgraded): + self._upgraded = upgraded + + def writeMessageBegin(self, *args, **kwargs): + if self._upgraded: + header = ResponseHeader() # .. TODO set some fields + header.write(self) + return TBinaryProtocol.TBinaryProtocol.writeMessageBegin(self, *args, **kwargs) + + def readMessageBegin(self, *args, **kwargs): + if self._upgraded: + header = RequestHeader() + header.read(self) + self._last_request = header + return TBinaryProtocol.TBinaryProtocol.readMessageBegin(self, *args, **kwargs) diff --git a/test/extensions/filters/network/thrift_proxy/driver/server.py b/test/extensions/filters/network/thrift_proxy/driver/server.py index 650280919eaa..b8246d1153e4 100755 --- a/test/extensions/filters/network/thrift_proxy/driver/server.py +++ b/test/extensions/filters/network/thrift_proxy/driver/server.py @@ -17,218 +17,218 @@ # On Windows we run this test on Python3 if sys.version_info[0] != 2: - sys.stdin.reconfigure(encoding='utf-8') - sys.stdout.reconfigure(encoding='utf-8') + sys.stdin.reconfigure(encoding='utf-8') + sys.stdout.reconfigure(encoding='utf-8') class SuccessHandler: - def ping(self): - print("server: ping()") + def ping(self): + print("server: ping()") - def poke(self): - print("server: poke()") + def poke(self): + print("server: poke()") - def add(self, a, b): - result = a + b - print("server: add({0}, {1}) = {2}".format(a, b, result)) - return result + def add(self, a, b): + result = a + b + print("server: add({0}, {1}) = {2}".format(a, b, result)) + return result - def execute(self, param): - print("server: execute({0})".format(param)) - if "all" in param.return_fields: - return Result(param.the_works) - elif "none" in param.return_fields: - return Result(TheWorks()) - the_works = TheWorks() - for field, value in vars(param.the_works).items(): - if field in param.return_fields: - setattr(the_works, field, value) - return Result(the_works) + def execute(self, param): + print("server: execute({0})".format(param)) + if "all" in param.return_fields: + return Result(param.the_works) + elif "none" in param.return_fields: + return Result(TheWorks()) + the_works = TheWorks() + for field, value in vars(param.the_works).items(): + if field in param.return_fields: + setattr(the_works, field, value) + return Result(the_works) class IDLExceptionHandler: - def ping(self): - print("server: ping()") + def ping(self): + print("server: ping()") - def poke(self): - print("server: poke()") + def poke(self): + print("server: poke()") - def add(self, a, b): - result = a + b - print("server: add({0}, {1}) = {2}".format(a, b, result)) - return result + def add(self, a, b): + result = a + b + print("server: add({0}, {1}) = {2}".format(a, b, result)) + return result - def execute(self, param): - print("server: app error: execute failed") - raise AppException("execute failed") + def execute(self, param): + print("server: app error: execute failed") + raise AppException("execute failed") class ExceptionHandler: - def ping(self): - print("server: ping failure") - raise Thrift.TApplicationException( - type=Thrift.TApplicationException.INTERNAL_ERROR, - message="for ping", - ) + def ping(self): + print("server: ping failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for ping", + ) + + def poke(self): + print("server: poke failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for poke", + ) + + def add(self, a, b): + print("server: add failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for add", + ) + + def execute(self, param): + print("server: execute failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for execute", + ) - def poke(self): - print("server: poke failure") - raise Thrift.TApplicationException( - type=Thrift.TApplicationException.INTERNAL_ERROR, - message="for poke", - ) - def add(self, a, b): - print("server: add failure") - raise Thrift.TApplicationException( - type=Thrift.TApplicationException.INTERNAL_ERROR, - message="for add", - ) +def main(cfg): + if cfg.unix: + if cfg.addr == "": + sys.exit("invalid listener unix domain socket: {}".format(cfg.addr)) + else: + try: + (host, port) = cfg.addr.rsplit(":", 1) + port = int(port) + except ValueError: + sys.exit("invalid listener address: {}".format(cfg.addr)) + + if cfg.response == "success": + handler = SuccessHandler() + elif cfg.response == "idl-exception": + handler = IDLExceptionHandler() + elif cfg.response == "exception": + # squelch traceback for the exception we throw + logging.getLogger().setLevel(logging.CRITICAL) + handler = ExceptionHandler() + else: + sys.exit("unknown server response mode {0}".format(cfg.response)) - def execute(self, param): - print("server: execute failure") - raise Thrift.TApplicationException( - type=Thrift.TApplicationException.INTERNAL_ERROR, - message="for execute", - ) + processor = Example.Processor(handler) + if cfg.service is not None: + # wrap processor with multiplexor + multi = TMultiplexedProcessor.TMultiplexedProcessor() + multi.registerProcessor(cfg.service, processor) + processor = multi + if cfg.protocol == "finagle": + # wrap processor with finagle request/response header handler + processor = TFinagleServerProcessor.TFinagleServerProcessor(processor) + + if cfg.unix: + transport = TSocket.TServerSocket(unix_socket=cfg.addr) + else: + transport = TSocket.TServerSocket(host=host, port=port) + + if cfg.transport == "framed": + transport_factory = TTransport.TFramedTransportFactory() + elif cfg.transport == "unframed": + transport_factory = TTransport.TBufferedTransportFactory() + elif cfg.transport == "header": + if cfg.protocol == "binary": + transport_factory = THeaderTransport.THeaderTransportFactory( + THeaderTransport.T_BINARY_PROTOCOL) + elif cfg.protocol == "compact": + transport_factory = THeaderTransport.THeaderTransportFactory( + THeaderTransport.T_COMPACT_PROTOCOL) + else: + sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) + else: + sys.exit("unknown transport {0}".format(cfg.transport)) -def main(cfg): - if cfg.unix: - if cfg.addr == "": - sys.exit("invalid listener unix domain socket: {}".format(cfg.addr)) - else: - try: - (host, port) = cfg.addr.rsplit(":", 1) - port = int(port) - except ValueError: - sys.exit("invalid listener address: {}".format(cfg.addr)) - - if cfg.response == "success": - handler = SuccessHandler() - elif cfg.response == "idl-exception": - handler = IDLExceptionHandler() - elif cfg.response == "exception": - # squelch traceback for the exception we throw - logging.getLogger().setLevel(logging.CRITICAL) - handler = ExceptionHandler() - else: - sys.exit("unknown server response mode {0}".format(cfg.response)) - - processor = Example.Processor(handler) - if cfg.service is not None: - # wrap processor with multiplexor - multi = TMultiplexedProcessor.TMultiplexedProcessor() - multi.registerProcessor(cfg.service, processor) - processor = multi - - if cfg.protocol == "finagle": - # wrap processor with finagle request/response header handler - processor = TFinagleServerProcessor.TFinagleServerProcessor(processor) - - if cfg.unix: - transport = TSocket.TServerSocket(unix_socket=cfg.addr) - else: - transport = TSocket.TServerSocket(host=host, port=port) - - if cfg.transport == "framed": - transport_factory = TTransport.TFramedTransportFactory() - elif cfg.transport == "unframed": - transport_factory = TTransport.TBufferedTransportFactory() - elif cfg.transport == "header": if cfg.protocol == "binary": - transport_factory = THeaderTransport.THeaderTransportFactory( - THeaderTransport.T_BINARY_PROTOCOL) + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() elif cfg.protocol == "compact": - transport_factory = THeaderTransport.THeaderTransportFactory( - THeaderTransport.T_COMPACT_PROTOCOL) + protocol_factory = TCompactProtocol.TCompactProtocolFactory() + elif cfg.protocol == "json": + protocol_factory = TJSONProtocol.TJSONProtocolFactory() + elif cfg.protocol == "finagle": + protocol_factory = TFinagleServerProtocol.TFinagleServerProtocolFactory() else: - sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) - else: - sys.exit("unknown transport {0}".format(cfg.transport)) - - if cfg.protocol == "binary": - protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() - elif cfg.protocol == "compact": - protocol_factory = TCompactProtocol.TCompactProtocolFactory() - elif cfg.protocol == "json": - protocol_factory = TJSONProtocol.TJSONProtocolFactory() - elif cfg.protocol == "finagle": - protocol_factory = TFinagleServerProtocol.TFinagleServerProtocolFactory() - else: - sys.exit("unknown protocol {0}".format(cfg.protocol)) - - print("Thrift Server listening on {0} for {1} {2} requests".format(cfg.addr, cfg.transport, - cfg.protocol)) - if cfg.service is not None: - print("Thrift Server service name {0}".format(cfg.service)) - if cfg.response == "idl-exception": - print("Thrift Server will throw IDL exceptions when defined") - elif cfg.response == "exception": - print("Thrift Server will throw Thrift exceptions for all messages") - - server = TServer.TThreadedServer(processor, transport, transport_factory, protocol_factory) - try: - server.serve() - except KeyboardInterrupt: - print + sys.exit("unknown protocol {0}".format(cfg.protocol)) + + print("Thrift Server listening on {0} for {1} {2} requests".format( + cfg.addr, cfg.transport, cfg.protocol)) + if cfg.service is not None: + print("Thrift Server service name {0}".format(cfg.service)) + if cfg.response == "idl-exception": + print("Thrift Server will throw IDL exceptions when defined") + elif cfg.response == "exception": + print("Thrift Server will throw Thrift exceptions for all messages") + + server = TServer.TThreadedServer(processor, transport, transport_factory, protocol_factory) + try: + server.serve() + except KeyboardInterrupt: + print if __name__ == "__main__": - logging.basicConfig() - parser = argparse.ArgumentParser(description="Thrift server to match client.py.") - parser.add_argument( - "-a", - "--addr", - metavar="ADDR", - dest="addr", - default=":0", - help="Listener address for server in the form host:port. The host is optional. If --unix" + - " is set, the address is the socket name.", - ) - parser.add_argument( - "-m", - "--multiplex", - metavar="SERVICE", - dest="service", - help="Enable service multiplexing and set the service name.", - ) - parser.add_argument( - "-p", - "--protocol", - help="Selects a protocol.", - dest="protocol", - default="binary", - choices=["binary", "compact", "json", "finagle"], - ) - parser.add_argument( - "-r", - "--response", - dest="response", - default="success", - choices=["success", "idl-exception", "exception"], - help="Controls how the server responds to requests", - ) - parser.add_argument( - "-t", - "--transport", - help="Selects a transport.", - dest="transport", - default="framed", - choices=["framed", "unframed", "header"], - ) - parser.add_argument( - "-u", - "--unix", - dest="unix", - action="store_true", - ) - cfg = parser.parse_args() - - try: - main(cfg) - except Thrift.TException as tx: - sys.exit("Thrift exception: {0}".format(tx.message)) + logging.basicConfig() + parser = argparse.ArgumentParser(description="Thrift server to match client.py.") + parser.add_argument( + "-a", + "--addr", + metavar="ADDR", + dest="addr", + default=":0", + help="Listener address for server in the form host:port. The host is optional. If --unix" + + " is set, the address is the socket name.", + ) + parser.add_argument( + "-m", + "--multiplex", + metavar="SERVICE", + dest="service", + help="Enable service multiplexing and set the service name.", + ) + parser.add_argument( + "-p", + "--protocol", + help="Selects a protocol.", + dest="protocol", + default="binary", + choices=["binary", "compact", "json", "finagle"], + ) + parser.add_argument( + "-r", + "--response", + dest="response", + default="success", + choices=["success", "idl-exception", "exception"], + help="Controls how the server responds to requests", + ) + parser.add_argument( + "-t", + "--transport", + help="Selects a transport.", + dest="transport", + default="framed", + choices=["framed", "unframed", "header"], + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + action="store_true", + ) + cfg = parser.parse_args() + + try: + main(cfg) + except Thrift.TException as tx: + sys.exit("Thrift exception: {0}".format(tx.message)) diff --git a/test/integration/capture_fuzz_gen.py b/test/integration/capture_fuzz_gen.py index b6e77001bbe7..a057c9e34e5b 100644 --- a/test/integration/capture_fuzz_gen.py +++ b/test/integration/capture_fuzz_gen.py @@ -19,75 +19,75 @@ # Collapse adjacent event in the trace that are of the same type. def Coalesce(trace): - if not trace.events: - return [] - events = [trace.events[0]] - for event in trace.events[1:]: - if events[-1].HasField('read') and event.HasField('read'): - events[-1].read.data += event.read.data - elif events[-1].HasField('write') and event.HasField('write'): - events[-1].write.data += event.write.data - else: - events.append(event) - return events + if not trace.events: + return [] + events = [trace.events[0]] + for event in trace.events[1:]: + if events[-1].HasField('read') and event.HasField('read'): + events[-1].read.data += event.read.data + elif events[-1].HasField('write') and event.HasField('write'): + events[-1].write.data += event.write.data + else: + events.append(event) + return events # Convert from transport socket Event to test Event. def ToTestEvent(direction, event): - test_event = capture_fuzz_pb2.Event() - if event.HasField('read'): - setattr(test_event, '%s_send_bytes' % direction, event.read.data) - elif event.HasField('write'): - getattr(test_event, '%s_recv_bytes' % direction).MergeFrom(empty_pb2.Empty()) - return test_event + test_event = capture_fuzz_pb2.Event() + if event.HasField('read'): + setattr(test_event, '%s_send_bytes' % direction, event.read.data) + elif event.HasField('write'): + getattr(test_event, '%s_recv_bytes' % direction).MergeFrom(empty_pb2.Empty()) + return test_event def ToDownstreamTestEvent(event): - return ToTestEvent('downstream', event) + return ToTestEvent('downstream', event) def ToUpstreamTestEvent(event): - return ToTestEvent('upstream', event) + return ToTestEvent('upstream', event) # Zip together the listener/cluster events to produce a single trace for replay. def TestCaseGen(listener_events, cluster_events): - test_case = capture_fuzz_pb2.CaptureFuzzTestCase() - if not listener_events: - return test_case - test_case.events.extend([ToDownstreamTestEvent(listener_events[0])]) - del listener_events[0] - while listener_events or cluster_events: + test_case = capture_fuzz_pb2.CaptureFuzzTestCase() if not listener_events: - test_case.events.extend(map(ToUpstreamTestEvent, cluster_events)) - return test_case - if not cluster_events: - test_case.events.extend(map(ToDownstreamTestEvent, listener_events)) - return test_case - if listener_events[0].timestamp.ToDatetime() < cluster_events[0].timestamp.ToDatetime(): - test_case.events.extend([ToDownstreamTestEvent(listener_events[0])]) - del listener_events[0] - test_case.events.extend([ToUpstreamTestEvent(cluster_events[0])]) - del cluster_events[0] + return test_case + test_case.events.extend([ToDownstreamTestEvent(listener_events[0])]) + del listener_events[0] + while listener_events or cluster_events: + if not listener_events: + test_case.events.extend(map(ToUpstreamTestEvent, cluster_events)) + return test_case + if not cluster_events: + test_case.events.extend(map(ToDownstreamTestEvent, listener_events)) + return test_case + if listener_events[0].timestamp.ToDatetime() < cluster_events[0].timestamp.ToDatetime(): + test_case.events.extend([ToDownstreamTestEvent(listener_events[0])]) + del listener_events[0] + test_case.events.extend([ToUpstreamTestEvent(cluster_events[0])]) + del cluster_events[0] def CaptureFuzzGen(listener_path, cluster_path=None): - listener_trace = capture_pb2.Trace() - with open(listener_path, 'r') as f: - text_format.Merge(f.read(), listener_trace) - listener_events = Coalesce(listener_trace) + listener_trace = capture_pb2.Trace() + with open(listener_path, 'r') as f: + text_format.Merge(f.read(), listener_trace) + listener_events = Coalesce(listener_trace) - cluster_trace = capture_pb2.Trace() - if cluster_path: - with open(cluster_path, 'r') as f: - text_format.Merge(f.read(), cluster_trace) - cluster_events = Coalesce(cluster_trace) + cluster_trace = capture_pb2.Trace() + if cluster_path: + with open(cluster_path, 'r') as f: + text_format.Merge(f.read(), cluster_trace) + cluster_events = Coalesce(cluster_trace) - print(TestCaseGen(listener_events, cluster_events)) + print(TestCaseGen(listener_events, cluster_events)) if __name__ == '__main__': - if len(sys.argv) < 2 or len(sys.argv) > 3: - print('Usage: %s []' % sys.argv[0]) - sys.exit(1) - CaptureFuzzGen(*sys.argv[1:]) + if len(sys.argv) < 2 or len(sys.argv) > 3: + print('Usage: %s []' % sys.argv[0]) + sys.exit(1) + CaptureFuzzGen(*sys.argv[1:]) diff --git a/tools/api/generate_go_protobuf.py b/tools/api/generate_go_protobuf.py index 7ac636fc00a3..07cb496579a2 100755 --- a/tools/api/generate_go_protobuf.py +++ b/tools/api/generate_go_protobuf.py @@ -23,114 +23,115 @@ def generate_protobufs(output): - bazel_bin = check_output(['bazel', 'info', 'bazel-bin']).decode().strip() - go_protos = check_output([ - 'bazel', - 'query', - 'kind("go_proto_library", %s)' % TARGETS, - ]).split() - - # Each rule has the form @envoy_api//foo/bar:baz_go_proto. - # First build all the rules to ensure we have the output files. - # We preserve source info so comments are retained on generated code. - check_call([ - 'bazel', 'build', '-c', 'fastbuild', - '--experimental_proto_descriptor_sets_include_source_info' - ] + BAZEL_BUILD_OPTIONS + go_protos) - - for rule in go_protos: - # Example rule: - # @envoy_api//envoy/config/bootstrap/v2:pkg_go_proto - # - # Example generated directory: - # bazel-bin/external/envoy_api/envoy/config/bootstrap/v2/linux_amd64_stripped/pkg_go_proto%/github.com/envoyproxy/go-control-plane/envoy/config/bootstrap/v2/ - # - # Example output directory: - # go_out/envoy/config/bootstrap/v2 - rule_dir, proto = rule.decode()[len('@envoy_api//'):].rsplit(':', 1) - input_dir = os.path.join(bazel_bin, 'external', 'envoy_api', rule_dir, proto + '_', IMPORT_BASE, - rule_dir) - input_files = glob.glob(os.path.join(input_dir, '*.go')) - output_dir = os.path.join(output, rule_dir) - - # Ensure the output directory exists - os.makedirs(output_dir, 0o755, exist_ok=True) - for generated_file in input_files: - shutil.copy(generated_file, output_dir) - print('Go artifacts placed into: ' + output) + bazel_bin = check_output(['bazel', 'info', 'bazel-bin']).decode().strip() + go_protos = check_output([ + 'bazel', + 'query', + 'kind("go_proto_library", %s)' % TARGETS, + ]).split() + + # Each rule has the form @envoy_api//foo/bar:baz_go_proto. + # First build all the rules to ensure we have the output files. + # We preserve source info so comments are retained on generated code. + check_call([ + 'bazel', 'build', '-c', 'fastbuild', + '--experimental_proto_descriptor_sets_include_source_info' + ] + BAZEL_BUILD_OPTIONS + go_protos) + + for rule in go_protos: + # Example rule: + # @envoy_api//envoy/config/bootstrap/v2:pkg_go_proto + # + # Example generated directory: + # bazel-bin/external/envoy_api/envoy/config/bootstrap/v2/linux_amd64_stripped/pkg_go_proto%/github.com/envoyproxy/go-control-plane/envoy/config/bootstrap/v2/ + # + # Example output directory: + # go_out/envoy/config/bootstrap/v2 + rule_dir, proto = rule.decode()[len('@envoy_api//'):].rsplit(':', 1) + input_dir = os.path.join(bazel_bin, 'external', 'envoy_api', rule_dir, proto + '_', + IMPORT_BASE, rule_dir) + input_files = glob.glob(os.path.join(input_dir, '*.go')) + output_dir = os.path.join(output, rule_dir) + + # Ensure the output directory exists + os.makedirs(output_dir, 0o755, exist_ok=True) + for generated_file in input_files: + shutil.copy(generated_file, output_dir) + print('Go artifacts placed into: ' + output) def git(repo, *args): - cmd = ['git'] - if repo: - cmd = cmd + ['-C', repo] - for arg in args: - cmd = cmd + [arg] - return check_output(cmd).decode() + cmd = ['git'] + if repo: + cmd = cmd + ['-C', repo] + for arg in args: + cmd = cmd + [arg] + return check_output(cmd).decode() def clone_go_protobufs(repo): - # Create a local clone of go-control-plane - git(None, 'clone', 'git@github.com:envoyproxy/go-control-plane', repo, '-b', BRANCH) + # Create a local clone of go-control-plane + git(None, 'clone', 'git@github.com:envoyproxy/go-control-plane', repo, '-b', BRANCH) def find_last_sync_sha(repo): - # Determine last envoyproxy/envoy SHA in envoyproxy/go-control-plane - last_commit = git(repo, 'log', '--grep=' + MIRROR_MSG, '-n', '1', '--format=%B').strip() - # Initial SHA from which the APIs start syncing. Prior to that it was done manually. - if last_commit == "": - return 'e7f0b7176efdc65f96eb1697b829d1e6187f4502' - m = re.search(MIRROR_MSG + '(\w+)', last_commit) - return m.group(1) + # Determine last envoyproxy/envoy SHA in envoyproxy/go-control-plane + last_commit = git(repo, 'log', '--grep=' + MIRROR_MSG, '-n', '1', '--format=%B').strip() + # Initial SHA from which the APIs start syncing. Prior to that it was done manually. + if last_commit == "": + return 'e7f0b7176efdc65f96eb1697b829d1e6187f4502' + m = re.search(MIRROR_MSG + '(\w+)', last_commit) + return m.group(1) def updated_since_sha(repo, last_sha): - # Determine if there are changes to API since last SHA - return git(None, 'rev-list', '%s..HEAD' % last_sha).split() + # Determine if there are changes to API since last SHA + return git(None, 'rev-list', '%s..HEAD' % last_sha).split() def write_revision_info(repo, sha): - # Put a file in the generated code root containing the latest mirrored SHA - dst = os.path.join(repo, 'envoy', 'COMMIT') - with open(dst, 'w') as fh: - fh.write(sha) + # Put a file in the generated code root containing the latest mirrored SHA + dst = os.path.join(repo, 'envoy', 'COMMIT') + with open(dst, 'w') as fh: + fh.write(sha) def sync_go_protobufs(output, repo): - # Sync generated content against repo and return true if there is a commit necessary - dst = os.path.join(repo, 'envoy') - # Remove subtree at envoy in repo - git(repo, 'rm', '-r', 'envoy') - # Copy subtree at envoy from output to repo - shutil.copytree(os.path.join(output, 'envoy'), dst) - git(repo, 'add', 'envoy') + # Sync generated content against repo and return true if there is a commit necessary + dst = os.path.join(repo, 'envoy') + # Remove subtree at envoy in repo + git(repo, 'rm', '-r', 'envoy') + # Copy subtree at envoy from output to repo + shutil.copytree(os.path.join(output, 'envoy'), dst) + git(repo, 'add', 'envoy') def publish_go_protobufs(repo, sha): - # Publish generated files with the last SHA changes to API - git(repo, 'config', 'user.name', USER_NAME) - git(repo, 'config', 'user.email', USER_EMAIL) - git(repo, 'add', 'envoy') - git(repo, 'commit', '--allow-empty', '-s', '-m', MIRROR_MSG + sha) - git(repo, 'push', 'origin', BRANCH) + # Publish generated files with the last SHA changes to API + git(repo, 'config', 'user.name', USER_NAME) + git(repo, 'config', 'user.email', USER_EMAIL) + git(repo, 'add', 'envoy') + git(repo, 'commit', '--allow-empty', '-s', '-m', MIRROR_MSG + sha) + git(repo, 'push', 'origin', BRANCH) def updated(repo): - return len( - [f for f in git(repo, 'diff', 'HEAD', '--name-only').splitlines() if f != 'envoy/COMMIT']) > 0 + return len([ + f for f in git(repo, 'diff', 'HEAD', '--name-only').splitlines() if f != 'envoy/COMMIT' + ]) > 0 if __name__ == "__main__": - workspace = check_output(['bazel', 'info', 'workspace']).decode().strip() - output = os.path.join(workspace, OUTPUT_BASE) - generate_protobufs(output) - repo = os.path.join(workspace, REPO_BASE) - clone_go_protobufs(repo) - sync_go_protobufs(output, repo) - last_sha = find_last_sync_sha(repo) - changes = updated_since_sha(repo, last_sha) - if updated(repo): - print('Changes detected: %s' % changes) - new_sha = changes[0] - write_revision_info(repo, new_sha) - publish_go_protobufs(repo, new_sha) + workspace = check_output(['bazel', 'info', 'workspace']).decode().strip() + output = os.path.join(workspace, OUTPUT_BASE) + generate_protobufs(output) + repo = os.path.join(workspace, REPO_BASE) + clone_go_protobufs(repo) + sync_go_protobufs(output, repo) + last_sha = find_last_sync_sha(repo) + changes = updated_since_sha(repo, last_sha) + if updated(repo): + print('Changes detected: %s' % changes) + new_sha = changes[0] + write_revision_info(repo, new_sha) + publish_go_protobufs(repo, new_sha) diff --git a/tools/api/validate_structure.py b/tools/api/validate_structure.py index 2a5c121d5752..0ccddf38ffc3 100755 --- a/tools/api/validate_structure.py +++ b/tools/api/validate_structure.py @@ -35,54 +35,54 @@ class ValidationError(Exception): - pass + pass # Extract major version and full API version string from a proto path. def proto_api_version(proto_path): - match = re.match('v(\d+).*', proto_path.parent.name) - if match: - return str(proto_path.parent.name)[1:], int(match.group(1)) - return None, 0 + match = re.match('v(\d+).*', proto_path.parent.name) + if match: + return str(proto_path.parent.name)[1:], int(match.group(1)) + return None, 0 # Validate a single proto path. def validate_proto_path(proto_path): - version_str, major_version = proto_api_version(proto_path) + version_str, major_version = proto_api_version(proto_path) - # Validate version-less paths. - if major_version == 0: - if not any(str(proto_path.parent) == p for p in VERSIONLESS_PATHS): - raise ValidationError('Package is missing a version') + # Validate version-less paths. + if major_version == 0: + if not any(str(proto_path.parent) == p for p in VERSIONLESS_PATHS): + raise ValidationError('Package is missing a version') - # Validate that v3+ versions are regular. - if major_version >= 3: - if not re.match('\d+(alpha)?$', version_str): - raise ValidationError('Invalid v3+ version: %s' % version_str) + # Validate that v3+ versions are regular. + if major_version >= 3: + if not re.match('\d+(alpha)?$', version_str): + raise ValidationError('Invalid v3+ version: %s' % version_str) - # Validate v2-only paths. - for p in V2_ONLY_PATHS: - if str(proto_path).startswith(p): - raise ValidationError('v3+ protos are not allowed in %s' % p) + # Validate v2-only paths. + for p in V2_ONLY_PATHS: + if str(proto_path).startswith(p): + raise ValidationError('v3+ protos are not allowed in %s' % p) # Validate a list of proto paths. def validate_proto_paths(proto_paths): - error_msgs = [] - for proto_path in proto_paths: - try: - validate_proto_path(proto_path) - except ValidationError as e: - error_msgs.append('Invalid .proto location [%s]: %s' % (proto_path, e)) - return error_msgs + error_msgs = [] + for proto_path in proto_paths: + try: + validate_proto_path(proto_path) + except ValidationError as e: + error_msgs.append('Invalid .proto location [%s]: %s' % (proto_path, e)) + return error_msgs if __name__ == '__main__': - api_root = 'api/envoy' - api_protos = pathlib.Path(api_root).rglob('*.proto') - error_msgs = validate_proto_paths(p.relative_to(api_root) for p in api_protos) - if error_msgs: - for m in error_msgs: - print(m) - sys.exit(1) - sys.exit(0) + api_root = 'api/envoy' + api_protos = pathlib.Path(api_root).rglob('*.proto') + error_msgs = validate_proto_paths(p.relative_to(api_root) for p in api_protos) + if error_msgs: + for m in error_msgs: + print(m) + sys.exit(1) + sys.exit(0) diff --git a/tools/api_boost/api_boost.py b/tools/api_boost/api_boost.py index 112ffaa8e61b..b9b3bad7c2d8 100755 --- a/tools/api_boost/api_boost.py +++ b/tools/api_boost/api_boost.py @@ -29,34 +29,34 @@ # Obtain the directory containing a path prefix, e.g. ./foo/bar.txt is ./foo, # ./foo/ba is ./foo, ./foo/bar/ is ./foo/bar. def prefix_directory(path_prefix): - return path_prefix if os.path.isdir(path_prefix) else os.path.dirname(path_prefix) + return path_prefix if os.path.isdir(path_prefix) else os.path.dirname(path_prefix) # Update a C++ file to the latest API. def api_boost_file(llvm_include_path, debug_log, path): - print('Processing %s' % path) - if 'API_NO_BOOST_FILE' in pathlib.Path(path).read_text(): + print('Processing %s' % path) + if 'API_NO_BOOST_FILE' in pathlib.Path(path).read_text(): + if debug_log: + print('Not boosting %s due to API_NO_BOOST_FILE\n' % path) + return None + # Run the booster + try: + result = sp.run([ + './bazel-bin/external/envoy_dev/clang_tools/api_booster/api_booster', + '--extra-arg-before=-xc++', + '--extra-arg=-isystem%s' % llvm_include_path, '--extra-arg=-Wno-undefined-internal', + '--extra-arg=-Wno-old-style-cast', path + ], + capture_output=True, + check=True) + except sp.CalledProcessError as e: + print('api_booster failure for %s: %s %s' % (path, e, e.stderr.decode('utf-8'))) + raise if debug_log: - print('Not boosting %s due to API_NO_BOOST_FILE\n' % path) - return None - # Run the booster - try: - result = sp.run([ - './bazel-bin/external/envoy_dev/clang_tools/api_booster/api_booster', - '--extra-arg-before=-xc++', - '--extra-arg=-isystem%s' % llvm_include_path, '--extra-arg=-Wno-undefined-internal', - '--extra-arg=-Wno-old-style-cast', path - ], - capture_output=True, - check=True) - except sp.CalledProcessError as e: - print('api_booster failure for %s: %s %s' % (path, e, e.stderr.decode('utf-8'))) - raise - if debug_log: - print(result.stderr.decode('utf-8')) - - # Consume stdout containing the list of inferred API headers. - return sorted(set(result.stdout.decode('utf-8').splitlines())) + print(result.stderr.decode('utf-8')) + + # Consume stdout containing the list of inferred API headers. + return sorted(set(result.stdout.decode('utf-8').splitlines())) # Rewrite API includes to the inferred headers. Currently this is handled @@ -64,29 +64,29 @@ def api_boost_file(llvm_include_path, debug_log, path): # with this or with clang-include-fixer, but it's pretty simply to handle as done # below, we have more control over special casing as well, so ¯\_(ツ)_/¯. def rewrite_includes(args): - path, api_includes = args - # Files with API_NO_BOOST_FILE will have None returned by api_boost_file. - if api_includes is None: - return - # We just dump the inferred API header includes at the start of the #includes - # in the file and remove all the present API header includes. This does not - # match Envoy style; we rely on later invocations of fix_format.sh to take - # care of this alignment. - output_lines = [] - include_lines = ['#include "%s"' % f for f in api_includes] - input_text = pathlib.Path(path).read_text() - for line in input_text.splitlines(): - if include_lines and line.startswith('#include'): - output_lines.extend(include_lines) - include_lines = None - # Exclude API includes, except for a special case related to v2alpha - # ext_authz; this is needed to include the service descriptor in the build - # and is a hack that will go away when we remove v2. - if re.match(API_INCLUDE_REGEX, line) and 'envoy/service/auth/v2alpha' not in line: - continue - output_lines.append(line) - # Rewrite file. - pathlib.Path(path).write_text('\n'.join(output_lines) + '\n') + path, api_includes = args + # Files with API_NO_BOOST_FILE will have None returned by api_boost_file. + if api_includes is None: + return + # We just dump the inferred API header includes at the start of the #includes + # in the file and remove all the present API header includes. This does not + # match Envoy style; we rely on later invocations of fix_format.sh to take + # care of this alignment. + output_lines = [] + include_lines = ['#include "%s"' % f for f in api_includes] + input_text = pathlib.Path(path).read_text() + for line in input_text.splitlines(): + if include_lines and line.startswith('#include'): + output_lines.extend(include_lines) + include_lines = None + # Exclude API includes, except for a special case related to v2alpha + # ext_authz; this is needed to include the service descriptor in the build + # and is a hack that will go away when we remove v2. + if re.match(API_INCLUDE_REGEX, line) and 'envoy/service/auth/v2alpha' not in line: + continue + output_lines.append(line) + # Rewrite file. + pathlib.Path(path).write_text('\n'.join(output_lines) + '\n') # Update the Envoy source tree the latest API. @@ -95,105 +95,106 @@ def api_boost_tree(target_paths, build_api_booster=False, debug_log=False, sequential=False): - dep_build_targets = ['//%s/...' % prefix_directory(prefix) for prefix in target_paths] - - # Optional setup of state. We need the compilation database and api_booster - # tool in place before we can start boosting. - if generate_compilation_database: - print('Building compilation database for %s' % dep_build_targets) - sp.run(['./tools/gen_compilation_database.py', '--include_headers'] + dep_build_targets, - check=True) - - if build_api_booster: - # Similar to gen_compilation_database.py, we only need the cc_library for - # setup. The long term fix for this is in - # https://github.com/bazelbuild/bazel/issues/9578. - # - # Figure out some cc_libraries that cover most of our external deps. This is - # the same logic as in gen_compilation_database.py. - query = 'kind(cc_library, {})'.format(' union '.join(dep_build_targets)) - dep_lib_build_targets = sp.check_output(['bazel', 'query', query]).decode().splitlines() - # We also need some misc. stuff such as test binaries for setup of benchmark - # dep. - query = 'attr("tags", "compilation_db_dep", {})'.format(' union '.join(dep_build_targets)) - dep_lib_build_targets.extend(sp.check_output(['bazel', 'query', query]).decode().splitlines()) - extra_api_booster_args = [] - if debug_log: - extra_api_booster_args.append('--copt=-DENABLE_DEBUG_LOG') - - # Slightly easier to debug when we build api_booster on its own. - sp.run([ - 'bazel', - 'build', - '--strip=always', - '@envoy_dev//clang_tools/api_booster', - ] + BAZEL_BUILD_OPTIONS + extra_api_booster_args, - check=True) - sp.run([ - 'bazel', - 'build', - '--strip=always', - ] + BAZEL_BUILD_OPTIONS + dep_lib_build_targets, - check=True) - - # Figure out where the LLVM include path is. We need to provide this - # explicitly as the api_booster is built inside the Bazel cache and doesn't - # know about this path. - # TODO(htuch): this is fragile and depends on Clang version, should figure out - # a cleaner approach. - llvm_include_path = os.path.join( - sp.check_output([os.getenv('LLVM_CONFIG'), '--libdir']).decode().rstrip(), - 'clang/11.0.1/include') - - # Determine the files in the target dirs eligible for API boosting, based on - # known files in the compilation database. - file_paths = set([]) - for entry in json.loads(pathlib.Path('compile_commands.json').read_text()): - file_path = entry['file'] - if any(file_path.startswith(prefix) for prefix in target_paths): - file_paths.add(file_path) - # Ensure a determinstic ordering if we are going to process sequentially. - if sequential: - file_paths = sorted(file_paths) - - # The API boosting is file local, so this is trivially parallelizable, use - # multiprocessing pool with default worker pool sized to cpu_count(), since - # this is CPU bound. - try: - with mp.Pool(processes=1 if sequential else None) as p: - # We need multiple phases, to ensure that any dependency on files being modified - # in one thread on consumed transitive headers on the other thread isn't an - # issue. This also ensures that we complete all analysis error free before - # any mutation takes place. - # TODO(htuch): we should move to run-clang-tidy.py once the headers fixups - # are Clang-based. - api_includes = p.map(functools.partial(api_boost_file, llvm_include_path, debug_log), - file_paths) - # Apply Clang replacements before header fixups, since the replacements - # are all relative to the original file. - for prefix_dir in set(map(prefix_directory, target_paths)): - sp.run(['clang-apply-replacements', prefix_dir], check=True) - # Fixup headers. - p.map(rewrite_includes, zip(file_paths, api_includes)) - finally: - # Cleanup any stray **/*.clang-replacements.yaml. - for prefix in target_paths: - clang_replacements = pathlib.Path( - prefix_directory(prefix)).glob('**/*.clang-replacements.yaml') - for path in clang_replacements: - path.unlink() + dep_build_targets = ['//%s/...' % prefix_directory(prefix) for prefix in target_paths] + + # Optional setup of state. We need the compilation database and api_booster + # tool in place before we can start boosting. + if generate_compilation_database: + print('Building compilation database for %s' % dep_build_targets) + sp.run(['./tools/gen_compilation_database.py', '--include_headers'] + dep_build_targets, + check=True) + + if build_api_booster: + # Similar to gen_compilation_database.py, we only need the cc_library for + # setup. The long term fix for this is in + # https://github.com/bazelbuild/bazel/issues/9578. + # + # Figure out some cc_libraries that cover most of our external deps. This is + # the same logic as in gen_compilation_database.py. + query = 'kind(cc_library, {})'.format(' union '.join(dep_build_targets)) + dep_lib_build_targets = sp.check_output(['bazel', 'query', query]).decode().splitlines() + # We also need some misc. stuff such as test binaries for setup of benchmark + # dep. + query = 'attr("tags", "compilation_db_dep", {})'.format(' union '.join(dep_build_targets)) + dep_lib_build_targets.extend( + sp.check_output(['bazel', 'query', query]).decode().splitlines()) + extra_api_booster_args = [] + if debug_log: + extra_api_booster_args.append('--copt=-DENABLE_DEBUG_LOG') + + # Slightly easier to debug when we build api_booster on its own. + sp.run([ + 'bazel', + 'build', + '--strip=always', + '@envoy_dev//clang_tools/api_booster', + ] + BAZEL_BUILD_OPTIONS + extra_api_booster_args, + check=True) + sp.run([ + 'bazel', + 'build', + '--strip=always', + ] + BAZEL_BUILD_OPTIONS + dep_lib_build_targets, + check=True) + + # Figure out where the LLVM include path is. We need to provide this + # explicitly as the api_booster is built inside the Bazel cache and doesn't + # know about this path. + # TODO(htuch): this is fragile and depends on Clang version, should figure out + # a cleaner approach. + llvm_include_path = os.path.join( + sp.check_output([os.getenv('LLVM_CONFIG'), '--libdir']).decode().rstrip(), + 'clang/11.0.1/include') + + # Determine the files in the target dirs eligible for API boosting, based on + # known files in the compilation database. + file_paths = set([]) + for entry in json.loads(pathlib.Path('compile_commands.json').read_text()): + file_path = entry['file'] + if any(file_path.startswith(prefix) for prefix in target_paths): + file_paths.add(file_path) + # Ensure a determinstic ordering if we are going to process sequentially. + if sequential: + file_paths = sorted(file_paths) + + # The API boosting is file local, so this is trivially parallelizable, use + # multiprocessing pool with default worker pool sized to cpu_count(), since + # this is CPU bound. + try: + with mp.Pool(processes=1 if sequential else None) as p: + # We need multiple phases, to ensure that any dependency on files being modified + # in one thread on consumed transitive headers on the other thread isn't an + # issue. This also ensures that we complete all analysis error free before + # any mutation takes place. + # TODO(htuch): we should move to run-clang-tidy.py once the headers fixups + # are Clang-based. + api_includes = p.map(functools.partial(api_boost_file, llvm_include_path, debug_log), + file_paths) + # Apply Clang replacements before header fixups, since the replacements + # are all relative to the original file. + for prefix_dir in set(map(prefix_directory, target_paths)): + sp.run(['clang-apply-replacements', prefix_dir], check=True) + # Fixup headers. + p.map(rewrite_includes, zip(file_paths, api_includes)) + finally: + # Cleanup any stray **/*.clang-replacements.yaml. + for prefix in target_paths: + clang_replacements = pathlib.Path( + prefix_directory(prefix)).glob('**/*.clang-replacements.yaml') + for path in clang_replacements: + path.unlink() if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Update Envoy tree to the latest API') - parser.add_argument('--generate_compilation_database', action='store_true') - parser.add_argument('--build_api_booster', action='store_true') - parser.add_argument('--debug_log', action='store_true') - parser.add_argument('--sequential', action='store_true') - parser.add_argument('paths', nargs='*', default=['source', 'test', 'include']) - args = parser.parse_args() - api_boost_tree(args.paths, - generate_compilation_database=args.generate_compilation_database, - build_api_booster=args.build_api_booster, - debug_log=args.debug_log, - sequential=args.sequential) + parser = argparse.ArgumentParser(description='Update Envoy tree to the latest API') + parser.add_argument('--generate_compilation_database', action='store_true') + parser.add_argument('--build_api_booster', action='store_true') + parser.add_argument('--debug_log', action='store_true') + parser.add_argument('--sequential', action='store_true') + parser.add_argument('paths', nargs='*', default=['source', 'test', 'include']) + args = parser.parse_args() + api_boost_tree(args.paths, + generate_compilation_database=args.generate_compilation_database, + build_api_booster=args.build_api_booster, + debug_log=args.debug_log, + sequential=args.sequential) diff --git a/tools/api_boost/api_boost_test.py b/tools/api_boost/api_boost_test.py index f4ff381e72ea..3c8a3f6e4a33 100755 --- a/tools/api_boost/api_boost_test.py +++ b/tools/api_boost/api_boost_test.py @@ -35,56 +35,58 @@ def diff(some_path, other_path): - result = subprocess.run(['diff', '-u', some_path, other_path], capture_output=True) - if result.returncode == 0: - return None - return result.stdout.decode('utf-8') + result.stderr.decode('utf-8') + result = subprocess.run(['diff', '-u', some_path, other_path], capture_output=True) + if result.returncode == 0: + return None + return result.stdout.decode('utf-8') + result.stderr.decode('utf-8') if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Golden C++ source tests for api_boost.py') - parser.add_argument('tests', nargs='*') - args = parser.parse_args() - - # Accumulated error messages. - logging.basicConfig(format='%(message)s') - messages = [] - - def should_run_test(test_name): - return len(args.tests) == 0 or test_name in args.tests - - # Run API booster against test artifacts in a directory relative to workspace. - # We use a temporary copy as the API booster does in-place rewriting. - with tempfile.TemporaryDirectory(dir=pathlib.Path.cwd()) as path: - # Setup temporary tree. - shutil.copy(os.path.join(TESTDATA_PATH, 'BUILD'), path) - for test in TESTS: - if should_run_test(test.name): - shutil.copy(os.path.join(TESTDATA_PATH, test.name + '.cc'), path) - else: - # Place an empty file to make Bazel happy. - pathlib.Path(path, test.name + '.cc').write_text('') - - # Run API booster. - relpath_to_testdata = str(pathlib.Path(path).relative_to(pathlib.Path.cwd())) - api_boost.api_boost_tree([ - os.path.join(relpath_to_testdata, test.name) for test in TESTS if should_run_test(test.name) - ], - generate_compilation_database=True, - build_api_booster=True, - debug_log=True, - sequential=True) - - # Validate output against golden files. - for test in TESTS: - if should_run_test(test.name): - delta = diff(os.path.join(TESTDATA_PATH, test.name + '.cc.gold'), - os.path.join(path, test.name + '.cc')) - if delta is not None: - messages.append('Non-empty diff for %s (%s):\n%s\n' % - (test.name, test.description, delta)) - - if len(messages) > 0: - logging.error('FAILED:\n{}'.format('\n'.join(messages))) - sys.exit(1) - logging.warning('PASS') + parser = argparse.ArgumentParser(description='Golden C++ source tests for api_boost.py') + parser.add_argument('tests', nargs='*') + args = parser.parse_args() + + # Accumulated error messages. + logging.basicConfig(format='%(message)s') + messages = [] + + def should_run_test(test_name): + return len(args.tests) == 0 or test_name in args.tests + + # Run API booster against test artifacts in a directory relative to workspace. + # We use a temporary copy as the API booster does in-place rewriting. + with tempfile.TemporaryDirectory(dir=pathlib.Path.cwd()) as path: + # Setup temporary tree. + shutil.copy(os.path.join(TESTDATA_PATH, 'BUILD'), path) + for test in TESTS: + if should_run_test(test.name): + shutil.copy(os.path.join(TESTDATA_PATH, test.name + '.cc'), path) + else: + # Place an empty file to make Bazel happy. + pathlib.Path(path, test.name + '.cc').write_text('') + + # Run API booster. + relpath_to_testdata = str(pathlib.Path(path).relative_to(pathlib.Path.cwd())) + api_boost.api_boost_tree([ + os.path.join(relpath_to_testdata, test.name) + for test in TESTS + if should_run_test(test.name) + ], + generate_compilation_database=True, + build_api_booster=True, + debug_log=True, + sequential=True) + + # Validate output against golden files. + for test in TESTS: + if should_run_test(test.name): + delta = diff(os.path.join(TESTDATA_PATH, test.name + '.cc.gold'), + os.path.join(path, test.name + '.cc')) + if delta is not None: + messages.append('Non-empty diff for %s (%s):\n%s\n' % + (test.name, test.description, delta)) + + if len(messages) > 0: + logging.error('FAILED:\n{}'.format('\n'.join(messages))) + sys.exit(1) + logging.warning('PASS') diff --git a/tools/api_proto_plugin/annotations.py b/tools/api_proto_plugin/annotations.py index 8ffd4d5b9597..ca55b9e14b82 100644 --- a/tools/api_proto_plugin/annotations.py +++ b/tools/api_proto_plugin/annotations.py @@ -50,11 +50,11 @@ class AnnotationError(Exception): - """Base error class for the annotations module.""" + """Base error class for the annotations module.""" def extract_annotations(s, inherited_annotations=None): - """Extract annotations map from a given comment string. + """Extract annotations map from a given comment string. Args: s: string that may contains annotations. @@ -64,21 +64,21 @@ def extract_annotations(s, inherited_annotations=None): Returns: Annotation map. """ - annotations = { - k: v for k, v in (inherited_annotations or {}).items() if k in INHERITED_ANNOTATIONS - } - # Extract annotations. - groups = re.findall(ANNOTATION_REGEX, s) - for group in groups: - annotation = group[0] - if annotation not in VALID_ANNOTATIONS: - raise AnnotationError('Unknown annotation: %s' % annotation) - annotations[group[0]] = group[1].lstrip() - return annotations + annotations = { + k: v for k, v in (inherited_annotations or {}).items() if k in INHERITED_ANNOTATIONS + } + # Extract annotations. + groups = re.findall(ANNOTATION_REGEX, s) + for group in groups: + annotation = group[0] + if annotation not in VALID_ANNOTATIONS: + raise AnnotationError('Unknown annotation: %s' % annotation) + annotations[group[0]] = group[1].lstrip() + return annotations def xform_annotation(s, annotation_xforms): - """Return transformed string with annotation transformers. + """Return transformed string with annotation transformers. The annotation will be replaced with the new value returned by the transformer. If the transformer returns None, then the annotation will be removed. @@ -91,29 +91,29 @@ def xform_annotation(s, annotation_xforms): Returns: transformed string. """ - present_annotations = set() - - def xform(match): - annotation, content, trailing = match.groups() - present_annotations.add(annotation) - annotation_xform = annotation_xforms.get(annotation) - if annotation_xform: - value = annotation_xform(annotation) - return '[#%s: %s]%s' % (annotation, value, trailing) if value is not None else '' - else: - return match.group(0) - - def append(s, annotation, content): - return '%s [#%s: %s]\n' % (s, annotation, content) - - xformed = re.sub(ANNOTATION_REGEX, xform, s) - for annotation, xform in sorted(annotation_xforms.items()): - if annotation not in present_annotations: - value = xform(None) - if value is not None: - xformed = append(xformed, annotation, value) - return xformed + present_annotations = set() + + def xform(match): + annotation, content, trailing = match.groups() + present_annotations.add(annotation) + annotation_xform = annotation_xforms.get(annotation) + if annotation_xform: + value = annotation_xform(annotation) + return '[#%s: %s]%s' % (annotation, value, trailing) if value is not None else '' + else: + return match.group(0) + + def append(s, annotation, content): + return '%s [#%s: %s]\n' % (s, annotation, content) + + xformed = re.sub(ANNOTATION_REGEX, xform, s) + for annotation, xform in sorted(annotation_xforms.items()): + if annotation not in present_annotations: + value = xform(None) + if value is not None: + xformed = append(xformed, annotation, value) + return xformed def without_annotations(s): - return re.sub(ANNOTATION_REGEX, '', s) + return re.sub(ANNOTATION_REGEX, '', s) diff --git a/tools/api_proto_plugin/plugin.py b/tools/api_proto_plugin/plugin.py index 794d0f2d8bb9..f9935ca62f23 100644 --- a/tools/api_proto_plugin/plugin.py +++ b/tools/api_proto_plugin/plugin.py @@ -30,13 +30,13 @@ def direct_output_descriptor(output_suffix, visitor, want_params=False): - return OutputDescriptor(output_suffix, visitor, (lambda x, _: x) if want_params else lambda x: x, - want_params) + return OutputDescriptor(output_suffix, visitor, + (lambda x, _: x) if want_params else lambda x: x, want_params) # TODO(phlax): make this into a class def plugin(output_descriptors): - """Protoc plugin entry point. + """Protoc plugin entry point. This defines protoc plugin and manages the stdin -> stdout flow. An api_proto_plugin is defined by the provided visitor. @@ -48,46 +48,47 @@ def plugin(output_descriptors): Args: output_descriptors: a list of OutputDescriptors. """ - request = plugin_pb2.CodeGeneratorRequest() - request.ParseFromString(sys.stdin.buffer.read()) - response = plugin_pb2.CodeGeneratorResponse() - cprofile_enabled = os.getenv('CPROFILE_ENABLED') + request = plugin_pb2.CodeGeneratorRequest() + request.ParseFromString(sys.stdin.buffer.read()) + response = plugin_pb2.CodeGeneratorResponse() + cprofile_enabled = os.getenv('CPROFILE_ENABLED') - # We use request.file_to_generate rather than request.file_proto here since we - # are invoked inside a Bazel aspect, each node in the DAG will be visited once - # by the aspect and we only want to generate docs for the current node. - for file_to_generate in request.file_to_generate: - # Find the FileDescriptorProto for the file we actually are generating. - file_proto = [pf for pf in request.proto_file if pf.name == file_to_generate][0] - if cprofile_enabled: - pr = cProfile.Profile() - pr.enable() - for od in output_descriptors: - f = response.file.add() - f.name = file_proto.name + od.output_suffix - # Don't run API proto plugins on things like WKT types etc. - if not file_proto.package.startswith('envoy.'): - continue - if request.HasField("parameter") and od.want_params: - params = dict(param.split('=') for param in request.parameter.split(',')) - xformed_proto = od.xform(file_proto, params) - visitor_factory = od.visitor_factory(params) - else: - xformed_proto = od.xform(file_proto) - visitor_factory = od.visitor_factory() - f.content = traverse.traverse_file(xformed_proto, visitor_factory) if xformed_proto else '' - if cprofile_enabled: - pr.disable() - stats_stream = io.StringIO() - ps = pstats.Stats(pr, - stream=stats_stream).sort_stats(os.getenv('CPROFILE_SORTBY', 'cumulative')) - stats_file = response.file.add() - stats_file.name = file_proto.name + '.profile' - ps.print_stats() - stats_file.content = stats_stream.getvalue() - # Also include the original FileDescriptorProto as text proto, this is - # useful when debugging. - descriptor_file = response.file.add() - descriptor_file.name = file_proto.name + ".descriptor.proto" - descriptor_file.content = str(file_proto) - sys.stdout.buffer.write(response.SerializeToString()) + # We use request.file_to_generate rather than request.file_proto here since we + # are invoked inside a Bazel aspect, each node in the DAG will be visited once + # by the aspect and we only want to generate docs for the current node. + for file_to_generate in request.file_to_generate: + # Find the FileDescriptorProto for the file we actually are generating. + file_proto = [pf for pf in request.proto_file if pf.name == file_to_generate][0] + if cprofile_enabled: + pr = cProfile.Profile() + pr.enable() + for od in output_descriptors: + f = response.file.add() + f.name = file_proto.name + od.output_suffix + # Don't run API proto plugins on things like WKT types etc. + if not file_proto.package.startswith('envoy.'): + continue + if request.HasField("parameter") and od.want_params: + params = dict(param.split('=') for param in request.parameter.split(',')) + xformed_proto = od.xform(file_proto, params) + visitor_factory = od.visitor_factory(params) + else: + xformed_proto = od.xform(file_proto) + visitor_factory = od.visitor_factory() + f.content = traverse.traverse_file(xformed_proto, + visitor_factory) if xformed_proto else '' + if cprofile_enabled: + pr.disable() + stats_stream = io.StringIO() + ps = pstats.Stats(pr, stream=stats_stream).sort_stats( + os.getenv('CPROFILE_SORTBY', 'cumulative')) + stats_file = response.file.add() + stats_file.name = file_proto.name + '.profile' + ps.print_stats() + stats_file.content = stats_stream.getvalue() + # Also include the original FileDescriptorProto as text proto, this is + # useful when debugging. + descriptor_file = response.file.add() + descriptor_file.name = file_proto.name + ".descriptor.proto" + descriptor_file.content = str(file_proto) + sys.stdout.buffer.write(response.SerializeToString()) diff --git a/tools/api_proto_plugin/traverse.py b/tools/api_proto_plugin/traverse.py index eb53b0be3d63..0ddc2f131b14 100644 --- a/tools/api_proto_plugin/traverse.py +++ b/tools/api_proto_plugin/traverse.py @@ -4,7 +4,7 @@ def traverse_service(type_context, service_proto, visitor): - """Traverse a service definition. + """Traverse a service definition. Args: type_context: type_context.TypeContext for service type. @@ -14,11 +14,11 @@ def traverse_service(type_context, service_proto, visitor): Returns: Plugin specific output. """ - return visitor.visit_service(service_proto, type_context) + return visitor.visit_service(service_proto, type_context) def traverse_enum(type_context, enum_proto, visitor): - """Traverse an enum definition. + """Traverse an enum definition. Args: type_context: type_context.TypeContext for enum type. @@ -28,11 +28,11 @@ def traverse_enum(type_context, enum_proto, visitor): Returns: Plugin specific output. """ - return visitor.visit_enum(enum_proto, type_context) + return visitor.visit_enum(enum_proto, type_context) def traverse_message(type_context, msg_proto, visitor): - """Traverse a message definition. + """Traverse a message definition. Args: type_context: type_context.TypeContext for message type. @@ -42,28 +42,30 @@ def traverse_message(type_context, msg_proto, visitor): Returns: Plugin specific output. """ - # We need to do some extra work to recover the map type annotation from the - # synthesized messages. - type_context.map_typenames = { - '%s.%s' % (type_context.name, nested_msg.name): (nested_msg.field[0], nested_msg.field[1]) - for nested_msg in msg_proto.nested_type - if nested_msg.options.map_entry - } - nested_msgs = [ - traverse_message( - type_context.extend_nested_message(index, nested_msg.name, nested_msg.options.deprecated), - nested_msg, visitor) for index, nested_msg in enumerate(msg_proto.nested_type) - ] - nested_enums = [ - traverse_enum( - type_context.extend_nested_enum(index, nested_enum.name, nested_enum.options.deprecated), - nested_enum, visitor) for index, nested_enum in enumerate(msg_proto.enum_type) - ] - return visitor.visit_message(msg_proto, type_context, nested_msgs, nested_enums) + # We need to do some extra work to recover the map type annotation from the + # synthesized messages. + type_context.map_typenames = { + '%s.%s' % (type_context.name, nested_msg.name): (nested_msg.field[0], nested_msg.field[1]) + for nested_msg in msg_proto.nested_type + if nested_msg.options.map_entry + } + nested_msgs = [ + traverse_message( + type_context.extend_nested_message(index, nested_msg.name, + nested_msg.options.deprecated), nested_msg, visitor) + for index, nested_msg in enumerate(msg_proto.nested_type) + ] + nested_enums = [ + traverse_enum( + type_context.extend_nested_enum(index, nested_enum.name, + nested_enum.options.deprecated), nested_enum, visitor) + for index, nested_enum in enumerate(msg_proto.enum_type) + ] + return visitor.visit_message(msg_proto, type_context, nested_msgs, nested_enums) def traverse_file(file_proto, visitor): - """Traverse a proto file definition. + """Traverse a proto file definition. Args: file_proto: FileDescriptorProto for file. @@ -72,18 +74,19 @@ def traverse_file(file_proto, visitor): Returns: Plugin specific output. """ - source_code_info = type_context.SourceCodeInfo(file_proto.name, file_proto.source_code_info) - package_type_context = type_context.TypeContext(source_code_info, file_proto.package) - services = [ - traverse_service(package_type_context.extend_service(index, service.name), service, visitor) - for index, service in enumerate(file_proto.service) - ] - msgs = [ - traverse_message(package_type_context.extend_message(index, msg.name, msg.options.deprecated), - msg, visitor) for index, msg in enumerate(file_proto.message_type) - ] - enums = [ - traverse_enum(package_type_context.extend_enum(index, enum.name, enum.options.deprecated), - enum, visitor) for index, enum in enumerate(file_proto.enum_type) - ] - return visitor.visit_file(file_proto, package_type_context, services, msgs, enums) + source_code_info = type_context.SourceCodeInfo(file_proto.name, file_proto.source_code_info) + package_type_context = type_context.TypeContext(source_code_info, file_proto.package) + services = [ + traverse_service(package_type_context.extend_service(index, service.name), service, visitor) + for index, service in enumerate(file_proto.service) + ] + msgs = [ + traverse_message( + package_type_context.extend_message(index, msg.name, msg.options.deprecated), msg, + visitor) for index, msg in enumerate(file_proto.message_type) + ] + enums = [ + traverse_enum(package_type_context.extend_enum(index, enum.name, enum.options.deprecated), + enum, visitor) for index, enum in enumerate(file_proto.enum_type) + ] + return visitor.visit_file(file_proto, package_type_context, services, msgs, enums) diff --git a/tools/api_proto_plugin/type_context.py b/tools/api_proto_plugin/type_context.py index 6beea9163770..70fcc2ec31ea 100644 --- a/tools/api_proto_plugin/type_context.py +++ b/tools/api_proto_plugin/type_context.py @@ -6,15 +6,15 @@ class Comment(object): - """Wrapper for proto source comments.""" + """Wrapper for proto source comments.""" - def __init__(self, comment, file_level_annotations=None): - self.raw = comment - self.file_level_annotations = file_level_annotations - self.annotations = annotations.extract_annotations(self.raw, file_level_annotations) + def __init__(self, comment, file_level_annotations=None): + self.raw = comment + self.file_level_annotations = file_level_annotations + self.annotations = annotations.extract_annotations(self.raw, file_level_annotations) - def get_comment_with_transforms(self, annotation_xforms): - """Return transformed comment with annotation transformers. + def get_comment_with_transforms(self, annotation_xforms): + """Return transformed comment with annotation transformers. Args: annotation_xforms: a dict of transformers for annotations in leading comment. @@ -22,50 +22,51 @@ def get_comment_with_transforms(self, annotation_xforms): Returns: transformed Comment object. """ - return Comment(annotations.xform_annotation(self.raw, annotation_xforms), - self.file_level_annotations) + return Comment(annotations.xform_annotation(self.raw, annotation_xforms), + self.file_level_annotations) class SourceCodeInfo(object): - """Wrapper for SourceCodeInfo proto.""" - - def __init__(self, name, source_code_info): - self.name = name - self.proto = source_code_info - # Map from path to SourceCodeInfo.Location - self._locations = {str(location.path): location for location in self.proto.location} - self._file_level_comments = None - self._file_level_annotations = None - - @property - def file_level_comments(self): - """Obtain inferred file level comment.""" - if self._file_level_comments: - return self._file_level_comments - comments = [] - # We find the earliest detached comment by first finding the maximum start - # line for any location and then scanning for any earlier locations with - # detached comments. - earliest_detached_comment = max(location.span[0] for location in self.proto.location) + 1 - for location in self.proto.location: - if location.leading_detached_comments and location.span[0] < earliest_detached_comment: - comments = location.leading_detached_comments - earliest_detached_comment = location.span[0] - self._file_level_comments = comments - return comments - - @property - def file_level_annotations(self): - """Obtain inferred file level annotations.""" - if self._file_level_annotations: - return self._file_level_annotations - self._file_level_annotations = dict( - sum([list(annotations.extract_annotations(c).items()) for c in self.file_level_comments], - [])) - return self._file_level_annotations - - def location_path_lookup(self, path): - """Lookup SourceCodeInfo.Location by path in SourceCodeInfo. + """Wrapper for SourceCodeInfo proto.""" + + def __init__(self, name, source_code_info): + self.name = name + self.proto = source_code_info + # Map from path to SourceCodeInfo.Location + self._locations = {str(location.path): location for location in self.proto.location} + self._file_level_comments = None + self._file_level_annotations = None + + @property + def file_level_comments(self): + """Obtain inferred file level comment.""" + if self._file_level_comments: + return self._file_level_comments + comments = [] + # We find the earliest detached comment by first finding the maximum start + # line for any location and then scanning for any earlier locations with + # detached comments. + earliest_detached_comment = max(location.span[0] for location in self.proto.location) + 1 + for location in self.proto.location: + if location.leading_detached_comments and location.span[0] < earliest_detached_comment: + comments = location.leading_detached_comments + earliest_detached_comment = location.span[0] + self._file_level_comments = comments + return comments + + @property + def file_level_annotations(self): + """Obtain inferred file level annotations.""" + if self._file_level_annotations: + return self._file_level_annotations + self._file_level_annotations = dict( + sum([ + list(annotations.extract_annotations(c).items()) for c in self.file_level_comments + ], [])) + return self._file_level_annotations + + def location_path_lookup(self, path): + """Lookup SourceCodeInfo.Location by path in SourceCodeInfo. Args: path: a list of path indexes as per @@ -74,12 +75,12 @@ def location_path_lookup(self, path): Returns: SourceCodeInfo.Location object if found, otherwise None. """ - return self._locations.get(str(path), None) + return self._locations.get(str(path), None) - # TODO(htuch): consider integrating comment lookup with overall - # FileDescriptorProto, perhaps via two passes. - def leading_comment_path_lookup(self, path): - """Lookup leading comment by path in SourceCodeInfo. + # TODO(htuch): consider integrating comment lookup with overall + # FileDescriptorProto, perhaps via two passes. + def leading_comment_path_lookup(self, path): + """Lookup leading comment by path in SourceCodeInfo. Args: path: a list of path indexes as per @@ -88,13 +89,13 @@ def leading_comment_path_lookup(self, path): Returns: Comment object. """ - location = self.location_path_lookup(path) - if location is not None: - return Comment(location.leading_comments, self.file_level_annotations) - return Comment('') + location = self.location_path_lookup(path) + if location is not None: + return Comment(location.leading_comments, self.file_level_annotations) + return Comment('') - def leading_detached_comments_path_lookup(self, path): - """Lookup leading detached comments by path in SourceCodeInfo. + def leading_detached_comments_path_lookup(self, path): + """Lookup leading detached comments by path in SourceCodeInfo. Args: path: a list of path indexes as per @@ -103,13 +104,13 @@ def leading_detached_comments_path_lookup(self, path): Returns: List of detached comment strings. """ - location = self.location_path_lookup(path) - if location is not None and location.leading_detached_comments != self.file_level_comments: - return location.leading_detached_comments - return [] + location = self.location_path_lookup(path) + if location is not None and location.leading_detached_comments != self.file_level_comments: + return location.leading_detached_comments + return [] - def trailing_comment_path_lookup(self, path): - """Lookup trailing comment by path in SourceCodeInfo. + def trailing_comment_path_lookup(self, path): + """Lookup trailing comment by path in SourceCodeInfo. Args: path: a list of path indexes as per @@ -118,158 +119,158 @@ def trailing_comment_path_lookup(self, path): Returns: Raw detached comment string """ - location = self.location_path_lookup(path) - if location is not None: - return location.trailing_comments - return '' + location = self.location_path_lookup(path) + if location is not None: + return location.trailing_comments + return '' class TypeContext(object): - """Contextual information for a message/field. + """Contextual information for a message/field. Provides information around namespaces and enclosing types for fields and nested messages/enums. """ - def __init__(self, source_code_info, name): - # SourceCodeInfo as per - # https://github.com/google/protobuf/blob/a08b03d4c00a5793b88b494f672513f6ad46a681/src/google/protobuf/descriptor.proto. - self.source_code_info = source_code_info - # path: a list of path indexes as per - # https://github.com/google/protobuf/blob/a08b03d4c00a5793b88b494f672513f6ad46a681/src/google/protobuf/descriptor.proto#L717. - # Extended as nested objects are traversed. - self.path = [] - # Message/enum/field name. Extended as nested objects are traversed. - self.name = name - # Map from type name to the correct type annotation string, e.g. from - # ".envoy.api.v2.Foo.Bar" to "map". This is lost during - # proto synthesis and is dynamically recovered in traverse_message. - self.map_typenames = {} - # Map from a message's oneof index to the fields sharing a oneof. - self.oneof_fields = {} - # Map from a message's oneof index to the name of oneof. - self.oneof_names = {} - # Map from a message's oneof index to the "required" bool property. - self.oneof_required = {} - self.type_name = 'file' - self.deprecated = False - - def _extend(self, path, type_name, name, deprecated=False): - if not self.name: - extended_name = name - else: - extended_name = '%s.%s' % (self.name, name) - extended = TypeContext(self.source_code_info, extended_name) - extended.path = self.path + path - extended.type_name = type_name - extended.map_typenames = self.map_typenames.copy() - extended.oneof_fields = self.oneof_fields.copy() - extended.oneof_names = self.oneof_names.copy() - extended.oneof_required = self.oneof_required.copy() - extended.deprecated = self.deprecated or deprecated - return extended - - def extend_message(self, index, name, deprecated): - """Extend type context with a message. + def __init__(self, source_code_info, name): + # SourceCodeInfo as per + # https://github.com/google/protobuf/blob/a08b03d4c00a5793b88b494f672513f6ad46a681/src/google/protobuf/descriptor.proto. + self.source_code_info = source_code_info + # path: a list of path indexes as per + # https://github.com/google/protobuf/blob/a08b03d4c00a5793b88b494f672513f6ad46a681/src/google/protobuf/descriptor.proto#L717. + # Extended as nested objects are traversed. + self.path = [] + # Message/enum/field name. Extended as nested objects are traversed. + self.name = name + # Map from type name to the correct type annotation string, e.g. from + # ".envoy.api.v2.Foo.Bar" to "map". This is lost during + # proto synthesis and is dynamically recovered in traverse_message. + self.map_typenames = {} + # Map from a message's oneof index to the fields sharing a oneof. + self.oneof_fields = {} + # Map from a message's oneof index to the name of oneof. + self.oneof_names = {} + # Map from a message's oneof index to the "required" bool property. + self.oneof_required = {} + self.type_name = 'file' + self.deprecated = False + + def _extend(self, path, type_name, name, deprecated=False): + if not self.name: + extended_name = name + else: + extended_name = '%s.%s' % (self.name, name) + extended = TypeContext(self.source_code_info, extended_name) + extended.path = self.path + path + extended.type_name = type_name + extended.map_typenames = self.map_typenames.copy() + extended.oneof_fields = self.oneof_fields.copy() + extended.oneof_names = self.oneof_names.copy() + extended.oneof_required = self.oneof_required.copy() + extended.deprecated = self.deprecated or deprecated + return extended + + def extend_message(self, index, name, deprecated): + """Extend type context with a message. Args: index: message index in file. name: message name. deprecated: is the message depreacted? """ - return self._extend([4, index], 'message', name, deprecated) + return self._extend([4, index], 'message', name, deprecated) - def extend_nested_message(self, index, name, deprecated): - """Extend type context with a nested message. + def extend_nested_message(self, index, name, deprecated): + """Extend type context with a nested message. Args: index: nested message index in message. name: message name. deprecated: is the message depreacted? """ - return self._extend([3, index], 'message', name, deprecated) + return self._extend([3, index], 'message', name, deprecated) - def extend_field(self, index, name): - """Extend type context with a field. + def extend_field(self, index, name): + """Extend type context with a field. Args: index: field index in message. name: field name. """ - return self._extend([2, index], 'field', name) + return self._extend([2, index], 'field', name) - def extend_enum(self, index, name, deprecated): - """Extend type context with an enum. + def extend_enum(self, index, name, deprecated): + """Extend type context with an enum. Args: index: enum index in file. name: enum name. deprecated: is the message depreacted? """ - return self._extend([5, index], 'enum', name, deprecated) + return self._extend([5, index], 'enum', name, deprecated) - def extend_service(self, index, name): - """Extend type context with a service. + def extend_service(self, index, name): + """Extend type context with a service. Args: index: service index in file. name: service name. """ - return self._extend([6, index], 'service', name) + return self._extend([6, index], 'service', name) - def extend_nested_enum(self, index, name, deprecated): - """Extend type context with a nested enum. + def extend_nested_enum(self, index, name, deprecated): + """Extend type context with a nested enum. Args: index: enum index in message. name: enum name. deprecated: is the message depreacted? """ - return self._extend([4, index], 'enum', name, deprecated) + return self._extend([4, index], 'enum', name, deprecated) - def extend_enum_value(self, index, name): - """Extend type context with an enum enum. + def extend_enum_value(self, index, name): + """Extend type context with an enum enum. Args: index: enum value index in enum. name: value name. """ - return self._extend([2, index], 'enum_value', name) + return self._extend([2, index], 'enum_value', name) - def extend_oneof(self, index, name): - """Extend type context with an oneof declaration. + def extend_oneof(self, index, name): + """Extend type context with an oneof declaration. Args: index: oneof index in oneof_decl. name: oneof name. """ - return self._extend([8, index], 'oneof', name) + return self._extend([8, index], 'oneof', name) - def extend_method(self, index, name): - """Extend type context with a service method declaration. + def extend_method(self, index, name): + """Extend type context with a service method declaration. Args: index: method index in service. name: method name. """ - return self._extend([2, index], 'method', name) - - @property - def location(self): - """SourceCodeInfo.Location for type context.""" - return self.source_code_info.location_path_lookup(self.path) - - @property - def leading_comment(self): - """Leading comment for type context.""" - return self.source_code_info.leading_comment_path_lookup(self.path) - - @property - def leading_detached_comments(self): - """Leading detached comments for type context.""" - return self.source_code_info.leading_detached_comments_path_lookup(self.path) - - @property - def trailing_comment(self): - """Trailing comment for type context.""" - return self.source_code_info.trailing_comment_path_lookup(self.path) + return self._extend([2, index], 'method', name) + + @property + def location(self): + """SourceCodeInfo.Location for type context.""" + return self.source_code_info.location_path_lookup(self.path) + + @property + def leading_comment(self): + """Leading comment for type context.""" + return self.source_code_info.leading_comment_path_lookup(self.path) + + @property + def leading_detached_comments(self): + """Leading detached comments for type context.""" + return self.source_code_info.leading_detached_comments_path_lookup(self.path) + + @property + def trailing_comment(self): + """Trailing comment for type context.""" + return self.source_code_info.trailing_comment_path_lookup(self.path) diff --git a/tools/api_proto_plugin/utils.py b/tools/api_proto_plugin/utils.py index fa7cfce3e432..256fb793353b 100644 --- a/tools/api_proto_plugin/utils.py +++ b/tools/api_proto_plugin/utils.py @@ -2,7 +2,7 @@ def proto_file_canonical_from_label(label): - """Compute path from API root to a proto file from a Bazel proto label. + """Compute path from API root to a proto file from a Bazel proto label. Args: label: Bazel source proto label string. @@ -11,12 +11,12 @@ def proto_file_canonical_from_label(label): A string with the path, e.g. for @envoy_api//envoy/type/matcher:metadata.proto this would be envoy/type/matcher/matcher.proto. """ - assert (label.startswith('@envoy_api_canonical//')) - return label[len('@envoy_api_canonical//'):].replace(':', '/') + assert (label.startswith('@envoy_api_canonical//')) + return label[len('@envoy_api_canonical//'):].replace(':', '/') def bazel_bin_path_for_output_artifact(label, suffix, root=''): - """Find the location in bazel-bin/ for an api_proto_plugin output file. + """Find the location in bazel-bin/ for an api_proto_plugin output file. Args: label: Bazel source proto label string. @@ -26,6 +26,6 @@ def bazel_bin_path_for_output_artifact(label, suffix, root=''): Returns: Path in bazel-bin/external/envoy_api_canonical for label output with given suffix. """ - proto_file_path = proto_file_canonical_from_label(label) - return os.path.join(root, 'bazel-bin/external/envoy_api_canonical', - os.path.dirname(proto_file_path), 'pkg', proto_file_path + suffix) + proto_file_path = proto_file_canonical_from_label(label) + return os.path.join(root, 'bazel-bin/external/envoy_api_canonical', + os.path.dirname(proto_file_path), 'pkg', proto_file_path + suffix) diff --git a/tools/api_proto_plugin/visitor.py b/tools/api_proto_plugin/visitor.py index 38f54f36cd24..b831320dbcfc 100644 --- a/tools/api_proto_plugin/visitor.py +++ b/tools/api_proto_plugin/visitor.py @@ -2,10 +2,10 @@ class Visitor(object): - """Abstract visitor interface for api_proto_plugin implementation.""" + """Abstract visitor interface for api_proto_plugin implementation.""" - def visit_service(self, service_proto, type_context): - """Visit a service definition. + def visit_service(self, service_proto, type_context): + """Visit a service definition. Args: service_proto: ServiceDescriptorProto for service. @@ -14,10 +14,10 @@ def visit_service(self, service_proto, type_context): Returns: Plugin specific output. """ - pass + pass - def visit_enum(self, enum_proto, type_context): - """Visit an enum definition. + def visit_enum(self, enum_proto, type_context): + """Visit an enum definition. Args: enum_proto: EnumDescriptorProto for enum. @@ -26,10 +26,10 @@ def visit_enum(self, enum_proto, type_context): Returns: Plugin specific output. """ - pass + pass - def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): - """Visit a message definition. + def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): + """Visit a message definition. Args: msg_proto: DescriptorProto for message. @@ -40,10 +40,10 @@ def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): Returns: Plugin specific output. """ - pass + pass - def visit_file(self, file_proto, type_context, services, msgs, enums): - """Visit a proto file definition. + def visit_file(self, file_proto, type_context, services, msgs, enums): + """Visit a proto file definition. Args: file_proto: FileDescriptorProto for file. @@ -55,4 +55,4 @@ def visit_file(self, file_proto, type_context, services, msgs, enums): Returns: Plugin specific output. """ - pass + pass diff --git a/tools/api_versioning/generate_api_version_header.py b/tools/api_versioning/generate_api_version_header.py index 47859e2679fb..ea661825d21d 100644 --- a/tools/api_versioning/generate_api_version_header.py +++ b/tools/api_versioning/generate_api_version_header.py @@ -21,7 +21,7 @@ def GenerateHeaderFile(input_path): - """Generates a c++ header file containing the api_version variable with the + """Generates a c++ header file containing the api_version variable with the correct value. Args: @@ -30,26 +30,26 @@ def GenerateHeaderFile(input_path): Returns: the header file contents. """ - lines = pathlib.Path(input_path).read_text().splitlines() - assert (len(lines) == 1) + lines = pathlib.Path(input_path).read_text().splitlines() + assert (len(lines) == 1) - # Mapping each field to int verifies it is a valid version - version = ApiVersion(*map(int, lines[0].split('.'))) - oldest_version = ComputeOldestApiVersion(version) + # Mapping each field to int verifies it is a valid version + version = ApiVersion(*map(int, lines[0].split('.'))) + oldest_version = ComputeOldestApiVersion(version) - header_file_contents = FILE_TEMPLATE.substitute({ - 'major': version.major, - 'minor': version.minor, - 'patch': version.patch, - 'oldest_major': oldest_version.major, - 'oldest_minor': oldest_version.minor, - 'oldest_patch': oldest_version.patch - }) - return header_file_contents + header_file_contents = FILE_TEMPLATE.substitute({ + 'major': version.major, + 'minor': version.minor, + 'patch': version.patch, + 'oldest_major': oldest_version.major, + 'oldest_minor': oldest_version.minor, + 'oldest_patch': oldest_version.patch + }) + return header_file_contents def ComputeOldestApiVersion(current_version: ApiVersion): - """Computest the oldest API version the client supports. According to the + """Computest the oldest API version the client supports. According to the specification (see: api/API_VERSIONING.md), Envoy supports up to 2 most recent minor versions. Therefore if the latest API version "X.Y.Z", Envoy's oldest API version is "X.Y-1.0". Note that the major number is always the @@ -63,11 +63,11 @@ def ComputeOldestApiVersion(current_version: ApiVersion): Returns: the oldest supported API version. """ - return ApiVersion(current_version.major, max(current_version.minor - 1, 0), 0) + return ApiVersion(current_version.major, max(current_version.minor - 1, 0), 0) if __name__ == '__main__': - input_path = sys.argv[1] - output = GenerateHeaderFile(input_path) - # Print output to stdout - print(output) + input_path = sys.argv[1] + output = GenerateHeaderFile(input_path) + # Print output to stdout + print(output) diff --git a/tools/api_versioning/generate_api_version_header_test.py b/tools/api_versioning/generate_api_version_header_test.py index 709ae4f090cf..3edf886fc587 100644 --- a/tools/api_versioning/generate_api_version_header_test.py +++ b/tools/api_versioning/generate_api_version_header_test.py @@ -10,7 +10,7 @@ class GenerateApiVersionHeaderTest(unittest.TestCase): - EXPECTED_TEMPLATE = string.Template("""#pragma once + EXPECTED_TEMPLATE = string.Template("""#pragma once #include "common/version/api_version_struct.h" namespace Envoy { @@ -20,72 +20,73 @@ class GenerateApiVersionHeaderTest(unittest.TestCase): } // namespace Envoy""") - def setUp(self): - # Using mkstemp instead of NamedTemporaryFile because in windows NT or later - # the created NamedTemporaryFile cannot be reopened again (see comment in: - # https://docs.python.org/3.9/library/tempfile.html#tempfile.NamedTemporaryFile) - self._temp_fd, self._temp_fname = tempfile.mkstemp(text=True) - - def tearDown(self): - # Close and delete the temp file. - os.close(self._temp_fd) - pathlib.Path(self._temp_fname).unlink() - - # General success pattern when valid file contents is detected. - def SuccessfulTestTemplate(self, output_string, current_version: ApiVersion, - oldest_version: ApiVersion): - pathlib.Path(self._temp_fname).write_text(output_string) - - # Read the string from the file, and parse the version. - output = generate_api_version_header.GenerateHeaderFile(self._temp_fname) - expected_output = GenerateApiVersionHeaderTest.EXPECTED_TEMPLATE.substitute({ - 'major': current_version.major, - 'minor': current_version.minor, - 'patch': current_version.patch, - 'oldest_major': oldest_version.major, - 'oldest_minor': oldest_version.minor, - 'oldest_patch': oldest_version.patch - }) - self.assertEqual(expected_output, output) - - # General failure pattern when invalid file contents is detected. - def FailedTestTemplate(self, output_string, assertion_error_type): - pathlib.Path(self._temp_fname).write_text(output_string) - - # Read the string from the file, and expect version parsing to fail. - with self.assertRaises(assertion_error_type, - msg='The call to GenerateHeaderFile should have thrown an exception'): - generate_api_version_header.GenerateHeaderFile(self._temp_fname) - - def test_valid_version(self): - self.SuccessfulTestTemplate('1.2.3', ApiVersion(1, 2, 3), ApiVersion(1, 1, 0)) - - def test_valid_version_newline(self): - self.SuccessfulTestTemplate('3.2.1\n', ApiVersion(3, 2, 1), ApiVersion(3, 1, 0)) - - def test_invalid_version_string(self): - self.FailedTestTemplate('1.2.abc3', ValueError) - - def test_invalid_version_partial(self): - self.FailedTestTemplate('1.2.', ValueError) - - def test_empty_file(self): - # Not writing anything to the file - self.FailedTestTemplate('', AssertionError) - - def test_invalid_multiple_lines(self): - self.FailedTestTemplate('1.2.3\n1.2.3', AssertionError) - - def test_valid_oldest_api_version(self): - expected_latest_oldest_pairs = [(ApiVersion(3, 2, 2), ApiVersion(3, 1, 0)), - (ApiVersion(4, 5, 30), ApiVersion(4, 4, 0)), - (ApiVersion(1, 1, 5), ApiVersion(1, 0, 0)), - (ApiVersion(2, 0, 3), ApiVersion(2, 0, 0))] - - for latest_version, expected_oldest_version in expected_latest_oldest_pairs: - self.assertEqual(expected_oldest_version, - generate_api_version_header.ComputeOldestApiVersion(latest_version)) + def setUp(self): + # Using mkstemp instead of NamedTemporaryFile because in windows NT or later + # the created NamedTemporaryFile cannot be reopened again (see comment in: + # https://docs.python.org/3.9/library/tempfile.html#tempfile.NamedTemporaryFile) + self._temp_fd, self._temp_fname = tempfile.mkstemp(text=True) + + def tearDown(self): + # Close and delete the temp file. + os.close(self._temp_fd) + pathlib.Path(self._temp_fname).unlink() + + # General success pattern when valid file contents is detected. + def SuccessfulTestTemplate(self, output_string, current_version: ApiVersion, + oldest_version: ApiVersion): + pathlib.Path(self._temp_fname).write_text(output_string) + + # Read the string from the file, and parse the version. + output = generate_api_version_header.GenerateHeaderFile(self._temp_fname) + expected_output = GenerateApiVersionHeaderTest.EXPECTED_TEMPLATE.substitute({ + 'major': current_version.major, + 'minor': current_version.minor, + 'patch': current_version.patch, + 'oldest_major': oldest_version.major, + 'oldest_minor': oldest_version.minor, + 'oldest_patch': oldest_version.patch + }) + self.assertEqual(expected_output, output) + + # General failure pattern when invalid file contents is detected. + def FailedTestTemplate(self, output_string, assertion_error_type): + pathlib.Path(self._temp_fname).write_text(output_string) + + # Read the string from the file, and expect version parsing to fail. + with self.assertRaises( + assertion_error_type, + msg='The call to GenerateHeaderFile should have thrown an exception'): + generate_api_version_header.GenerateHeaderFile(self._temp_fname) + + def test_valid_version(self): + self.SuccessfulTestTemplate('1.2.3', ApiVersion(1, 2, 3), ApiVersion(1, 1, 0)) + + def test_valid_version_newline(self): + self.SuccessfulTestTemplate('3.2.1\n', ApiVersion(3, 2, 1), ApiVersion(3, 1, 0)) + + def test_invalid_version_string(self): + self.FailedTestTemplate('1.2.abc3', ValueError) + + def test_invalid_version_partial(self): + self.FailedTestTemplate('1.2.', ValueError) + + def test_empty_file(self): + # Not writing anything to the file + self.FailedTestTemplate('', AssertionError) + + def test_invalid_multiple_lines(self): + self.FailedTestTemplate('1.2.3\n1.2.3', AssertionError) + + def test_valid_oldest_api_version(self): + expected_latest_oldest_pairs = [(ApiVersion(3, 2, 2), ApiVersion(3, 1, 0)), + (ApiVersion(4, 5, 30), ApiVersion(4, 4, 0)), + (ApiVersion(1, 1, 5), ApiVersion(1, 0, 0)), + (ApiVersion(2, 0, 3), ApiVersion(2, 0, 0))] + + for latest_version, expected_oldest_version in expected_latest_oldest_pairs: + self.assertEqual(expected_oldest_version, + generate_api_version_header.ComputeOldestApiVersion(latest_version)) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tools/build_profile.py b/tools/build_profile.py index dc4fd46d9915..5c47d7a5895f 100755 --- a/tools/build_profile.py +++ b/tools/build_profile.py @@ -10,16 +10,16 @@ def print_profile(f): - prev_cmd = None - prev_timestamp = None - for line in f: - sr = re.match('\++ (\d+\.\d+) (.*)', line) - if sr: - timestamp, cmd = sr.groups() - if prev_cmd: - print('%.2f %s' % (float(timestamp) - float(prev_timestamp), prev_cmd)) - prev_timestamp, prev_cmd = timestamp, cmd + prev_cmd = None + prev_timestamp = None + for line in f: + sr = re.match('\++ (\d+\.\d+) (.*)', line) + if sr: + timestamp, cmd = sr.groups() + if prev_cmd: + print('%.2f %s' % (float(timestamp) - float(prev_timestamp), prev_cmd)) + prev_timestamp, prev_cmd = timestamp, cmd if __name__ == '__main__': - print_profile(sys.stdin) + print_profile(sys.stdin) diff --git a/tools/code_format/.style.yapf b/tools/code_format/.style.yapf index b07cb79714a6..932f9100bd8d 100644 --- a/tools/code_format/.style.yapf +++ b/tools/code_format/.style.yapf @@ -2,5 +2,5 @@ # TODO: Look into enforcing single vs double quote. [style] based_on_style=Google -indent_width=2 +indent_width=4 column_limit=100 diff --git a/tools/code_format/check_format.py b/tools/code_format/check_format.py index 294534131883..71463a74d823 100755 --- a/tools/code_format/check_format.py +++ b/tools/code_format/check_format.py @@ -252,810 +252,828 @@ class FormatChecker: - def __init__(self, args): - self.operation_type = args.operation_type - self.target_path = args.target_path - self.api_prefix = args.api_prefix - self.api_shadow_root = args.api_shadow_prefix - self.envoy_build_rule_check = not args.skip_envoy_build_rule_check - self.namespace_check = args.namespace_check - self.namespace_check_excluded_paths = args.namespace_check_excluded_paths + [ - "./tools/api_boost/testdata/", - "./tools/clang_tools/", - ] - self.build_fixer_check_excluded_paths = args.build_fixer_check_excluded_paths + [ - "./bazel/external/", - "./bazel/toolchains/", - "./bazel/BUILD", - "./tools/clang_tools", - ] - self.include_dir_order = args.include_dir_order - - # Map a line transformation function across each line of a file, - # writing the result lines as requested. - # If there is a clang format nesting or mismatch error, return the first occurrence - def evaluate_lines(self, path, line_xform, write=True): - error_message = None - format_flag = True - output_lines = [] - for line_number, line in enumerate(self.read_lines(path)): - if line.find("// clang-format off") != -1: - if not format_flag and error_message is None: - error_message = "%s:%d: %s" % (path, line_number + 1, "clang-format nested off") - format_flag = False - if line.find("// clang-format on") != -1: - if format_flag and error_message is None: - error_message = "%s:%d: %s" % (path, line_number + 1, "clang-format nested on") + def __init__(self, args): + self.operation_type = args.operation_type + self.target_path = args.target_path + self.api_prefix = args.api_prefix + self.api_shadow_root = args.api_shadow_prefix + self.envoy_build_rule_check = not args.skip_envoy_build_rule_check + self.namespace_check = args.namespace_check + self.namespace_check_excluded_paths = args.namespace_check_excluded_paths + [ + "./tools/api_boost/testdata/", + "./tools/clang_tools/", + ] + self.build_fixer_check_excluded_paths = args.build_fixer_check_excluded_paths + [ + "./bazel/external/", + "./bazel/toolchains/", + "./bazel/BUILD", + "./tools/clang_tools", + ] + self.include_dir_order = args.include_dir_order + + # Map a line transformation function across each line of a file, + # writing the result lines as requested. + # If there is a clang format nesting or mismatch error, return the first occurrence + def evaluate_lines(self, path, line_xform, write=True): + error_message = None format_flag = True - if format_flag: - output_lines.append(line_xform(line, line_number)) - else: - output_lines.append(line) - # We used to use fileinput in the older Python 2.7 script, but this doesn't do - # inplace mode and UTF-8 in Python 3, so doing it the manual way. - if write: - pathlib.Path(path).write_text('\n'.join(output_lines), encoding='utf-8') - if not format_flag and error_message is None: - error_message = "%s:%d: %s" % (path, line_number + 1, "clang-format remains off") - return error_message - - # Obtain all the lines in a given file. - def read_lines(self, path): - return self.read_file(path).split('\n') - - # Read a UTF-8 encoded file as a str. - def read_file(self, path): - return pathlib.Path(path).read_text(encoding='utf-8') - - # look_path searches for the given executable in all directories in PATH - # environment variable. If it cannot be found, empty string is returned. - def look_path(self, executable): - return shutil.which(executable) or '' - - # path_exists checks whether the given path exists. This function assumes that - # the path is absolute and evaluates environment variables. - def path_exists(self, executable): - return os.path.exists(os.path.expandvars(executable)) - - # executable_by_others checks whether the given path has execute permission for - # others. - def executable_by_others(self, executable): - st = os.stat(os.path.expandvars(executable)) - return bool(st.st_mode & stat.S_IXOTH) - - # Check whether all needed external tools (clang-format, buildifier, buildozer) are - # available. - def check_tools(self): - error_messages = [] - - clang_format_abs_path = self.look_path(CLANG_FORMAT_PATH) - if clang_format_abs_path: - if not self.executable_by_others(clang_format_abs_path): - error_messages.append("command {} exists, but cannot be executed by other " - "users".format(CLANG_FORMAT_PATH)) - else: - error_messages.append( - "Command {} not found. If you have clang-format in version 10.x.x " - "installed, but the binary name is different or it's not available in " - "PATH, please use CLANG_FORMAT environment variable to specify the path. " - "Examples:\n" - " export CLANG_FORMAT=clang-format-11.0.1\n" - " export CLANG_FORMAT=/opt/bin/clang-format-11\n" - " export CLANG_FORMAT=/usr/local/opt/llvm@11/bin/clang-format".format( - CLANG_FORMAT_PATH)) - - def check_bazel_tool(name, path, var): - bazel_tool_abs_path = self.look_path(path) - if bazel_tool_abs_path: - if not self.executable_by_others(bazel_tool_abs_path): - error_messages.append("command {} exists, but cannot be executed by other " - "users".format(path)) - elif self.path_exists(path): - if not self.executable_by_others(path): - error_messages.append("command {} exists, but cannot be executed by other " - "users".format(path)) - else: - - error_messages.append("Command {} not found. If you have {} installed, but the binary " - "name is different or it's not available in $GOPATH/bin, please use " - "{} environment variable to specify the path. Example:\n" - " export {}=`which {}`\n" - "If you don't have {} installed, you can install it by:\n" - " go get -u github.com/bazelbuild/buildtools/{}".format( - path, name, var, var, name, name, name)) - - check_bazel_tool('buildifier', BUILDIFIER_PATH, 'BUILDIFIER_BIN') - check_bazel_tool('buildozer', BUILDOZER_PATH, 'BUILDOZER_BIN') - - return error_messages - - def check_namespace(self, file_path): - for excluded_path in self.namespace_check_excluded_paths: - if file_path.startswith(excluded_path): + output_lines = [] + for line_number, line in enumerate(self.read_lines(path)): + if line.find("// clang-format off") != -1: + if not format_flag and error_message is None: + error_message = "%s:%d: %s" % (path, line_number + 1, "clang-format nested off") + format_flag = False + if line.find("// clang-format on") != -1: + if format_flag and error_message is None: + error_message = "%s:%d: %s" % (path, line_number + 1, "clang-format nested on") + format_flag = True + if format_flag: + output_lines.append(line_xform(line, line_number)) + else: + output_lines.append(line) + # We used to use fileinput in the older Python 2.7 script, but this doesn't do + # inplace mode and UTF-8 in Python 3, so doing it the manual way. + if write: + pathlib.Path(path).write_text('\n'.join(output_lines), encoding='utf-8') + if not format_flag and error_message is None: + error_message = "%s:%d: %s" % (path, line_number + 1, "clang-format remains off") + return error_message + + # Obtain all the lines in a given file. + def read_lines(self, path): + return self.read_file(path).split('\n') + + # Read a UTF-8 encoded file as a str. + def read_file(self, path): + return pathlib.Path(path).read_text(encoding='utf-8') + + # look_path searches for the given executable in all directories in PATH + # environment variable. If it cannot be found, empty string is returned. + def look_path(self, executable): + return shutil.which(executable) or '' + + # path_exists checks whether the given path exists. This function assumes that + # the path is absolute and evaluates environment variables. + def path_exists(self, executable): + return os.path.exists(os.path.expandvars(executable)) + + # executable_by_others checks whether the given path has execute permission for + # others. + def executable_by_others(self, executable): + st = os.stat(os.path.expandvars(executable)) + return bool(st.st_mode & stat.S_IXOTH) + + # Check whether all needed external tools (clang-format, buildifier, buildozer) are + # available. + def check_tools(self): + error_messages = [] + + clang_format_abs_path = self.look_path(CLANG_FORMAT_PATH) + if clang_format_abs_path: + if not self.executable_by_others(clang_format_abs_path): + error_messages.append("command {} exists, but cannot be executed by other " + "users".format(CLANG_FORMAT_PATH)) + else: + error_messages.append( + "Command {} not found. If you have clang-format in version 10.x.x " + "installed, but the binary name is different or it's not available in " + "PATH, please use CLANG_FORMAT environment variable to specify the path. " + "Examples:\n" + " export CLANG_FORMAT=clang-format-11.0.1\n" + " export CLANG_FORMAT=/opt/bin/clang-format-11\n" + " export CLANG_FORMAT=/usr/local/opt/llvm@11/bin/clang-format".format( + CLANG_FORMAT_PATH)) + + def check_bazel_tool(name, path, var): + bazel_tool_abs_path = self.look_path(path) + if bazel_tool_abs_path: + if not self.executable_by_others(bazel_tool_abs_path): + error_messages.append("command {} exists, but cannot be executed by other " + "users".format(path)) + elif self.path_exists(path): + if not self.executable_by_others(path): + error_messages.append("command {} exists, but cannot be executed by other " + "users".format(path)) + else: + + error_messages.append( + "Command {} not found. If you have {} installed, but the binary " + "name is different or it's not available in $GOPATH/bin, please use " + "{} environment variable to specify the path. Example:\n" + " export {}=`which {}`\n" + "If you don't have {} installed, you can install it by:\n" + " go get -u github.com/bazelbuild/buildtools/{}".format( + path, name, var, var, name, name, name)) + + check_bazel_tool('buildifier', BUILDIFIER_PATH, 'BUILDIFIER_BIN') + check_bazel_tool('buildozer', BUILDOZER_PATH, 'BUILDOZER_BIN') + + return error_messages + + def check_namespace(self, file_path): + for excluded_path in self.namespace_check_excluded_paths: + if file_path.startswith(excluded_path): + return [] + + nolint = "NOLINT(namespace-%s)" % self.namespace_check.lower() + text = self.read_file(file_path) + if not re.search("^\s*namespace\s+%s\s*{" % self.namespace_check, text, re.MULTILINE) and \ + not nolint in text: + return [ + "Unable to find %s namespace or %s for file: %s" % + (self.namespace_check, nolint, file_path) + ] return [] - nolint = "NOLINT(namespace-%s)" % self.namespace_check.lower() - text = self.read_file(file_path) - if not re.search("^\s*namespace\s+%s\s*{" % self.namespace_check, text, re.MULTILINE) and \ - not nolint in text: - return [ - "Unable to find %s namespace or %s for file: %s" % - (self.namespace_check, nolint, file_path) - ] - return [] - - def package_name_for_proto(self, file_path): - package_name = None - error_message = [] - result = PROTO_PACKAGE_REGEX.search(self.read_file(file_path)) - if result is not None and len(result.groups()) == 1: - package_name = result.group(1) - if package_name is None: - error_message = ["Unable to find package name for proto file: %s" % file_path] - - return [package_name, error_message] - - # To avoid breaking the Lyft import, we just check for path inclusion here. - def allow_listed_for_protobuf_deps(self, file_path): - return (file_path.endswith(PROTO_SUFFIX) or file_path.endswith(REPOSITORIES_BZL) or \ - any(path_segment in file_path for path_segment in GOOGLE_PROTOBUF_ALLOWLIST)) - - # Real-world time sources should not be instantiated in the source, except for a few - # specific cases. They should be passed down from where they are instantied to where - # they need to be used, e.g. through the ServerInstance, Dispatcher, or ClusterManager. - def allow_listed_for_realtime(self, file_path): - if file_path.endswith(".md"): - return True - return file_path in REAL_TIME_ALLOWLIST - - def allow_listed_for_register_factory(self, file_path): - if not file_path.startswith("./test/"): - return True - - return any(file_path.startswith(prefix) for prefix in REGISTER_FACTORY_TEST_ALLOWLIST) - - def allow_listed_for_serialize_as_string(self, file_path): - return file_path in SERIALIZE_AS_STRING_ALLOWLIST or file_path.endswith(DOCS_SUFFIX) - - def allow_listed_for_json_string_to_message(self, file_path): - return file_path in JSON_STRING_TO_MESSAGE_ALLOWLIST - - def allow_listed_for_histogram_si_suffix(self, name): - return name in HISTOGRAM_WITH_SI_SUFFIX_ALLOWLIST - - def allow_listed_for_std_regex(self, file_path): - return file_path.startswith("./test") or file_path in STD_REGEX_ALLOWLIST or file_path.endswith( - DOCS_SUFFIX) - - def allow_listed_for_grpc_init(self, file_path): - return file_path in GRPC_INIT_ALLOWLIST - - def allow_listed_for_unpack_to(self, file_path): - return file_path.startswith("./test") or file_path in [ - "./source/common/protobuf/utility.cc", "./source/common/protobuf/utility.h" - ] - - def deny_listed_for_exceptions(self, file_path): - # Returns true when it is a non test header file or the file_path is in DENYLIST or - # it is under tools/testdata subdirectory. - if file_path.endswith(DOCS_SUFFIX): - return False - - return (file_path.endswith('.h') and not file_path.startswith("./test/") and not file_path in EXCEPTION_ALLOWLIST) or file_path in EXCEPTION_DENYLIST \ - or self.is_in_subdir(file_path, 'tools/testdata') - - def allow_listed_for_build_urls(self, file_path): - return file_path in BUILD_URLS_ALLOWLIST - - def is_api_file(self, file_path): - return file_path.startswith(self.api_prefix) or file_path.startswith(self.api_shadow_root) - - def is_build_file(self, file_path): - basename = os.path.basename(file_path) - if basename in {"BUILD", "BUILD.bazel"} or basename.endswith(".BUILD"): - return True - return False - - def is_external_build_file(self, file_path): - return self.is_build_file(file_path) and (file_path.startswith("./bazel/external/") or - file_path.startswith("./tools/clang_tools")) - - def is_starlark_file(self, file_path): - return file_path.endswith(".bzl") - - def is_workspace_file(self, file_path): - return os.path.basename(file_path) == "WORKSPACE" - - def is_build_fixer_excluded_file(self, file_path): - for excluded_path in self.build_fixer_check_excluded_paths: - if file_path.startswith(excluded_path): - return True - return False - - def has_invalid_angle_bracket_directory(self, line): - if not line.startswith(INCLUDE_ANGLE): - return False - path = line[INCLUDE_ANGLE_LEN:] - slash = path.find("/") - if slash == -1: - return False - subdir = path[0:slash] - return subdir in SUBDIR_SET - - def check_current_release_notes(self, file_path, error_messages): - first_word_of_prior_line = '' - next_word_to_check = '' # first word after : - prior_line = '' - - def ends_with_period(prior_line): - if not prior_line: - return True # Don't punctuation-check empty lines. - if prior_line.endswith('.'): - return True # Actually ends with . - if prior_line.endswith('`') and REF_WITH_PUNCTUATION_REGEX.match(prior_line): - return True # The text in the :ref ends with a . - return False - - for line_number, line in enumerate(self.read_lines(file_path)): - - def report_error(message): - error_messages.append("%s:%d: %s" % (file_path, line_number + 1, message)) - - if VERSION_HISTORY_SECTION_NAME.match(line): - if line == "Deprecated": - # The deprecations section is last, and does not have enforced formatting. - break - - # Reset all parsing at the start of a section. + def package_name_for_proto(self, file_path): + package_name = None + error_message = [] + result = PROTO_PACKAGE_REGEX.search(self.read_file(file_path)) + if result is not None and len(result.groups()) == 1: + package_name = result.group(1) + if package_name is None: + error_message = ["Unable to find package name for proto file: %s" % file_path] + + return [package_name, error_message] + + # To avoid breaking the Lyft import, we just check for path inclusion here. + def allow_listed_for_protobuf_deps(self, file_path): + return (file_path.endswith(PROTO_SUFFIX) or file_path.endswith(REPOSITORIES_BZL) or \ + any(path_segment in file_path for path_segment in GOOGLE_PROTOBUF_ALLOWLIST)) + + # Real-world time sources should not be instantiated in the source, except for a few + # specific cases. They should be passed down from where they are instantied to where + # they need to be used, e.g. through the ServerInstance, Dispatcher, or ClusterManager. + def allow_listed_for_realtime(self, file_path): + if file_path.endswith(".md"): + return True + return file_path in REAL_TIME_ALLOWLIST + + def allow_listed_for_register_factory(self, file_path): + if not file_path.startswith("./test/"): + return True + + return any(file_path.startswith(prefix) for prefix in REGISTER_FACTORY_TEST_ALLOWLIST) + + def allow_listed_for_serialize_as_string(self, file_path): + return file_path in SERIALIZE_AS_STRING_ALLOWLIST or file_path.endswith(DOCS_SUFFIX) + + def allow_listed_for_json_string_to_message(self, file_path): + return file_path in JSON_STRING_TO_MESSAGE_ALLOWLIST + + def allow_listed_for_histogram_si_suffix(self, name): + return name in HISTOGRAM_WITH_SI_SUFFIX_ALLOWLIST + + def allow_listed_for_std_regex(self, file_path): + return file_path.startswith( + "./test") or file_path in STD_REGEX_ALLOWLIST or file_path.endswith(DOCS_SUFFIX) + + def allow_listed_for_grpc_init(self, file_path): + return file_path in GRPC_INIT_ALLOWLIST + + def allow_listed_for_unpack_to(self, file_path): + return file_path.startswith("./test") or file_path in [ + "./source/common/protobuf/utility.cc", "./source/common/protobuf/utility.h" + ] + + def deny_listed_for_exceptions(self, file_path): + # Returns true when it is a non test header file or the file_path is in DENYLIST or + # it is under tools/testdata subdirectory. + if file_path.endswith(DOCS_SUFFIX): + return False + + return (file_path.endswith('.h') and not file_path.startswith("./test/") and not file_path in EXCEPTION_ALLOWLIST) or file_path in EXCEPTION_DENYLIST \ + or self.is_in_subdir(file_path, 'tools/testdata') + + def allow_listed_for_build_urls(self, file_path): + return file_path in BUILD_URLS_ALLOWLIST + + def is_api_file(self, file_path): + return file_path.startswith(self.api_prefix) or file_path.startswith(self.api_shadow_root) + + def is_build_file(self, file_path): + basename = os.path.basename(file_path) + if basename in {"BUILD", "BUILD.bazel"} or basename.endswith(".BUILD"): + return True + return False + + def is_external_build_file(self, file_path): + return self.is_build_file(file_path) and (file_path.startswith("./bazel/external/") or + file_path.startswith("./tools/clang_tools")) + + def is_starlark_file(self, file_path): + return file_path.endswith(".bzl") + + def is_workspace_file(self, file_path): + return os.path.basename(file_path) == "WORKSPACE" + + def is_build_fixer_excluded_file(self, file_path): + for excluded_path in self.build_fixer_check_excluded_paths: + if file_path.startswith(excluded_path): + return True + return False + + def has_invalid_angle_bracket_directory(self, line): + if not line.startswith(INCLUDE_ANGLE): + return False + path = line[INCLUDE_ANGLE_LEN:] + slash = path.find("/") + if slash == -1: + return False + subdir = path[0:slash] + return subdir in SUBDIR_SET + + def check_current_release_notes(self, file_path, error_messages): first_word_of_prior_line = '' next_word_to_check = '' # first word after : prior_line = '' - invalid_reflink_match = INVALID_REFLINK.match(line) - if invalid_reflink_match: - report_error("Found text \" ref:\". This should probably be \" :ref:\"\n%s" % line) - - # make sure flags are surrounded by ``s - flag_match = RELOADABLE_FLAG_REGEX.match(line) - if flag_match: - if not flag_match.groups()[0].startswith(' `'): - report_error("Flag `%s` should be enclosed in a single set of back ticks" % - flag_match.groups()[1]) - - if line.startswith("* "): - if not ends_with_period(prior_line): - report_error("The following release note does not end with a '.'\n %s" % prior_line) - - match = VERSION_HISTORY_NEW_LINE_REGEX.match(line) - if not match: - report_error("Version history line malformed. " - "Does not match VERSION_HISTORY_NEW_LINE_REGEX in check_format.py\n %s\n" - "Please use messages in the form 'category: feature explanation.', " - "starting with a lower-cased letter and ending with a period." % line) - else: - first_word = match.groups()[0] - next_word = match.groups()[1] - # Do basic alphabetization checks of the first word on the line and the - # first word after the : - if first_word_of_prior_line and first_word_of_prior_line > first_word: - report_error( - "Version history not in alphabetical order (%s vs %s): please check placement of line\n %s. " - % (first_word_of_prior_line, first_word, line)) - if first_word_of_prior_line == first_word and next_word_to_check and next_word_to_check > next_word: - report_error( - "Version history not in alphabetical order (%s vs %s): please check placement of line\n %s. " - % (next_word_to_check, next_word, line)) - first_word_of_prior_line = first_word - next_word_to_check = next_word - - prior_line = line - elif not line: - # If we hit the end of this release note block block, check the prior line. - if not ends_with_period(prior_line): - report_error("The following release note does not end with a '.'\n %s" % prior_line) - prior_line = '' - elif prior_line: - prior_line += line - - def check_file_contents(self, file_path, checker): - error_messages = [] - - if file_path.endswith("version_history/current.rst"): - # Version file checking has enough special cased logic to merit its own checks. - # This only validates entries for the current release as very old release - # notes have a different format. - self.check_current_release_notes(file_path, error_messages) - - def check_format_errors(line, line_number): - - def report_error(message): - error_messages.append("%s:%d: %s" % (file_path, line_number + 1, message)) - - checker(line, file_path, report_error) - - evaluate_failure = self.evaluate_lines(file_path, check_format_errors, False) - if evaluate_failure is not None: - error_messages.append(evaluate_failure) - - return error_messages - - def fix_source_line(self, line, line_number): - # Strip double space after '.' This may prove overenthusiastic and need to - # be restricted to comments and metadata files but works for now. - line = re.sub(DOT_MULTI_SPACE_REGEX, ". ", line) - - if self.has_invalid_angle_bracket_directory(line): - line = line.replace("<", '"').replace(">", '"') - - # Fix incorrect protobuf namespace references. - for invalid_construct, valid_construct in PROTOBUF_TYPE_ERRORS.items(): - line = line.replace(invalid_construct, valid_construct) - - # Use recommended cpp stdlib - for invalid_construct, valid_construct in LIBCXX_REPLACEMENTS.items(): - line = line.replace(invalid_construct, valid_construct) - - # Fix code conventions violations. - for invalid_construct, valid_construct in CODE_CONVENTION_REPLACEMENTS.items(): - line = line.replace(invalid_construct, valid_construct) - - return line - - # We want to look for a call to condvar.waitFor, but there's no strong pattern - # to the variable name of the condvar. If we just look for ".waitFor" we'll also - # pick up time_system_.waitFor(...), and we don't want to return true for that - # pattern. But in that case there is a strong pattern of using time_system in - # various spellings as the variable name. - def has_cond_var_wait_for(self, line): - wait_for = line.find(".waitFor(") - if wait_for == -1: - return False - preceding = line[0:wait_for] - if preceding.endswith("time_system") or preceding.endswith("timeSystem()") or \ - preceding.endswith("time_system_"): - return False - return True - - # Determines whether the filename is either in the specified subdirectory, or - # at the top level. We consider files in the top level for the benefit of - # the check_format testcases in tools/testdata/check_format. - def is_in_subdir(self, filename, *subdirs): - # Skip this check for check_format's unit-tests. - if filename.count("/") <= 1: - return True - for subdir in subdirs: - if filename.startswith('./' + subdir + '/'): + def ends_with_period(prior_line): + if not prior_line: + return True # Don't punctuation-check empty lines. + if prior_line.endswith('.'): + return True # Actually ends with . + if prior_line.endswith('`') and REF_WITH_PUNCTUATION_REGEX.match(prior_line): + return True # The text in the :ref ends with a . + return False + + for line_number, line in enumerate(self.read_lines(file_path)): + + def report_error(message): + error_messages.append("%s:%d: %s" % (file_path, line_number + 1, message)) + + if VERSION_HISTORY_SECTION_NAME.match(line): + if line == "Deprecated": + # The deprecations section is last, and does not have enforced formatting. + break + + # Reset all parsing at the start of a section. + first_word_of_prior_line = '' + next_word_to_check = '' # first word after : + prior_line = '' + + invalid_reflink_match = INVALID_REFLINK.match(line) + if invalid_reflink_match: + report_error("Found text \" ref:\". This should probably be \" :ref:\"\n%s" % line) + + # make sure flags are surrounded by ``s + flag_match = RELOADABLE_FLAG_REGEX.match(line) + if flag_match: + if not flag_match.groups()[0].startswith(' `'): + report_error("Flag `%s` should be enclosed in a single set of back ticks" % + flag_match.groups()[1]) + + if line.startswith("* "): + if not ends_with_period(prior_line): + report_error("The following release note does not end with a '.'\n %s" % + prior_line) + + match = VERSION_HISTORY_NEW_LINE_REGEX.match(line) + if not match: + report_error( + "Version history line malformed. " + "Does not match VERSION_HISTORY_NEW_LINE_REGEX in check_format.py\n %s\n" + "Please use messages in the form 'category: feature explanation.', " + "starting with a lower-cased letter and ending with a period." % line) + else: + first_word = match.groups()[0] + next_word = match.groups()[1] + # Do basic alphabetization checks of the first word on the line and the + # first word after the : + if first_word_of_prior_line and first_word_of_prior_line > first_word: + report_error( + "Version history not in alphabetical order (%s vs %s): please check placement of line\n %s. " + % (first_word_of_prior_line, first_word, line)) + if first_word_of_prior_line == first_word and next_word_to_check and next_word_to_check > next_word: + report_error( + "Version history not in alphabetical order (%s vs %s): please check placement of line\n %s. " + % (next_word_to_check, next_word, line)) + first_word_of_prior_line = first_word + next_word_to_check = next_word + + prior_line = line + elif not line: + # If we hit the end of this release note block block, check the prior line. + if not ends_with_period(prior_line): + report_error("The following release note does not end with a '.'\n %s" % + prior_line) + prior_line = '' + elif prior_line: + prior_line += line + + def check_file_contents(self, file_path, checker): + error_messages = [] + + if file_path.endswith("version_history/current.rst"): + # Version file checking has enough special cased logic to merit its own checks. + # This only validates entries for the current release as very old release + # notes have a different format. + self.check_current_release_notes(file_path, error_messages) + + def check_format_errors(line, line_number): + + def report_error(message): + error_messages.append("%s:%d: %s" % (file_path, line_number + 1, message)) + + checker(line, file_path, report_error) + + evaluate_failure = self.evaluate_lines(file_path, check_format_errors, False) + if evaluate_failure is not None: + error_messages.append(evaluate_failure) + + return error_messages + + def fix_source_line(self, line, line_number): + # Strip double space after '.' This may prove overenthusiastic and need to + # be restricted to comments and metadata files but works for now. + line = re.sub(DOT_MULTI_SPACE_REGEX, ". ", line) + + if self.has_invalid_angle_bracket_directory(line): + line = line.replace("<", '"').replace(">", '"') + + # Fix incorrect protobuf namespace references. + for invalid_construct, valid_construct in PROTOBUF_TYPE_ERRORS.items(): + line = line.replace(invalid_construct, valid_construct) + + # Use recommended cpp stdlib + for invalid_construct, valid_construct in LIBCXX_REPLACEMENTS.items(): + line = line.replace(invalid_construct, valid_construct) + + # Fix code conventions violations. + for invalid_construct, valid_construct in CODE_CONVENTION_REPLACEMENTS.items(): + line = line.replace(invalid_construct, valid_construct) + + return line + + # We want to look for a call to condvar.waitFor, but there's no strong pattern + # to the variable name of the condvar. If we just look for ".waitFor" we'll also + # pick up time_system_.waitFor(...), and we don't want to return true for that + # pattern. But in that case there is a strong pattern of using time_system in + # various spellings as the variable name. + def has_cond_var_wait_for(self, line): + wait_for = line.find(".waitFor(") + if wait_for == -1: + return False + preceding = line[0:wait_for] + if preceding.endswith("time_system") or preceding.endswith("timeSystem()") or \ + preceding.endswith("time_system_"): + return False return True - return False - - # Determines if given token exists in line without leading or trailing token characters - # e.g. will return True for a line containing foo() but not foo_bar() or baz_foo - def token_in_line(self, token, line): - index = 0 - while True: - index = line.find(token, index) - # the following check has been changed from index < 1 to index < 0 because - # this function incorrectly returns false when the token in question is the - # first one in a line. The following line returns false when the token is present: - # (no leading whitespace) violating_symbol foo; - if index < 0: - break - if index == 0 or not (line[index - 1].isalnum() or line[index - 1] == '_'): - if index + len(token) >= len(line) or not (line[index + len(token)].isalnum() or - line[index + len(token)] == '_'): - return True - index = index + 1 - return False - - def check_source_line(self, line, file_path, report_error): - # Check fixable errors. These may have been fixed already. - if line.find(". ") != -1: - report_error("over-enthusiastic spaces") - if self.is_in_subdir(file_path, 'source', - 'include') and X_ENVOY_USED_DIRECTLY_REGEX.match(line): - report_error( - "Please do not use the raw literal x-envoy in source code. See Envoy::Http::PrefixValue." - ) - if self.has_invalid_angle_bracket_directory(line): - report_error("envoy includes should not have angle brackets") - for invalid_construct, valid_construct in PROTOBUF_TYPE_ERRORS.items(): - if invalid_construct in line: - report_error("incorrect protobuf type reference %s; " - "should be %s" % (invalid_construct, valid_construct)) - for invalid_construct, valid_construct in LIBCXX_REPLACEMENTS.items(): - if invalid_construct in line: - report_error("term %s should be replaced with standard library term %s" % - (invalid_construct, valid_construct)) - for invalid_construct, valid_construct in CODE_CONVENTION_REPLACEMENTS.items(): - if invalid_construct in line: - report_error("term %s should be replaced with preferred term %s" % - (invalid_construct, valid_construct)) - # Do not include the virtual_includes headers. - if re.search("#include.*/_virtual_includes/", line): - report_error("Don't include the virtual includes headers.") - - # Some errors cannot be fixed automatically, and actionable, consistent, - # navigable messages should be emitted to make it easy to find and fix - # the errors by hand. - if not self.allow_listed_for_protobuf_deps(file_path): - if '"google/protobuf' in line or "google::protobuf" in line: - report_error("unexpected direct dependency on google.protobuf, use " - "the definitions in common/protobuf/protobuf.h instead.") - if line.startswith("#include ") or line.startswith("#include or , switch to " - "Thread::MutexBasicLockable in source/common/common/thread.h") - if line.startswith("#include "): - # We don't check here for std::shared_timed_mutex because that may - # legitimately show up in comments, for example this one. - report_error("Don't use , use absl::Mutex for reader/writer locks.") - if not self.allow_listed_for_realtime(file_path) and not "NO_CHECK_FORMAT(real_time)" in line: - if "RealTimeSource" in line or \ - ("RealTimeSystem" in line and not "TestRealTimeSystem" in line) or \ - "std::chrono::system_clock::now" in line or "std::chrono::steady_clock::now" in line or \ - "std::this_thread::sleep_for" in line or self.has_cond_var_wait_for(line): - report_error("Don't reference real-world time sources from production code; use injection") - duration_arg = DURATION_VALUE_REGEX.search(line) - if duration_arg and duration_arg.group(1) != "0" and duration_arg.group(1) != "0.0": - # Matching duration(int-const or float-const) other than zero - report_error( - "Don't use ambiguous duration(value), use an explicit duration type, e.g. Event::TimeSystem::Milliseconds(value)" - ) - if not self.allow_listed_for_register_factory(file_path): - if "Registry::RegisterFactory<" in line or "REGISTER_FACTORY" in line: - report_error("Don't use Registry::RegisterFactory or REGISTER_FACTORY in tests, " - "use Registry::InjectFactory instead.") - if not self.allow_listed_for_unpack_to(file_path): - if "UnpackTo" in line: - report_error("Don't use UnpackTo() directly, use MessageUtil::unpackTo() instead") - # Check that we use the absl::Time library - if self.token_in_line("std::get_time", line): - if "test/" in file_path: - report_error("Don't use std::get_time; use TestUtility::parseTime in tests") - else: - report_error("Don't use std::get_time; use the injectable time system") - if self.token_in_line("std::put_time", line): - report_error("Don't use std::put_time; use absl::Time equivalent instead") - if self.token_in_line("gmtime", line): - report_error("Don't use gmtime; use absl::Time equivalent instead") - if self.token_in_line("mktime", line): - report_error("Don't use mktime; use absl::Time equivalent instead") - if self.token_in_line("localtime", line): - report_error("Don't use localtime; use absl::Time equivalent instead") - if self.token_in_line("strftime", line): - report_error("Don't use strftime; use absl::FormatTime instead") - if self.token_in_line("strptime", line): - report_error("Don't use strptime; use absl::FormatTime instead") - if self.token_in_line("strerror", line): - report_error("Don't use strerror; use Envoy::errorDetails instead") - # Prefer using abseil hash maps/sets over std::unordered_map/set for performance optimizations and - # non-deterministic iteration order that exposes faulty assertions. - # See: https://abseil.io/docs/cpp/guides/container#hash-tables - if "std::unordered_map" in line: - report_error("Don't use std::unordered_map; use absl::flat_hash_map instead or " - "absl::node_hash_map if pointer stability of keys/values is required") - if "std::unordered_set" in line: - report_error("Don't use std::unordered_set; use absl::flat_hash_set instead or " - "absl::node_hash_set if pointer stability of keys/values is required") - if "std::atomic_" in line: - # The std::atomic_* free functions are functionally equivalent to calling - # operations on std::atomic objects, so prefer to use that instead. - report_error("Don't use free std::atomic_* functions, use std::atomic members instead.") - # Block usage of certain std types/functions as iOS 11 and macOS 10.13 - # do not support these at runtime. - # See: https://github.com/envoyproxy/envoy/issues/12341 - if self.token_in_line("std::any", line): - report_error("Don't use std::any; use absl::any instead") - if self.token_in_line("std::get_if", line): - report_error("Don't use std::get_if; use absl::get_if instead") - if self.token_in_line("std::holds_alternative", line): - report_error("Don't use std::holds_alternative; use absl::holds_alternative instead") - if self.token_in_line("std::make_optional", line): - report_error("Don't use std::make_optional; use absl::make_optional instead") - if self.token_in_line("std::monostate", line): - report_error("Don't use std::monostate; use absl::monostate instead") - if self.token_in_line("std::optional", line): - report_error("Don't use std::optional; use absl::optional instead") - if self.token_in_line("std::string_view", line): - report_error("Don't use std::string_view; use absl::string_view instead") - if self.token_in_line("std::variant", line): - report_error("Don't use std::variant; use absl::variant instead") - if self.token_in_line("std::visit", line): - report_error("Don't use std::visit; use absl::visit instead") - if "__attribute__((packed))" in line and file_path != "./include/envoy/common/platform.h": - # __attribute__((packed)) is not supported by MSVC, we have a PACKED_STRUCT macro that - # can be used instead - report_error("Don't use __attribute__((packed)), use the PACKED_STRUCT macro defined " - "in include/envoy/common/platform.h instead") - if DESIGNATED_INITIALIZER_REGEX.search(line): - # Designated initializers are not part of the C++14 standard and are not supported - # by MSVC - report_error("Don't use designated initializers in struct initialization, " - "they are not part of C++14") - if " ?: " in line: - # The ?: operator is non-standard, it is a GCC extension - report_error("Don't use the '?:' operator, it is a non-standard GCC extension") - if line.startswith("using testing::Test;"): - report_error("Don't use 'using testing::Test;, elaborate the type instead") - if line.startswith("using testing::TestWithParams;"): - report_error("Don't use 'using testing::Test;, elaborate the type instead") - if TEST_NAME_STARTING_LOWER_CASE_REGEX.search(line): - # Matches variants of TEST(), TEST_P(), TEST_F() etc. where the test name begins - # with a lowercase letter. - report_error("Test names should be CamelCase, starting with a capital letter") - if OLD_MOCK_METHOD_REGEX.search(line): - report_error("The MOCK_METHODn() macros should not be used, use MOCK_METHOD() instead") - if FOR_EACH_N_REGEX.search(line): - report_error("std::for_each_n should not be used, use an alternative for loop instead") - - if not self.allow_listed_for_serialize_as_string(file_path) and "SerializeAsString" in line: - # The MessageLite::SerializeAsString doesn't generate deterministic serialization, - # use MessageUtil::hash instead. - report_error( - "Don't use MessageLite::SerializeAsString for generating deterministic serialization, use MessageUtil::hash instead." - ) - if not self.allow_listed_for_json_string_to_message( - file_path) and "JsonStringToMessage" in line: - # Centralize all usage of JSON parsing so it is easier to make changes in JSON parsing - # behavior. - report_error("Don't use Protobuf::util::JsonStringToMessage, use TestUtility::loadFromJson.") - - if self.is_in_subdir(file_path, 'source') and file_path.endswith('.cc') and \ - ('.counterFromString(' in line or '.gaugeFromString(' in line or \ - '.histogramFromString(' in line or '.textReadoutFromString(' in line or \ - '->counterFromString(' in line or '->gaugeFromString(' in line or \ - '->histogramFromString(' in line or '->textReadoutFromString(' in line): - report_error("Don't lookup stats by name at runtime; use StatName saved during construction") - - if MANGLED_PROTOBUF_NAME_REGEX.search(line): - report_error("Don't use mangled Protobuf names for enum constants") - - hist_m = HISTOGRAM_SI_SUFFIX_REGEX.search(line) - if hist_m and not self.allow_listed_for_histogram_si_suffix(hist_m.group(0)): - report_error( - "Don't suffix histogram names with the unit symbol, " - "it's already part of the histogram object and unit-supporting sinks can use this information natively, " - "other sinks can add the suffix automatically on flush should they prefer to do so.") - - if not self.allow_listed_for_std_regex(file_path) and "std::regex" in line: - report_error("Don't use std::regex in code that handles untrusted input. Use RegexMatcher") - - if not self.allow_listed_for_grpc_init(file_path): - grpc_init_or_shutdown = line.find("grpc_init()") - grpc_shutdown = line.find("grpc_shutdown()") - if grpc_init_or_shutdown == -1 or (grpc_shutdown != -1 and - grpc_shutdown < grpc_init_or_shutdown): - grpc_init_or_shutdown = grpc_shutdown - if grpc_init_or_shutdown != -1: - comment = line.find("// ") - if comment == -1 or comment > grpc_init_or_shutdown: - report_error("Don't call grpc_init() or grpc_shutdown() directly, instantiate " + - "Grpc::GoogleGrpcContext. See #8282") - - if not self.whitelisted_for_memcpy(file_path) and \ - not ("test/" in file_path) and \ - ("memcpy(" in line) and \ - not ("NOLINT(safe-memcpy)" in line): - report_error( - "Don't call memcpy() directly; use safeMemcpy, safeMemcpyUnsafeSrc, safeMemcpyUnsafeDst or MemBlockBuilder instead." - ) - - if self.deny_listed_for_exceptions(file_path): - # Skpping cases where 'throw' is a substring of a symbol like in "foothrowBar". - if "throw" in line.split(): - comment_match = COMMENT_REGEX.search(line) - if comment_match is None or comment_match.start(0) > line.find("throw"): - report_error("Don't introduce throws into exception-free files, use error " + - "statuses instead.") - - if "lua_pushlightuserdata" in line: - report_error( - "Don't use lua_pushlightuserdata, since it can cause unprotected error in call to" + - "Lua API (bad light userdata pointer) on ARM64 architecture. See " + - "https://github.com/LuaJIT/LuaJIT/issues/450#issuecomment-433659873 for details.") - - if file_path.endswith(PROTO_SUFFIX): - exclude_path = ['v1', 'v2', 'generated_api_shadow'] - result = PROTO_VALIDATION_STRING.search(line) - if result is not None: - if not any(x in file_path for x in exclude_path): - report_error("min_bytes is DEPRECATED, Use min_len.") - - def check_build_line(self, line, file_path, report_error): - if "@bazel_tools" in line and not (self.is_starlark_file(file_path) or - file_path.startswith("./bazel/") or - "python/runfiles" in line): - report_error("unexpected @bazel_tools reference, please indirect via a definition in //bazel") - if not self.allow_listed_for_protobuf_deps(file_path) and '"protobuf"' in line: - report_error("unexpected direct external dependency on protobuf, use " - "//source/common/protobuf instead.") - if (self.envoy_build_rule_check and not self.is_starlark_file(file_path) and - not self.is_workspace_file(file_path) and not self.is_external_build_file(file_path) and - "@envoy//" in line): - report_error("Superfluous '@envoy//' prefix") - if not self.allow_listed_for_build_urls(file_path) and (" urls = " in line or - " url = " in line): - report_error("Only repository_locations.bzl may contains URL references") - - def fix_build_line(self, file_path, line, line_number): - if (self.envoy_build_rule_check and not self.is_starlark_file(file_path) and - not self.is_workspace_file(file_path) and not self.is_external_build_file(file_path)): - line = line.replace("@envoy//", "//") - return line - - def fix_build_path(self, file_path): - self.evaluate_lines(file_path, functools.partial(self.fix_build_line, file_path)) - error_messages = [] - - # TODO(htuch): Add API specific BUILD fixer script. - if not self.is_build_fixer_excluded_file(file_path) and not self.is_api_file( - file_path) and not self.is_starlark_file(file_path) and not self.is_workspace_file( - file_path): - if os.system("%s %s %s" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path)) != 0: - error_messages += ["envoy_build_fixer rewrite failed for file: %s" % file_path] + # Determines whether the filename is either in the specified subdirectory, or + # at the top level. We consider files in the top level for the benefit of + # the check_format testcases in tools/testdata/check_format. + def is_in_subdir(self, filename, *subdirs): + # Skip this check for check_format's unit-tests. + if filename.count("/") <= 1: + return True + for subdir in subdirs: + if filename.startswith('./' + subdir + '/'): + return True + return False + + # Determines if given token exists in line without leading or trailing token characters + # e.g. will return True for a line containing foo() but not foo_bar() or baz_foo + def token_in_line(self, token, line): + index = 0 + while True: + index = line.find(token, index) + # the following check has been changed from index < 1 to index < 0 because + # this function incorrectly returns false when the token in question is the + # first one in a line. The following line returns false when the token is present: + # (no leading whitespace) violating_symbol foo; + if index < 0: + break + if index == 0 or not (line[index - 1].isalnum() or line[index - 1] == '_'): + if index + len(token) >= len(line) or not (line[index + len(token)].isalnum() or + line[index + len(token)] == '_'): + return True + index = index + 1 + return False + + def check_source_line(self, line, file_path, report_error): + # Check fixable errors. These may have been fixed already. + if line.find(". ") != -1: + report_error("over-enthusiastic spaces") + if self.is_in_subdir(file_path, 'source', + 'include') and X_ENVOY_USED_DIRECTLY_REGEX.match(line): + report_error( + "Please do not use the raw literal x-envoy in source code. See Envoy::Http::PrefixValue." + ) + if self.has_invalid_angle_bracket_directory(line): + report_error("envoy includes should not have angle brackets") + for invalid_construct, valid_construct in PROTOBUF_TYPE_ERRORS.items(): + if invalid_construct in line: + report_error("incorrect protobuf type reference %s; " + "should be %s" % (invalid_construct, valid_construct)) + for invalid_construct, valid_construct in LIBCXX_REPLACEMENTS.items(): + if invalid_construct in line: + report_error("term %s should be replaced with standard library term %s" % + (invalid_construct, valid_construct)) + for invalid_construct, valid_construct in CODE_CONVENTION_REPLACEMENTS.items(): + if invalid_construct in line: + report_error("term %s should be replaced with preferred term %s" % + (invalid_construct, valid_construct)) + # Do not include the virtual_includes headers. + if re.search("#include.*/_virtual_includes/", line): + report_error("Don't include the virtual includes headers.") + + # Some errors cannot be fixed automatically, and actionable, consistent, + # navigable messages should be emitted to make it easy to find and fix + # the errors by hand. + if not self.allow_listed_for_protobuf_deps(file_path): + if '"google/protobuf' in line or "google::protobuf" in line: + report_error("unexpected direct dependency on google.protobuf, use " + "the definitions in common/protobuf/protobuf.h instead.") + if line.startswith("#include ") or line.startswith("#include or , switch to " + "Thread::MutexBasicLockable in source/common/common/thread.h") + if line.startswith("#include "): + # We don't check here for std::shared_timed_mutex because that may + # legitimately show up in comments, for example this one. + report_error("Don't use , use absl::Mutex for reader/writer locks.") + if not self.allow_listed_for_realtime( + file_path) and not "NO_CHECK_FORMAT(real_time)" in line: + if "RealTimeSource" in line or \ + ("RealTimeSystem" in line and not "TestRealTimeSystem" in line) or \ + "std::chrono::system_clock::now" in line or "std::chrono::steady_clock::now" in line or \ + "std::this_thread::sleep_for" in line or self.has_cond_var_wait_for(line): + report_error( + "Don't reference real-world time sources from production code; use injection") + duration_arg = DURATION_VALUE_REGEX.search(line) + if duration_arg and duration_arg.group(1) != "0" and duration_arg.group(1) != "0.0": + # Matching duration(int-const or float-const) other than zero + report_error( + "Don't use ambiguous duration(value), use an explicit duration type, e.g. Event::TimeSystem::Milliseconds(value)" + ) + if not self.allow_listed_for_register_factory(file_path): + if "Registry::RegisterFactory<" in line or "REGISTER_FACTORY" in line: + report_error("Don't use Registry::RegisterFactory or REGISTER_FACTORY in tests, " + "use Registry::InjectFactory instead.") + if not self.allow_listed_for_unpack_to(file_path): + if "UnpackTo" in line: + report_error("Don't use UnpackTo() directly, use MessageUtil::unpackTo() instead") + # Check that we use the absl::Time library + if self.token_in_line("std::get_time", line): + if "test/" in file_path: + report_error("Don't use std::get_time; use TestUtility::parseTime in tests") + else: + report_error("Don't use std::get_time; use the injectable time system") + if self.token_in_line("std::put_time", line): + report_error("Don't use std::put_time; use absl::Time equivalent instead") + if self.token_in_line("gmtime", line): + report_error("Don't use gmtime; use absl::Time equivalent instead") + if self.token_in_line("mktime", line): + report_error("Don't use mktime; use absl::Time equivalent instead") + if self.token_in_line("localtime", line): + report_error("Don't use localtime; use absl::Time equivalent instead") + if self.token_in_line("strftime", line): + report_error("Don't use strftime; use absl::FormatTime instead") + if self.token_in_line("strptime", line): + report_error("Don't use strptime; use absl::FormatTime instead") + if self.token_in_line("strerror", line): + report_error("Don't use strerror; use Envoy::errorDetails instead") + # Prefer using abseil hash maps/sets over std::unordered_map/set for performance optimizations and + # non-deterministic iteration order that exposes faulty assertions. + # See: https://abseil.io/docs/cpp/guides/container#hash-tables + if "std::unordered_map" in line: + report_error("Don't use std::unordered_map; use absl::flat_hash_map instead or " + "absl::node_hash_map if pointer stability of keys/values is required") + if "std::unordered_set" in line: + report_error("Don't use std::unordered_set; use absl::flat_hash_set instead or " + "absl::node_hash_set if pointer stability of keys/values is required") + if "std::atomic_" in line: + # The std::atomic_* free functions are functionally equivalent to calling + # operations on std::atomic objects, so prefer to use that instead. + report_error( + "Don't use free std::atomic_* functions, use std::atomic members instead.") + # Block usage of certain std types/functions as iOS 11 and macOS 10.13 + # do not support these at runtime. + # See: https://github.com/envoyproxy/envoy/issues/12341 + if self.token_in_line("std::any", line): + report_error("Don't use std::any; use absl::any instead") + if self.token_in_line("std::get_if", line): + report_error("Don't use std::get_if; use absl::get_if instead") + if self.token_in_line("std::holds_alternative", line): + report_error("Don't use std::holds_alternative; use absl::holds_alternative instead") + if self.token_in_line("std::make_optional", line): + report_error("Don't use std::make_optional; use absl::make_optional instead") + if self.token_in_line("std::monostate", line): + report_error("Don't use std::monostate; use absl::monostate instead") + if self.token_in_line("std::optional", line): + report_error("Don't use std::optional; use absl::optional instead") + if self.token_in_line("std::string_view", line): + report_error("Don't use std::string_view; use absl::string_view instead") + if self.token_in_line("std::variant", line): + report_error("Don't use std::variant; use absl::variant instead") + if self.token_in_line("std::visit", line): + report_error("Don't use std::visit; use absl::visit instead") + if "__attribute__((packed))" in line and file_path != "./include/envoy/common/platform.h": + # __attribute__((packed)) is not supported by MSVC, we have a PACKED_STRUCT macro that + # can be used instead + report_error("Don't use __attribute__((packed)), use the PACKED_STRUCT macro defined " + "in include/envoy/common/platform.h instead") + if DESIGNATED_INITIALIZER_REGEX.search(line): + # Designated initializers are not part of the C++14 standard and are not supported + # by MSVC + report_error("Don't use designated initializers in struct initialization, " + "they are not part of C++14") + if " ?: " in line: + # The ?: operator is non-standard, it is a GCC extension + report_error("Don't use the '?:' operator, it is a non-standard GCC extension") + if line.startswith("using testing::Test;"): + report_error("Don't use 'using testing::Test;, elaborate the type instead") + if line.startswith("using testing::TestWithParams;"): + report_error("Don't use 'using testing::Test;, elaborate the type instead") + if TEST_NAME_STARTING_LOWER_CASE_REGEX.search(line): + # Matches variants of TEST(), TEST_P(), TEST_F() etc. where the test name begins + # with a lowercase letter. + report_error("Test names should be CamelCase, starting with a capital letter") + if OLD_MOCK_METHOD_REGEX.search(line): + report_error("The MOCK_METHODn() macros should not be used, use MOCK_METHOD() instead") + if FOR_EACH_N_REGEX.search(line): + report_error("std::for_each_n should not be used, use an alternative for loop instead") + + if not self.allow_listed_for_serialize_as_string(file_path) and "SerializeAsString" in line: + # The MessageLite::SerializeAsString doesn't generate deterministic serialization, + # use MessageUtil::hash instead. + report_error( + "Don't use MessageLite::SerializeAsString for generating deterministic serialization, use MessageUtil::hash instead." + ) + if not self.allow_listed_for_json_string_to_message( + file_path) and "JsonStringToMessage" in line: + # Centralize all usage of JSON parsing so it is easier to make changes in JSON parsing + # behavior. + report_error( + "Don't use Protobuf::util::JsonStringToMessage, use TestUtility::loadFromJson.") - if os.system("%s -lint=fix -mode=fix %s" % (BUILDIFIER_PATH, file_path)) != 0: - error_messages += ["buildifier rewrite failed for file: %s" % file_path] - return error_messages + if self.is_in_subdir(file_path, 'source') and file_path.endswith('.cc') and \ + ('.counterFromString(' in line or '.gaugeFromString(' in line or \ + '.histogramFromString(' in line or '.textReadoutFromString(' in line or \ + '->counterFromString(' in line or '->gaugeFromString(' in line or \ + '->histogramFromString(' in line or '->textReadoutFromString(' in line): + report_error( + "Don't lookup stats by name at runtime; use StatName saved during construction") - def check_build_path(self, file_path): - error_messages = [] + if MANGLED_PROTOBUF_NAME_REGEX.search(line): + report_error("Don't use mangled Protobuf names for enum constants") - if not self.is_build_fixer_excluded_file(file_path) and not self.is_api_file( - file_path) and not self.is_starlark_file(file_path) and not self.is_workspace_file( - file_path): - command = "%s %s | diff %s -" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path) - error_messages += self.execute_command(command, "envoy_build_fixer check failed", file_path) - - if self.is_build_file(file_path) and (file_path.startswith(self.api_prefix + "envoy") or - file_path.startswith(self.api_shadow_root + "envoy")): - found = False - for line in self.read_lines(file_path): - if "api_proto_package(" in line: - found = True - break - if not found: - error_messages += ["API build file does not provide api_proto_package()"] - - command = "%s -mode=diff %s" % (BUILDIFIER_PATH, file_path) - error_messages += self.execute_command(command, "buildifier check failed", file_path) - error_messages += self.check_file_contents(file_path, self.check_build_line) - return error_messages - - def fix_source_path(self, file_path): - self.evaluate_lines(file_path, self.fix_source_line) + hist_m = HISTOGRAM_SI_SUFFIX_REGEX.search(line) + if hist_m and not self.allow_listed_for_histogram_si_suffix(hist_m.group(0)): + report_error( + "Don't suffix histogram names with the unit symbol, " + "it's already part of the histogram object and unit-supporting sinks can use this information natively, " + "other sinks can add the suffix automatically on flush should they prefer to do so." + ) - error_messages = [] + if not self.allow_listed_for_std_regex(file_path) and "std::regex" in line: + report_error( + "Don't use std::regex in code that handles untrusted input. Use RegexMatcher") + + if not self.allow_listed_for_grpc_init(file_path): + grpc_init_or_shutdown = line.find("grpc_init()") + grpc_shutdown = line.find("grpc_shutdown()") + if grpc_init_or_shutdown == -1 or (grpc_shutdown != -1 and + grpc_shutdown < grpc_init_or_shutdown): + grpc_init_or_shutdown = grpc_shutdown + if grpc_init_or_shutdown != -1: + comment = line.find("// ") + if comment == -1 or comment > grpc_init_or_shutdown: + report_error( + "Don't call grpc_init() or grpc_shutdown() directly, instantiate " + + "Grpc::GoogleGrpcContext. See #8282") + + if not self.whitelisted_for_memcpy(file_path) and \ + not ("test/" in file_path) and \ + ("memcpy(" in line) and \ + not ("NOLINT(safe-memcpy)" in line): + report_error( + "Don't call memcpy() directly; use safeMemcpy, safeMemcpyUnsafeSrc, safeMemcpyUnsafeDst or MemBlockBuilder instead." + ) + + if self.deny_listed_for_exceptions(file_path): + # Skpping cases where 'throw' is a substring of a symbol like in "foothrowBar". + if "throw" in line.split(): + comment_match = COMMENT_REGEX.search(line) + if comment_match is None or comment_match.start(0) > line.find("throw"): + report_error("Don't introduce throws into exception-free files, use error " + + "statuses instead.") + + if "lua_pushlightuserdata" in line: + report_error( + "Don't use lua_pushlightuserdata, since it can cause unprotected error in call to" + + "Lua API (bad light userdata pointer) on ARM64 architecture. See " + + "https://github.com/LuaJIT/LuaJIT/issues/450#issuecomment-433659873 for details.") + + if file_path.endswith(PROTO_SUFFIX): + exclude_path = ['v1', 'v2', 'generated_api_shadow'] + result = PROTO_VALIDATION_STRING.search(line) + if result is not None: + if not any(x in file_path for x in exclude_path): + report_error("min_bytes is DEPRECATED, Use min_len.") + + def check_build_line(self, line, file_path, report_error): + if "@bazel_tools" in line and not (self.is_starlark_file(file_path) or + file_path.startswith("./bazel/") or + "python/runfiles" in line): + report_error( + "unexpected @bazel_tools reference, please indirect via a definition in //bazel") + if not self.allow_listed_for_protobuf_deps(file_path) and '"protobuf"' in line: + report_error("unexpected direct external dependency on protobuf, use " + "//source/common/protobuf instead.") + if (self.envoy_build_rule_check and not self.is_starlark_file(file_path) and + not self.is_workspace_file(file_path) and + not self.is_external_build_file(file_path) and "@envoy//" in line): + report_error("Superfluous '@envoy//' prefix") + if not self.allow_listed_for_build_urls(file_path) and (" urls = " in line or + " url = " in line): + report_error("Only repository_locations.bzl may contains URL references") + + def fix_build_line(self, file_path, line, line_number): + if (self.envoy_build_rule_check and not self.is_starlark_file(file_path) and + not self.is_workspace_file(file_path) and + not self.is_external_build_file(file_path)): + line = line.replace("@envoy//", "//") + return line + + def fix_build_path(self, file_path): + self.evaluate_lines(file_path, functools.partial(self.fix_build_line, file_path)) + + error_messages = [] + + # TODO(htuch): Add API specific BUILD fixer script. + if not self.is_build_fixer_excluded_file(file_path) and not self.is_api_file( + file_path) and not self.is_starlark_file(file_path) and not self.is_workspace_file( + file_path): + if os.system("%s %s %s" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path)) != 0: + error_messages += ["envoy_build_fixer rewrite failed for file: %s" % file_path] + + if os.system("%s -lint=fix -mode=fix %s" % (BUILDIFIER_PATH, file_path)) != 0: + error_messages += ["buildifier rewrite failed for file: %s" % file_path] + return error_messages + + def check_build_path(self, file_path): + error_messages = [] + + if not self.is_build_fixer_excluded_file(file_path) and not self.is_api_file( + file_path) and not self.is_starlark_file(file_path) and not self.is_workspace_file( + file_path): + command = "%s %s | diff %s -" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path) + error_messages += self.execute_command(command, "envoy_build_fixer check failed", + file_path) + + if self.is_build_file(file_path) and (file_path.startswith(self.api_prefix + "envoy") or + file_path.startswith(self.api_shadow_root + "envoy")): + found = False + for line in self.read_lines(file_path): + if "api_proto_package(" in line: + found = True + break + if not found: + error_messages += ["API build file does not provide api_proto_package()"] + + command = "%s -mode=diff %s" % (BUILDIFIER_PATH, file_path) + error_messages += self.execute_command(command, "buildifier check failed", file_path) + error_messages += self.check_file_contents(file_path, self.check_build_line) + return error_messages + + def fix_source_path(self, file_path): + self.evaluate_lines(file_path, self.fix_source_line) + + error_messages = [] + + if not file_path.endswith(DOCS_SUFFIX): + if not file_path.endswith(PROTO_SUFFIX): + error_messages += self.fix_header_order(file_path) + error_messages += self.clang_format(file_path) + if file_path.endswith(PROTO_SUFFIX) and self.is_api_file(file_path): + package_name, error_message = self.package_name_for_proto(file_path) + if package_name is None: + error_messages += error_message + return error_messages + + def check_source_path(self, file_path): + error_messages = self.check_file_contents(file_path, self.check_source_line) + + if not file_path.endswith(DOCS_SUFFIX): + if not file_path.endswith(PROTO_SUFFIX): + error_messages += self.check_namespace(file_path) + command = ("%s --include_dir_order %s --path %s | diff %s -" % + (HEADER_ORDER_PATH, self.include_dir_order, file_path, file_path)) + error_messages += self.execute_command(command, "header_order.py check failed", + file_path) + command = ("%s %s | diff %s -" % (CLANG_FORMAT_PATH, file_path, file_path)) + error_messages += self.execute_command(command, "clang-format check failed", file_path) + + if file_path.endswith(PROTO_SUFFIX) and self.is_api_file(file_path): + package_name, error_message = self.package_name_for_proto(file_path) + if package_name is None: + error_messages += error_message + return error_messages + + # Example target outputs are: + # - "26,27c26" + # - "12,13d13" + # - "7a8,9" + def execute_command(self, + command, + error_message, + file_path, + regex=re.compile(r"^(\d+)[a|c|d]?\d*(?:,\d+[a|c|d]?\d*)?$")): + try: + output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT).strip() + if output: + return output.decode('utf-8').split("\n") + return [] + except subprocess.CalledProcessError as e: + if (e.returncode != 0 and e.returncode != 1): + return ["ERROR: something went wrong while executing: %s" % e.cmd] + # In case we can't find any line numbers, record an error message first. + error_messages = ["%s for file: %s" % (error_message, file_path)] + for line in e.output.decode('utf-8').splitlines(): + for num in regex.findall(line): + error_messages.append(" %s:%s" % (file_path, num)) + return error_messages + + def fix_header_order(self, file_path): + command = "%s --rewrite --include_dir_order %s --path %s" % ( + HEADER_ORDER_PATH, self.include_dir_order, file_path) + if os.system(command) != 0: + return ["header_order.py rewrite error: %s" % (file_path)] + return [] - if not file_path.endswith(DOCS_SUFFIX): - if not file_path.endswith(PROTO_SUFFIX): - error_messages += self.fix_header_order(file_path) - error_messages += self.clang_format(file_path) - if file_path.endswith(PROTO_SUFFIX) and self.is_api_file(file_path): - package_name, error_message = self.package_name_for_proto(file_path) - if package_name is None: - error_messages += error_message - return error_messages - - def check_source_path(self, file_path): - error_messages = self.check_file_contents(file_path, self.check_source_line) - - if not file_path.endswith(DOCS_SUFFIX): - if not file_path.endswith(PROTO_SUFFIX): - error_messages += self.check_namespace(file_path) - command = ("%s --include_dir_order %s --path %s | diff %s -" % - (HEADER_ORDER_PATH, self.include_dir_order, file_path, file_path)) - error_messages += self.execute_command(command, "header_order.py check failed", file_path) - command = ("%s %s | diff %s -" % (CLANG_FORMAT_PATH, file_path, file_path)) - error_messages += self.execute_command(command, "clang-format check failed", file_path) - - if file_path.endswith(PROTO_SUFFIX) and self.is_api_file(file_path): - package_name, error_message = self.package_name_for_proto(file_path) - if package_name is None: - error_messages += error_message - return error_messages - - # Example target outputs are: - # - "26,27c26" - # - "12,13d13" - # - "7a8,9" - def execute_command(self, - command, - error_message, - file_path, - regex=re.compile(r"^(\d+)[a|c|d]?\d*(?:,\d+[a|c|d]?\d*)?$")): - try: - output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT).strip() - if output: - return output.decode('utf-8').split("\n") - return [] - except subprocess.CalledProcessError as e: - if (e.returncode != 0 and e.returncode != 1): - return ["ERROR: something went wrong while executing: %s" % e.cmd] - # In case we can't find any line numbers, record an error message first. - error_messages = ["%s for file: %s" % (error_message, file_path)] - for line in e.output.decode('utf-8').splitlines(): - for num in regex.findall(line): - error_messages.append(" %s:%s" % (file_path, num)) - return error_messages - - def fix_header_order(self, file_path): - command = "%s --rewrite --include_dir_order %s --path %s" % (HEADER_ORDER_PATH, - self.include_dir_order, file_path) - if os.system(command) != 0: - return ["header_order.py rewrite error: %s" % (file_path)] - return [] - - def clang_format(self, file_path): - command = "%s -i %s" % (CLANG_FORMAT_PATH, file_path) - if os.system(command) != 0: - return ["clang-format rewrite error: %s" % (file_path)] - return [] - - def check_format(self, file_path): - if file_path.startswith(EXCLUDED_PREFIXES): - return [] - - if not file_path.endswith(SUFFIXES): - return [] + def clang_format(self, file_path): + command = "%s -i %s" % (CLANG_FORMAT_PATH, file_path) + if os.system(command) != 0: + return ["clang-format rewrite error: %s" % (file_path)] + return [] - error_messages = [] - # Apply fixes first, if asked, and then run checks. If we wind up attempting to fix - # an issue, but there's still an error, that's a problem. - try_to_fix = self.operation_type == "fix" - if self.is_build_file(file_path) or self.is_starlark_file(file_path) or self.is_workspace_file( - file_path): - if try_to_fix: - error_messages += self.fix_build_path(file_path) - error_messages += self.check_build_path(file_path) - else: - if try_to_fix: - error_messages += self.fix_source_path(file_path) - error_messages += self.check_source_path(file_path) - - if error_messages: - return ["From %s" % file_path] + error_messages - return error_messages - - def check_format_return_trace_on_error(self, file_path): - """Run check_format and return the traceback of any exception.""" - try: - return self.check_format(file_path) - except: - return traceback.format_exc().split("\n") - - def check_owners(self, dir_name, owned_directories, error_messages): - """Checks to make sure a given directory is present either in CODEOWNERS or OWNED_EXTENSIONS + def check_format(self, file_path): + if file_path.startswith(EXCLUDED_PREFIXES): + return [] + + if not file_path.endswith(SUFFIXES): + return [] + + error_messages = [] + # Apply fixes first, if asked, and then run checks. If we wind up attempting to fix + # an issue, but there's still an error, that's a problem. + try_to_fix = self.operation_type == "fix" + if self.is_build_file(file_path) or self.is_starlark_file( + file_path) or self.is_workspace_file(file_path): + if try_to_fix: + error_messages += self.fix_build_path(file_path) + error_messages += self.check_build_path(file_path) + else: + if try_to_fix: + error_messages += self.fix_source_path(file_path) + error_messages += self.check_source_path(file_path) + + if error_messages: + return ["From %s" % file_path] + error_messages + return error_messages + + def check_format_return_trace_on_error(self, file_path): + """Run check_format and return the traceback of any exception.""" + try: + return self.check_format(file_path) + except: + return traceback.format_exc().split("\n") + + def check_owners(self, dir_name, owned_directories, error_messages): + """Checks to make sure a given directory is present either in CODEOWNERS or OWNED_EXTENSIONS Args: dir_name: the directory being checked. owned_directories: directories currently listed in CODEOWNERS. error_messages: where to put an error message for new unowned directories. """ - found = False - for owned in owned_directories: - if owned.startswith(dir_name) or dir_name.startswith(owned): - found = True - if not found and dir_name not in UNOWNED_EXTENSIONS: - error_messages.append("New directory %s appears to not have owners in CODEOWNERS" % dir_name) - - def check_api_shadow_starlark_files(self, file_path, error_messages): - command = "diff -u " - command += file_path + " " - api_shadow_starlark_path = self.api_shadow_root + re.sub(r"\./api/", '', file_path) - command += api_shadow_starlark_path - - error_message = self.execute_command(command, "invalid .bzl in generated_api_shadow", file_path) - if self.operation_type == "check": - error_messages += error_message - elif self.operation_type == "fix" and len(error_message) != 0: - shutil.copy(file_path, api_shadow_starlark_path) - - return error_messages - - def check_format_visitor(self, arg, dir_name, names): - """Run check_format in parallel for the given files. + found = False + for owned in owned_directories: + if owned.startswith(dir_name) or dir_name.startswith(owned): + found = True + if not found and dir_name not in UNOWNED_EXTENSIONS: + error_messages.append("New directory %s appears to not have owners in CODEOWNERS" % + dir_name) + + def check_api_shadow_starlark_files(self, file_path, error_messages): + command = "diff -u " + command += file_path + " " + api_shadow_starlark_path = self.api_shadow_root + re.sub(r"\./api/", '', file_path) + command += api_shadow_starlark_path + + error_message = self.execute_command(command, "invalid .bzl in generated_api_shadow", + file_path) + if self.operation_type == "check": + error_messages += error_message + elif self.operation_type == "fix" and len(error_message) != 0: + shutil.copy(file_path, api_shadow_starlark_path) + + return error_messages + + def check_format_visitor(self, arg, dir_name, names): + """Run check_format in parallel for the given files. Args: arg: a tuple (pool, result_list, owned_directories, error_messages) pool and result_list are for starting tasks asynchronously. @@ -1065,170 +1083,173 @@ def check_format_visitor(self, arg, dir_name, names): names: a list of file names. """ - # Unpack the multiprocessing.Pool process pool and list of results. Since - # python lists are passed as references, this is used to collect the list of - # async results (futures) from running check_format and passing them back to - # the caller. - pool, result_list, owned_directories, error_messages = arg - - # Sanity check CODEOWNERS. This doesn't need to be done in a multi-threaded - # manner as it is a small and limited list. - source_prefix = './source/' - full_prefix = './source/extensions/' - # Check to see if this directory is a subdir under /source/extensions - # Also ignore top level directories under /source/extensions since we don't - # need owners for source/extensions/access_loggers etc, just the subdirectories. - if dir_name.startswith(full_prefix) and '/' in dir_name[len(full_prefix):]: - self.check_owners(dir_name[len(source_prefix):], owned_directories, error_messages) - - for file_name in names: - if dir_name.startswith("./api") and self.is_starlark_file(file_name): - result = pool.apply_async(self.check_api_shadow_starlark_files, - args=(dir_name + "/" + file_name, error_messages)) - result_list.append(result) - result = pool.apply_async(self.check_format_return_trace_on_error, - args=(dir_name + "/" + file_name,)) - result_list.append(result) - - # check_error_messages iterates over the list with error messages and prints - # errors and returns a bool based on whether there were any errors. - def check_error_messages(self, error_messages): - if error_messages: - for e in error_messages: - print("ERROR: %s" % e) - return True - return False - - def whitelisted_for_memcpy(self, file_path): - return file_path in MEMCPY_WHITELIST + # Unpack the multiprocessing.Pool process pool and list of results. Since + # python lists are passed as references, this is used to collect the list of + # async results (futures) from running check_format and passing them back to + # the caller. + pool, result_list, owned_directories, error_messages = arg + + # Sanity check CODEOWNERS. This doesn't need to be done in a multi-threaded + # manner as it is a small and limited list. + source_prefix = './source/' + full_prefix = './source/extensions/' + # Check to see if this directory is a subdir under /source/extensions + # Also ignore top level directories under /source/extensions since we don't + # need owners for source/extensions/access_loggers etc, just the subdirectories. + if dir_name.startswith(full_prefix) and '/' in dir_name[len(full_prefix):]: + self.check_owners(dir_name[len(source_prefix):], owned_directories, error_messages) + + for file_name in names: + if dir_name.startswith("./api") and self.is_starlark_file(file_name): + result = pool.apply_async(self.check_api_shadow_starlark_files, + args=(dir_name + "/" + file_name, error_messages)) + result_list.append(result) + result = pool.apply_async(self.check_format_return_trace_on_error, + args=(dir_name + "/" + file_name,)) + result_list.append(result) + + # check_error_messages iterates over the list with error messages and prints + # errors and returns a bool based on whether there were any errors. + def check_error_messages(self, error_messages): + if error_messages: + for e in error_messages: + print("ERROR: %s" % e) + return True + return False + + def whitelisted_for_memcpy(self, file_path): + return file_path in MEMCPY_WHITELIST if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Check or fix file format.") - parser.add_argument("operation_type", - type=str, - choices=["check", "fix"], - help="specify if the run should 'check' or 'fix' format.") - parser.add_argument( - "target_path", - type=str, - nargs="?", - default=".", - help="specify the root directory for the script to recurse over. Default '.'.") - parser.add_argument("--add-excluded-prefixes", - type=str, - nargs="+", - help="exclude additional prefixes.") - parser.add_argument("-j", - "--num-workers", - type=int, - default=multiprocessing.cpu_count(), - help="number of worker processes to use; defaults to one per core.") - parser.add_argument("--api-prefix", type=str, default="./api/", help="path of the API tree.") - parser.add_argument("--api-shadow-prefix", - type=str, - default="./generated_api_shadow/", - help="path of the shadow API tree.") - parser.add_argument("--skip_envoy_build_rule_check", - action="store_true", - help="skip checking for '@envoy//' prefix in build rules.") - parser.add_argument("--namespace_check", - type=str, - nargs="?", - default="Envoy", - help="specify namespace check string. Default 'Envoy'.") - parser.add_argument("--namespace_check_excluded_paths", - type=str, - nargs="+", - default=[], - help="exclude paths from the namespace_check.") - parser.add_argument("--build_fixer_check_excluded_paths", - type=str, - nargs="+", - default=[], - help="exclude paths from envoy_build_fixer check.") - parser.add_argument("--bazel_tools_check_excluded_paths", - type=str, - nargs="+", - default=[], - help="exclude paths from bazel_tools check.") - parser.add_argument("--include_dir_order", - type=str, - default=",".join(common.include_dir_order()), - help="specify the header block include directory order.") - args = parser.parse_args() - if args.add_excluded_prefixes: - EXCLUDED_PREFIXES += tuple(args.add_excluded_prefixes) - format_checker = FormatChecker(args) - - # Check whether all needed external tools are available. - ct_error_messages = format_checker.check_tools() - if format_checker.check_error_messages(ct_error_messages): - sys.exit(1) - - # Returns the list of directories with owners listed in CODEOWNERS. May append errors to - # error_messages. - def owned_directories(error_messages): - owned = [] - maintainers = [ - '@mattklein123', '@htuch', '@alyssawilk', '@zuercher', '@lizan', '@snowp', '@asraa', - '@yanavlasov', '@junr03', '@dio', '@jmarantz', '@antoniovicente' - ] - - try: - with open('./CODEOWNERS') as f: - for line in f: - # If this line is of the form "extensions/... @owner1 @owner2" capture the directory - # name and store it in the list of directories with documented owners. - m = EXTENSIONS_CODEOWNERS_REGEX.search(line) - if m is not None and not line.startswith('#'): - owned.append(m.group(1).strip()) - owners = re.findall('@\S+', m.group(2).strip()) - if len(owners) < 2: - error_messages.append("Extensions require at least 2 owners in CODEOWNERS:\n" - " {}".format(line)) - maintainer = len(set(owners).intersection(set(maintainers))) > 0 - if not maintainer: - error_messages.append("Extensions require at least one maintainer OWNER:\n" - " {}".format(line)) - - return owned - except IOError: - return [] # for the check format tests. - - # Calculate the list of owned directories once per run. - error_messages = [] - owned_directories = owned_directories(error_messages) - - if os.path.isfile(args.target_path): - error_messages += format_checker.check_format("./" + args.target_path) - else: - results = [] - - def pooled_check_format(path_predicate): - pool = multiprocessing.Pool(processes=args.num_workers) - # For each file in target_path, start a new task in the pool and collect the - # results (results is passed by reference, and is used as an output). - for root, _, files in os.walk(args.target_path): - format_checker.check_format_visitor((pool, results, owned_directories, error_messages), - root, [f for f in files if path_predicate(f)]) - - # Close the pool to new tasks, wait for all of the running tasks to finish, - # then collect the error messages. - pool.close() - pool.join() - - # We first run formatting on non-BUILD files, since the BUILD file format - # requires analysis of srcs/hdrs in the BUILD file, and we don't want these - # to be rewritten by other multiprocessing pooled processes. - pooled_check_format(lambda f: not format_checker.is_build_file(f)) - pooled_check_format(lambda f: format_checker.is_build_file(f)) - - error_messages += sum((r.get() for r in results), []) - - if format_checker.check_error_messages(error_messages): - print("ERROR: check format failed. run 'tools/code_format/check_format.py fix'") - sys.exit(1) - - if args.operation_type == "check": - print("PASS") + parser = argparse.ArgumentParser(description="Check or fix file format.") + parser.add_argument("operation_type", + type=str, + choices=["check", "fix"], + help="specify if the run should 'check' or 'fix' format.") + parser.add_argument( + "target_path", + type=str, + nargs="?", + default=".", + help="specify the root directory for the script to recurse over. Default '.'.") + parser.add_argument("--add-excluded-prefixes", + type=str, + nargs="+", + help="exclude additional prefixes.") + parser.add_argument("-j", + "--num-workers", + type=int, + default=multiprocessing.cpu_count(), + help="number of worker processes to use; defaults to one per core.") + parser.add_argument("--api-prefix", type=str, default="./api/", help="path of the API tree.") + parser.add_argument("--api-shadow-prefix", + type=str, + default="./generated_api_shadow/", + help="path of the shadow API tree.") + parser.add_argument("--skip_envoy_build_rule_check", + action="store_true", + help="skip checking for '@envoy//' prefix in build rules.") + parser.add_argument("--namespace_check", + type=str, + nargs="?", + default="Envoy", + help="specify namespace check string. Default 'Envoy'.") + parser.add_argument("--namespace_check_excluded_paths", + type=str, + nargs="+", + default=[], + help="exclude paths from the namespace_check.") + parser.add_argument("--build_fixer_check_excluded_paths", + type=str, + nargs="+", + default=[], + help="exclude paths from envoy_build_fixer check.") + parser.add_argument("--bazel_tools_check_excluded_paths", + type=str, + nargs="+", + default=[], + help="exclude paths from bazel_tools check.") + parser.add_argument("--include_dir_order", + type=str, + default=",".join(common.include_dir_order()), + help="specify the header block include directory order.") + args = parser.parse_args() + if args.add_excluded_prefixes: + EXCLUDED_PREFIXES += tuple(args.add_excluded_prefixes) + format_checker = FormatChecker(args) + + # Check whether all needed external tools are available. + ct_error_messages = format_checker.check_tools() + if format_checker.check_error_messages(ct_error_messages): + sys.exit(1) + + # Returns the list of directories with owners listed in CODEOWNERS. May append errors to + # error_messages. + def owned_directories(error_messages): + owned = [] + maintainers = [ + '@mattklein123', '@htuch', '@alyssawilk', '@zuercher', '@lizan', '@snowp', '@asraa', + '@yanavlasov', '@junr03', '@dio', '@jmarantz', '@antoniovicente' + ] + + try: + with open('./CODEOWNERS') as f: + for line in f: + # If this line is of the form "extensions/... @owner1 @owner2" capture the directory + # name and store it in the list of directories with documented owners. + m = EXTENSIONS_CODEOWNERS_REGEX.search(line) + if m is not None and not line.startswith('#'): + owned.append(m.group(1).strip()) + owners = re.findall('@\S+', m.group(2).strip()) + if len(owners) < 2: + error_messages.append( + "Extensions require at least 2 owners in CODEOWNERS:\n" + " {}".format(line)) + maintainer = len(set(owners).intersection(set(maintainers))) > 0 + if not maintainer: + error_messages.append( + "Extensions require at least one maintainer OWNER:\n" + " {}".format(line)) + + return owned + except IOError: + return [] # for the check format tests. + + # Calculate the list of owned directories once per run. + error_messages = [] + owned_directories = owned_directories(error_messages) + + if os.path.isfile(args.target_path): + error_messages += format_checker.check_format("./" + args.target_path) + else: + results = [] + + def pooled_check_format(path_predicate): + pool = multiprocessing.Pool(processes=args.num_workers) + # For each file in target_path, start a new task in the pool and collect the + # results (results is passed by reference, and is used as an output). + for root, _, files in os.walk(args.target_path): + format_checker.check_format_visitor( + (pool, results, owned_directories, error_messages), root, + [f for f in files if path_predicate(f)]) + + # Close the pool to new tasks, wait for all of the running tasks to finish, + # then collect the error messages. + pool.close() + pool.join() + + # We first run formatting on non-BUILD files, since the BUILD file format + # requires analysis of srcs/hdrs in the BUILD file, and we don't want these + # to be rewritten by other multiprocessing pooled processes. + pooled_check_format(lambda f: not format_checker.is_build_file(f)) + pooled_check_format(lambda f: format_checker.is_build_file(f)) + + error_messages += sum((r.get() for r in results), []) + + if format_checker.check_error_messages(error_messages): + print("ERROR: check format failed. run 'tools/code_format/check_format.py fix'") + sys.exit(1) + + if args.operation_type == "check": + print("PASS") diff --git a/tools/code_format/check_format_test_helper.py b/tools/code_format/check_format_test_helper.py index fe671b98bcb0..030dc665404e 100755 --- a/tools/code_format/check_format_test_helper.py +++ b/tools/code_format/check_format_test_helper.py @@ -26,313 +26,315 @@ # the comamnd run and the status code as well as the stdout, and returning # all of that to the caller. def run_check_format(operation, filename): - command = check_format + " " + operation + " " + filename - status, stdout, stderr = run_command(command) - return (command, status, stdout + stderr) + command = check_format + " " + operation + " " + filename + status, stdout, stderr = run_command(command) + return (command, status, stdout + stderr) def get_input_file(filename, extra_input_files=None): - files_to_copy = [filename] - if extra_input_files is not None: - files_to_copy.extend(extra_input_files) - for f in files_to_copy: - infile = os.path.join(src, f) - directory = os.path.dirname(f) - if not directory == '' and not os.path.isdir(directory): - os.makedirs(directory) - shutil.copyfile(infile, f) - return filename + files_to_copy = [filename] + if extra_input_files is not None: + files_to_copy.extend(extra_input_files) + for f in files_to_copy: + infile = os.path.join(src, f) + directory = os.path.dirname(f) + if not directory == '' and not os.path.isdir(directory): + os.makedirs(directory) + shutil.copyfile(infile, f) + return filename # Attempts to fix file, returning a 4-tuple: the command, input file name, # output filename, captured stdout as an array of lines, and the error status # code. def fix_file_helper(filename, extra_input_files=None): - command, status, stdout = run_check_format( - "fix", get_input_file(filename, extra_input_files=extra_input_files)) - infile = os.path.join(src, filename) - return command, infile, filename, status, stdout + command, status, stdout = run_check_format( + "fix", get_input_file(filename, extra_input_files=extra_input_files)) + infile = os.path.join(src, filename) + return command, infile, filename, status, stdout # Attempts to fix a file, returning the status code and the generated output. # If the fix was successful, the diff is returned as a string-array. If the file # was not fixable, the error-messages are returned as a string-array. def fix_file_expecting_success(file, extra_input_files=None): - command, infile, outfile, status, stdout = fix_file_helper(file, - extra_input_files=extra_input_files) - if status != 0: - print("FAILED: " + infile) - emit_stdout_as_error(stdout) - return 1 - status, stdout, stderr = run_command('diff ' + outfile + ' ' + infile + '.gold') - if status != 0: - print("FAILED: " + infile) - emit_stdout_as_error(stdout + stderr) - return 1 - return 0 + command, infile, outfile, status, stdout = fix_file_helper(file, + extra_input_files=extra_input_files) + if status != 0: + print("FAILED: " + infile) + emit_stdout_as_error(stdout) + return 1 + status, stdout, stderr = run_command('diff ' + outfile + ' ' + infile + '.gold') + if status != 0: + print("FAILED: " + infile) + emit_stdout_as_error(stdout + stderr) + return 1 + return 0 def fix_file_expecting_no_change(file): - command, infile, outfile, status, stdout = fix_file_helper(file) - if status != 0: - return 1 - status, stdout, stderr = run_command('diff ' + outfile + ' ' + infile) - if status != 0: - logging.error(file + ': expected file to remain unchanged') - return 1 - return 0 + command, infile, outfile, status, stdout = fix_file_helper(file) + if status != 0: + return 1 + status, stdout, stderr = run_command('diff ' + outfile + ' ' + infile) + if status != 0: + logging.error(file + ': expected file to remain unchanged') + return 1 + return 0 def emit_stdout_as_error(stdout): - logging.error("\n".join(stdout)) + logging.error("\n".join(stdout)) def expect_error(filename, status, stdout, expected_substring): - if status == 0: - logging.error("%s: Expected failure `%s`, but succeeded" % (filename, expected_substring)) + if status == 0: + logging.error("%s: Expected failure `%s`, but succeeded" % (filename, expected_substring)) + return 1 + for line in stdout: + if expected_substring in line: + return 0 + logging.error("%s: Could not find '%s' in:\n" % (filename, expected_substring)) + emit_stdout_as_error(stdout) return 1 - for line in stdout: - if expected_substring in line: - return 0 - logging.error("%s: Could not find '%s' in:\n" % (filename, expected_substring)) - emit_stdout_as_error(stdout) - return 1 def fix_file_expecting_failure(filename, expected_substring): - command, infile, outfile, status, stdout = fix_file_helper(filename) - return expect_error(filename, status, stdout, expected_substring) + command, infile, outfile, status, stdout = fix_file_helper(filename) + return expect_error(filename, status, stdout, expected_substring) def check_file_expecting_error(filename, expected_substring, extra_input_files=None): - command, status, stdout = run_check_format( - "check", get_input_file(filename, extra_input_files=extra_input_files)) - return expect_error(filename, status, stdout, expected_substring) + command, status, stdout = run_check_format( + "check", get_input_file(filename, extra_input_files=extra_input_files)) + return expect_error(filename, status, stdout, expected_substring) def check_and_fix_error(filename, expected_substring, extra_input_files=None): - errors = check_file_expecting_error(filename, - expected_substring, - extra_input_files=extra_input_files) - errors += fix_file_expecting_success(filename, extra_input_files=extra_input_files) - return errors + errors = check_file_expecting_error(filename, + expected_substring, + extra_input_files=extra_input_files) + errors += fix_file_expecting_success(filename, extra_input_files=extra_input_files) + return errors def check_tool_not_found_error(): - # Temporarily change PATH to test the error about lack of external tools. - oldPath = os.environ["PATH"] - os.environ["PATH"] = "/sbin:/usr/sbin" - clang_format = os.getenv("CLANG_FORMAT", "clang-format-9") - # If CLANG_FORMAT points directly to the binary, skip this test. - if os.path.isfile(clang_format) and os.access(clang_format, os.X_OK): + # Temporarily change PATH to test the error about lack of external tools. + oldPath = os.environ["PATH"] + os.environ["PATH"] = "/sbin:/usr/sbin" + clang_format = os.getenv("CLANG_FORMAT", "clang-format-9") + # If CLANG_FORMAT points directly to the binary, skip this test. + if os.path.isfile(clang_format) and os.access(clang_format, os.X_OK): + os.environ["PATH"] = oldPath + return 0 + errors = check_file_expecting_error("no_namespace_envoy.cc", + "Command %s not found." % clang_format) os.environ["PATH"] = oldPath - return 0 - errors = check_file_expecting_error("no_namespace_envoy.cc", - "Command %s not found." % clang_format) - os.environ["PATH"] = oldPath - return errors + return errors def check_unfixable_error(filename, expected_substring): - errors = check_file_expecting_error(filename, expected_substring) - errors += fix_file_expecting_failure(filename, expected_substring) - return errors + errors = check_file_expecting_error(filename, expected_substring) + errors += fix_file_expecting_failure(filename, expected_substring) + return errors def check_file_expecting_ok(filename): - command, status, stdout = run_check_format("check", get_input_file(filename)) - if status != 0: - logging.error("Expected %s to have no errors; status=%d, output:\n" % (filename, status)) - emit_stdout_as_error(stdout) - return status + fix_file_expecting_no_change(filename) + command, status, stdout = run_check_format("check", get_input_file(filename)) + if status != 0: + logging.error("Expected %s to have no errors; status=%d, output:\n" % (filename, status)) + emit_stdout_as_error(stdout) + return status + fix_file_expecting_no_change(filename) def run_checks(): - errors = 0 - - # The following error is the error about unavailability of external tools. - errors += check_tool_not_found_error() - - # The following errors can be detected but not fixed automatically. - errors += check_unfixable_error("no_namespace_envoy.cc", - "Unable to find Envoy namespace or NOLINT(namespace-envoy)") - errors += check_unfixable_error("mutex.cc", "Don't use or ") - errors += check_unfixable_error("condition_variable.cc", - "Don't use or ") - errors += check_unfixable_error("condition_variable_any.cc", - "Don't use or ") - errors += check_unfixable_error("shared_mutex.cc", "shared_mutex") - errors += check_unfixable_error("shared_mutex.cc", "shared_mutex") - real_time_inject_error = ( - "Don't reference real-world time sources from production code; use injection") - errors += check_unfixable_error("real_time_source.cc", real_time_inject_error) - errors += check_unfixable_error("real_time_system.cc", real_time_inject_error) - errors += check_unfixable_error( - "duration_value.cc", - "Don't use ambiguous duration(value), use an explicit duration type, e.g. Event::TimeSystem::Milliseconds(value)" - ) - errors += check_unfixable_error("system_clock.cc", real_time_inject_error) - errors += check_unfixable_error("steady_clock.cc", real_time_inject_error) - errors += check_unfixable_error( - "unpack_to.cc", "Don't use UnpackTo() directly, use MessageUtil::unpackTo() instead") - errors += check_unfixable_error("condvar_wait_for.cc", real_time_inject_error) - errors += check_unfixable_error("sleep.cc", real_time_inject_error) - errors += check_unfixable_error("std_atomic_free_functions.cc", "std::atomic_*") - errors += check_unfixable_error("std_get_time.cc", "std::get_time") - errors += check_unfixable_error("no_namespace_envoy.cc", - "Unable to find Envoy namespace or NOLINT(namespace-envoy)") - errors += check_unfixable_error("bazel_tools.BUILD", "unexpected @bazel_tools reference") - errors += check_unfixable_error("proto.BUILD", - "unexpected direct external dependency on protobuf") - errors += check_unfixable_error("proto_deps.cc", - "unexpected direct dependency on google.protobuf") - errors += check_unfixable_error("attribute_packed.cc", "Don't use __attribute__((packed))") - errors += check_unfixable_error("designated_initializers.cc", "Don't use designated initializers") - errors += check_unfixable_error("elvis_operator.cc", "Don't use the '?:' operator") - errors += check_unfixable_error("testing_test.cc", - "Don't use 'using testing::Test;, elaborate the type instead") - errors += check_unfixable_error( - "serialize_as_string.cc", - "Don't use MessageLite::SerializeAsString for generating deterministic serialization") - errors += check_unfixable_error( - "version_history/current.rst", - "Version history not in alphabetical order (zzzzz vs aaaaa): please check placement of line") - errors += check_unfixable_error( - "version_history/current.rst", - "Version history not in alphabetical order (this vs aaaa): please check placement of line") - errors += check_unfixable_error( - "version_history/current.rst", - "Version history line malformed. Does not match VERSION_HISTORY_NEW_LINE_REGEX in " - "check_format.py") - errors += check_unfixable_error( - "counter_from_string.cc", - "Don't lookup stats by name at runtime; use StatName saved during construction") - errors += check_unfixable_error( - "gauge_from_string.cc", - "Don't lookup stats by name at runtime; use StatName saved during construction") - errors += check_unfixable_error( - "histogram_from_string.cc", - "Don't lookup stats by name at runtime; use StatName saved during construction") - errors += check_unfixable_error( - "regex.cc", "Don't use std::regex in code that handles untrusted input. Use RegexMatcher") - errors += check_unfixable_error( - "grpc_init.cc", - "Don't call grpc_init() or grpc_shutdown() directly, instantiate Grpc::GoogleGrpcContext. " + - "See #8282") - errors += check_unfixable_error( - "grpc_shutdown.cc", - "Don't call grpc_init() or grpc_shutdown() directly, instantiate Grpc::GoogleGrpcContext. " + - "See #8282") - errors += check_unfixable_error("clang_format_double_off.cc", "clang-format nested off") - errors += check_unfixable_error("clang_format_trailing_off.cc", "clang-format remains off") - errors += check_unfixable_error("clang_format_double_on.cc", "clang-format nested on") - errors += fix_file_expecting_failure( - "api/missing_package.proto", - "Unable to find package name for proto file: ./api/missing_package.proto") - errors += check_unfixable_error("proto_enum_mangling.cc", - "Don't use mangled Protobuf names for enum constants") - errors += check_unfixable_error("test_naming.cc", - "Test names should be CamelCase, starting with a capital letter") - errors += check_unfixable_error("mock_method_n.cc", "use MOCK_METHOD() instead") - errors += check_unfixable_error("for_each_n.cc", "use an alternative for loop instead") - errors += check_unfixable_error( - "test/register_factory.cc", - "Don't use Registry::RegisterFactory or REGISTER_FACTORY in tests, use " - "Registry::InjectFactory instead.") - errors += check_unfixable_error("strerror.cc", - "Don't use strerror; use Envoy::errorDetails instead") - errors += check_unfixable_error( - "std_unordered_map.cc", "Don't use std::unordered_map; use absl::flat_hash_map instead " + - "or absl::node_hash_map if pointer stability of keys/values is required") - errors += check_unfixable_error( - "std_unordered_set.cc", "Don't use std::unordered_set; use absl::flat_hash_set instead " + - "or absl::node_hash_set if pointer stability of keys/values is required") - errors += check_unfixable_error("std_any.cc", "Don't use std::any; use absl::any instead") - errors += check_unfixable_error("std_get_if.cc", - "Don't use std::get_if; use absl::get_if instead") - errors += check_unfixable_error( - "std_holds_alternative.cc", - "Don't use std::holds_alternative; use absl::holds_alternative instead") - errors += check_unfixable_error("std_make_optional.cc", - "Don't use std::make_optional; use absl::make_optional instead") - errors += check_unfixable_error("std_monostate.cc", - "Don't use std::monostate; use absl::monostate instead") - errors += check_unfixable_error("std_optional.cc", - "Don't use std::optional; use absl::optional instead") - errors += check_unfixable_error("std_string_view.cc", - "Don't use std::string_view; use absl::string_view instead") - errors += check_unfixable_error("std_variant.cc", - "Don't use std::variant; use absl::variant instead") - errors += check_unfixable_error("std_visit.cc", "Don't use std::visit; use absl::visit instead") - errors += check_unfixable_error( - "throw.cc", "Don't introduce throws into exception-free files, use error statuses instead.") - errors += check_unfixable_error("pgv_string.proto", "min_bytes is DEPRECATED, Use min_len.") - errors += check_file_expecting_ok("commented_throw.cc") - errors += check_unfixable_error("repository_url.bzl", - "Only repository_locations.bzl may contains URL references") - errors += check_unfixable_error("repository_urls.bzl", - "Only repository_locations.bzl may contains URL references") - - # The following files have errors that can be automatically fixed. - errors += check_and_fix_error("over_enthusiastic_spaces.cc", - "./over_enthusiastic_spaces.cc:3: over-enthusiastic spaces") - errors += check_and_fix_error("extra_enthusiastic_spaces.cc", - "./extra_enthusiastic_spaces.cc:3: over-enthusiastic spaces") - errors += check_and_fix_error("angle_bracket_include.cc", - "envoy includes should not have angle brackets") - errors += check_and_fix_error("proto_style.cc", "incorrect protobuf type reference") - errors += check_and_fix_error("long_line.cc", "clang-format check failed") - errors += check_and_fix_error("header_order.cc", "header_order.py check failed") - errors += check_and_fix_error("clang_format_on.cc", - "./clang_format_on.cc:7: over-enthusiastic spaces") - # Validate that a missing license is added. - errors += check_and_fix_error("license.BUILD", "envoy_build_fixer check failed") - # Validate that an incorrect license is replaced and reordered. - errors += check_and_fix_error("update_license.BUILD", "envoy_build_fixer check failed") - # Validate that envoy_package() is added where there is an envoy_* rule occurring. - errors += check_and_fix_error("add_envoy_package.BUILD", "envoy_build_fixer check failed") - # Validate that we don't add envoy_package() when no envoy_* rule. - errors += check_file_expecting_ok("skip_envoy_package.BUILD") - # Validate that we clean up gratuitous blank lines. - errors += check_and_fix_error("canonical_spacing.BUILD", "envoy_build_fixer check failed") - # Validate that unused loads are removed. - errors += check_and_fix_error("remove_unused_loads.BUILD", "envoy_build_fixer check failed") - # Validate that API proto package deps are computed automagically. - errors += check_and_fix_error("canonical_api_deps.BUILD", - "envoy_build_fixer check failed", - extra_input_files=[ - "canonical_api_deps.cc", "canonical_api_deps.h", - "canonical_api_deps.other.cc" - ]) - errors += check_and_fix_error("bad_envoy_build_sys_ref.BUILD", "Superfluous '@envoy//' prefix") - errors += check_and_fix_error("proto_format.proto", "clang-format check failed") - errors += check_and_fix_error( - "cpp_std.cc", - "term absl::make_unique< should be replaced with standard library term std::make_unique<") - errors += check_and_fix_error("code_conventions.cc", - "term .Times(1); should be replaced with preferred term ;") - - errors += check_file_expecting_ok("real_time_source_override.cc") - errors += check_file_expecting_ok("duration_value_zero.cc") - errors += check_file_expecting_ok("time_system_wait_for.cc") - errors += check_file_expecting_ok("clang_format_off.cc") - return errors + errors = 0 + + # The following error is the error about unavailability of external tools. + errors += check_tool_not_found_error() + + # The following errors can be detected but not fixed automatically. + errors += check_unfixable_error("no_namespace_envoy.cc", + "Unable to find Envoy namespace or NOLINT(namespace-envoy)") + errors += check_unfixable_error("mutex.cc", "Don't use or ") + errors += check_unfixable_error("condition_variable.cc", + "Don't use or ") + errors += check_unfixable_error("condition_variable_any.cc", + "Don't use or ") + errors += check_unfixable_error("shared_mutex.cc", "shared_mutex") + errors += check_unfixable_error("shared_mutex.cc", "shared_mutex") + real_time_inject_error = ( + "Don't reference real-world time sources from production code; use injection") + errors += check_unfixable_error("real_time_source.cc", real_time_inject_error) + errors += check_unfixable_error("real_time_system.cc", real_time_inject_error) + errors += check_unfixable_error( + "duration_value.cc", + "Don't use ambiguous duration(value), use an explicit duration type, e.g. Event::TimeSystem::Milliseconds(value)" + ) + errors += check_unfixable_error("system_clock.cc", real_time_inject_error) + errors += check_unfixable_error("steady_clock.cc", real_time_inject_error) + errors += check_unfixable_error( + "unpack_to.cc", "Don't use UnpackTo() directly, use MessageUtil::unpackTo() instead") + errors += check_unfixable_error("condvar_wait_for.cc", real_time_inject_error) + errors += check_unfixable_error("sleep.cc", real_time_inject_error) + errors += check_unfixable_error("std_atomic_free_functions.cc", "std::atomic_*") + errors += check_unfixable_error("std_get_time.cc", "std::get_time") + errors += check_unfixable_error("no_namespace_envoy.cc", + "Unable to find Envoy namespace or NOLINT(namespace-envoy)") + errors += check_unfixable_error("bazel_tools.BUILD", "unexpected @bazel_tools reference") + errors += check_unfixable_error("proto.BUILD", + "unexpected direct external dependency on protobuf") + errors += check_unfixable_error("proto_deps.cc", + "unexpected direct dependency on google.protobuf") + errors += check_unfixable_error("attribute_packed.cc", "Don't use __attribute__((packed))") + errors += check_unfixable_error("designated_initializers.cc", + "Don't use designated initializers") + errors += check_unfixable_error("elvis_operator.cc", "Don't use the '?:' operator") + errors += check_unfixable_error("testing_test.cc", + "Don't use 'using testing::Test;, elaborate the type instead") + errors += check_unfixable_error( + "serialize_as_string.cc", + "Don't use MessageLite::SerializeAsString for generating deterministic serialization") + errors += check_unfixable_error( + "version_history/current.rst", + "Version history not in alphabetical order (zzzzz vs aaaaa): please check placement of line" + ) + errors += check_unfixable_error( + "version_history/current.rst", + "Version history not in alphabetical order (this vs aaaa): please check placement of line") + errors += check_unfixable_error( + "version_history/current.rst", + "Version history line malformed. Does not match VERSION_HISTORY_NEW_LINE_REGEX in " + "check_format.py") + errors += check_unfixable_error( + "counter_from_string.cc", + "Don't lookup stats by name at runtime; use StatName saved during construction") + errors += check_unfixable_error( + "gauge_from_string.cc", + "Don't lookup stats by name at runtime; use StatName saved during construction") + errors += check_unfixable_error( + "histogram_from_string.cc", + "Don't lookup stats by name at runtime; use StatName saved during construction") + errors += check_unfixable_error( + "regex.cc", "Don't use std::regex in code that handles untrusted input. Use RegexMatcher") + errors += check_unfixable_error( + "grpc_init.cc", + "Don't call grpc_init() or grpc_shutdown() directly, instantiate Grpc::GoogleGrpcContext. " + + "See #8282") + errors += check_unfixable_error( + "grpc_shutdown.cc", + "Don't call grpc_init() or grpc_shutdown() directly, instantiate Grpc::GoogleGrpcContext. " + + "See #8282") + errors += check_unfixable_error("clang_format_double_off.cc", "clang-format nested off") + errors += check_unfixable_error("clang_format_trailing_off.cc", "clang-format remains off") + errors += check_unfixable_error("clang_format_double_on.cc", "clang-format nested on") + errors += fix_file_expecting_failure( + "api/missing_package.proto", + "Unable to find package name for proto file: ./api/missing_package.proto") + errors += check_unfixable_error("proto_enum_mangling.cc", + "Don't use mangled Protobuf names for enum constants") + errors += check_unfixable_error( + "test_naming.cc", "Test names should be CamelCase, starting with a capital letter") + errors += check_unfixable_error("mock_method_n.cc", "use MOCK_METHOD() instead") + errors += check_unfixable_error("for_each_n.cc", "use an alternative for loop instead") + errors += check_unfixable_error( + "test/register_factory.cc", + "Don't use Registry::RegisterFactory or REGISTER_FACTORY in tests, use " + "Registry::InjectFactory instead.") + errors += check_unfixable_error("strerror.cc", + "Don't use strerror; use Envoy::errorDetails instead") + errors += check_unfixable_error( + "std_unordered_map.cc", "Don't use std::unordered_map; use absl::flat_hash_map instead " + + "or absl::node_hash_map if pointer stability of keys/values is required") + errors += check_unfixable_error( + "std_unordered_set.cc", "Don't use std::unordered_set; use absl::flat_hash_set instead " + + "or absl::node_hash_set if pointer stability of keys/values is required") + errors += check_unfixable_error("std_any.cc", "Don't use std::any; use absl::any instead") + errors += check_unfixable_error("std_get_if.cc", + "Don't use std::get_if; use absl::get_if instead") + errors += check_unfixable_error( + "std_holds_alternative.cc", + "Don't use std::holds_alternative; use absl::holds_alternative instead") + errors += check_unfixable_error( + "std_make_optional.cc", "Don't use std::make_optional; use absl::make_optional instead") + errors += check_unfixable_error("std_monostate.cc", + "Don't use std::monostate; use absl::monostate instead") + errors += check_unfixable_error("std_optional.cc", + "Don't use std::optional; use absl::optional instead") + errors += check_unfixable_error("std_string_view.cc", + "Don't use std::string_view; use absl::string_view instead") + errors += check_unfixable_error("std_variant.cc", + "Don't use std::variant; use absl::variant instead") + errors += check_unfixable_error("std_visit.cc", "Don't use std::visit; use absl::visit instead") + errors += check_unfixable_error( + "throw.cc", "Don't introduce throws into exception-free files, use error statuses instead.") + errors += check_unfixable_error("pgv_string.proto", "min_bytes is DEPRECATED, Use min_len.") + errors += check_file_expecting_ok("commented_throw.cc") + errors += check_unfixable_error("repository_url.bzl", + "Only repository_locations.bzl may contains URL references") + errors += check_unfixable_error("repository_urls.bzl", + "Only repository_locations.bzl may contains URL references") + + # The following files have errors that can be automatically fixed. + errors += check_and_fix_error("over_enthusiastic_spaces.cc", + "./over_enthusiastic_spaces.cc:3: over-enthusiastic spaces") + errors += check_and_fix_error("extra_enthusiastic_spaces.cc", + "./extra_enthusiastic_spaces.cc:3: over-enthusiastic spaces") + errors += check_and_fix_error("angle_bracket_include.cc", + "envoy includes should not have angle brackets") + errors += check_and_fix_error("proto_style.cc", "incorrect protobuf type reference") + errors += check_and_fix_error("long_line.cc", "clang-format check failed") + errors += check_and_fix_error("header_order.cc", "header_order.py check failed") + errors += check_and_fix_error("clang_format_on.cc", + "./clang_format_on.cc:7: over-enthusiastic spaces") + # Validate that a missing license is added. + errors += check_and_fix_error("license.BUILD", "envoy_build_fixer check failed") + # Validate that an incorrect license is replaced and reordered. + errors += check_and_fix_error("update_license.BUILD", "envoy_build_fixer check failed") + # Validate that envoy_package() is added where there is an envoy_* rule occurring. + errors += check_and_fix_error("add_envoy_package.BUILD", "envoy_build_fixer check failed") + # Validate that we don't add envoy_package() when no envoy_* rule. + errors += check_file_expecting_ok("skip_envoy_package.BUILD") + # Validate that we clean up gratuitous blank lines. + errors += check_and_fix_error("canonical_spacing.BUILD", "envoy_build_fixer check failed") + # Validate that unused loads are removed. + errors += check_and_fix_error("remove_unused_loads.BUILD", "envoy_build_fixer check failed") + # Validate that API proto package deps are computed automagically. + errors += check_and_fix_error("canonical_api_deps.BUILD", + "envoy_build_fixer check failed", + extra_input_files=[ + "canonical_api_deps.cc", "canonical_api_deps.h", + "canonical_api_deps.other.cc" + ]) + errors += check_and_fix_error("bad_envoy_build_sys_ref.BUILD", "Superfluous '@envoy//' prefix") + errors += check_and_fix_error("proto_format.proto", "clang-format check failed") + errors += check_and_fix_error( + "cpp_std.cc", + "term absl::make_unique< should be replaced with standard library term std::make_unique<") + errors += check_and_fix_error("code_conventions.cc", + "term .Times(1); should be replaced with preferred term ;") + + errors += check_file_expecting_ok("real_time_source_override.cc") + errors += check_file_expecting_ok("duration_value_zero.cc") + errors += check_file_expecting_ok("time_system_wait_for.cc") + errors += check_file_expecting_ok("clang_format_off.cc") + return errors if __name__ == "__main__": - parser = argparse.ArgumentParser(description='tester for check_format.py.') - parser.add_argument('--log', choices=['INFO', 'WARN', 'ERROR'], default='INFO') - args = parser.parse_args() - logging.basicConfig(format='%(message)s', level=args.log) - - # Now create a temp directory to copy the input files, so we can fix them - # without actually fixing our testdata. This requires chdiring to the temp - # directory, so it's annoying to comingle check-tests and fix-tests. - with tempfile.TemporaryDirectory() as tmp: - os.chdir(tmp) - errors = run_checks() - - if errors != 0: - logging.error("%d FAILURES" % errors) - exit(1) - logging.warning("PASS") + parser = argparse.ArgumentParser(description='tester for check_format.py.') + parser.add_argument('--log', choices=['INFO', 'WARN', 'ERROR'], default='INFO') + args = parser.parse_args() + logging.basicConfig(format='%(message)s', level=args.log) + + # Now create a temp directory to copy the input files, so we can fix them + # without actually fixing our testdata. This requires chdiring to the temp + # directory, so it's annoying to comingle check-tests and fix-tests. + with tempfile.TemporaryDirectory() as tmp: + os.chdir(tmp) + errors = run_checks() + + if errors != 0: + logging.error("%d FAILURES" % errors) + exit(1) + logging.warning("PASS") diff --git a/tools/code_format/common.py b/tools/code_format/common.py index b57198a6d253..1ea2d344d79e 100644 --- a/tools/code_format/common.py +++ b/tools/code_format/common.py @@ -1,10 +1,10 @@ def include_dir_order(): - return ( - "envoy", - "common", - "source", - "exe", - "server", - "extensions", - "test", - ) + return ( + "envoy", + "common", + "source", + "exe", + "server", + "extensions", + "test", + ) diff --git a/tools/code_format/envoy_build_fixer.py b/tools/code_format/envoy_build_fixer.py index dd0cb1796b69..e00615d5cb69 100755 --- a/tools/code_format/envoy_build_fixer.py +++ b/tools/code_format/envoy_build_fixer.py @@ -45,83 +45,83 @@ class EnvoyBuildFixerError(Exception): - pass + pass # Run Buildozer commands on a string representing a BUILD file. def run_buildozer(cmds, contents): - with tempfile.NamedTemporaryFile(mode='w') as cmd_file: - # We send the BUILD contents to buildozer on stdin and receive the - # transformed BUILD on stdout. The commands are provided in a file. - cmd_input = '\n'.join('%s|-:%s' % (cmd, target) for cmd, target in cmds) - cmd_file.write(cmd_input) - cmd_file.flush() - r = subprocess.run([BUILDOZER_PATH, '-stdout', '-f', cmd_file.name], - input=contents.encode(), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - # Buildozer uses 3 for success but no change (0 is success and changed). - if r.returncode != 0 and r.returncode != 3: - raise EnvoyBuildFixerError('buildozer execution failed: %s' % r) - # Sometimes buildozer feels like returning nothing when the transform is a - # nop. - if not r.stdout: - return contents - return r.stdout.decode('utf-8') + with tempfile.NamedTemporaryFile(mode='w') as cmd_file: + # We send the BUILD contents to buildozer on stdin and receive the + # transformed BUILD on stdout. The commands are provided in a file. + cmd_input = '\n'.join('%s|-:%s' % (cmd, target) for cmd, target in cmds) + cmd_file.write(cmd_input) + cmd_file.flush() + r = subprocess.run([BUILDOZER_PATH, '-stdout', '-f', cmd_file.name], + input=contents.encode(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + # Buildozer uses 3 for success but no change (0 is success and changed). + if r.returncode != 0 and r.returncode != 3: + raise EnvoyBuildFixerError('buildozer execution failed: %s' % r) + # Sometimes buildozer feels like returning nothing when the transform is a + # nop. + if not r.stdout: + return contents + return r.stdout.decode('utf-8') # Add an Apache 2 license and envoy_package() import and rule as needed. def fix_package_and_license(path, contents): - regex_to_use = PACKAGE_LOAD_BLOCK_REGEX - package_string = 'envoy_package' - - if 'source/extensions' in path: - regex_to_use = EXTENSION_PACKAGE_LOAD_BLOCK_REGEX - package_string = 'envoy_extension_package' - - # Ensure we have an envoy_package import load if this is a real Envoy package. We also allow - # the prefix to be overridden if envoy is included in a larger workspace. - if re.search(ENVOY_RULE_REGEX, contents): - new_load = 'new_load {}//bazel:envoy_build_system.bzl %s' % package_string - contents = run_buildozer([ - (new_load.format(os.getenv("ENVOY_BAZEL_PREFIX", "")), '__pkg__'), - ], contents) - # Envoy package is inserted after the load block containing the - # envoy_package import. - package_and_parens = package_string + '()' - if package_and_parens[:-1] not in contents: - contents = re.sub(regex_to_use, r'\1\n%s\n\n' % package_and_parens, contents) - if package_and_parens not in contents: - raise EnvoyBuildFixerError('Unable to insert %s' % package_and_parens) - - # Delete old licenses. - if re.search(OLD_LICENSES_REGEX, contents): - contents = re.sub(OLD_LICENSES_REGEX, '', contents) - # Add canonical Apache 2 license. - contents = LICENSE_STRING + contents - return contents + regex_to_use = PACKAGE_LOAD_BLOCK_REGEX + package_string = 'envoy_package' + + if 'source/extensions' in path: + regex_to_use = EXTENSION_PACKAGE_LOAD_BLOCK_REGEX + package_string = 'envoy_extension_package' + + # Ensure we have an envoy_package import load if this is a real Envoy package. We also allow + # the prefix to be overridden if envoy is included in a larger workspace. + if re.search(ENVOY_RULE_REGEX, contents): + new_load = 'new_load {}//bazel:envoy_build_system.bzl %s' % package_string + contents = run_buildozer([ + (new_load.format(os.getenv("ENVOY_BAZEL_PREFIX", "")), '__pkg__'), + ], contents) + # Envoy package is inserted after the load block containing the + # envoy_package import. + package_and_parens = package_string + '()' + if package_and_parens[:-1] not in contents: + contents = re.sub(regex_to_use, r'\1\n%s\n\n' % package_and_parens, contents) + if package_and_parens not in contents: + raise EnvoyBuildFixerError('Unable to insert %s' % package_and_parens) + + # Delete old licenses. + if re.search(OLD_LICENSES_REGEX, contents): + contents = re.sub(OLD_LICENSES_REGEX, '', contents) + # Add canonical Apache 2 license. + contents = LICENSE_STRING + contents + return contents # Run Buildifier commands on a string with lint mode. def buildifier_lint(contents): - r = subprocess.run([BUILDIFIER_PATH, '-lint=fix', '-mode=fix', '-type=build'], - input=contents.encode(), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if r.returncode != 0: - raise EnvoyBuildFixerError('buildozer execution failed: %s' % r) - return r.stdout.decode('utf-8') + r = subprocess.run([BUILDIFIER_PATH, '-lint=fix', '-mode=fix', '-type=build'], + input=contents.encode(), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + if r.returncode != 0: + raise EnvoyBuildFixerError('buildozer execution failed: %s' % r) + return r.stdout.decode('utf-8') # Find all the API headers in a C++ source file. def find_api_headers(source_path): - api_hdrs = set([]) - contents = pathlib.Path(source_path).read_text(encoding='utf8') - for line in contents.split('\n'): - match = re.match(API_INCLUDE_REGEX, line) - if match: - api_hdrs.add(match.group(1)) - return api_hdrs + api_hdrs = set([]) + contents = pathlib.Path(source_path).read_text(encoding='utf8') + for line in contents.split('\n'): + match = re.match(API_INCLUDE_REGEX, line) + if match: + api_hdrs.add(match.group(1)) + return api_hdrs # Infer and adjust rule dependencies in BUILD files for @envoy_api proto @@ -133,73 +133,75 @@ def find_api_headers(source_path): # compilation database and full build of Envoy, envoy_build_fixer.py is run # under check_format, which should be fast for developers. def fix_api_deps(path, contents): - source_dirname = os.path.dirname(path) - buildozer_out = run_buildozer([ - ('print kind name srcs hdrs deps', '*'), - ], contents).strip() - deps_mutation_cmds = [] - for line in buildozer_out.split('\n'): - match = re.match(BUILDOZER_PRINT_REGEX, line) - if not match: - # buildozer might emit complex multiline output when a 'select' or other - # macro is used. We're not smart enough to handle these today and they - # require manual fixup. - # TODO(htuch): investigate using --output_proto on buildozer to be able to - # consume something more usable in this situation. - continue - kind, name, srcs, hdrs, deps = match.groups() - if not name: - continue - source_paths = [] - if srcs != 'missing': - source_paths.extend( - os.path.join(source_dirname, f) - for f in srcs.split() - if f.endswith('.cc') or f.endswith('.h')) - if hdrs != 'missing': - source_paths.extend(os.path.join(source_dirname, f) for f in hdrs.split() if f.endswith('.h')) - api_hdrs = set([]) - for p in source_paths: - # We're not smart enough to infer on generated files. - if os.path.exists(p): - api_hdrs = api_hdrs.union(find_api_headers(p)) - actual_api_deps = set(['@envoy_api//%s:pkg_cc_proto' % h for h in api_hdrs]) - existing_api_deps = set([]) - if deps != 'missing': - existing_api_deps = set([ - d for d in deps.split() if d.startswith('@envoy_api') and d.endswith('pkg_cc_proto') and - d != '@com_github_cncf_udpa//udpa/annotations:pkg_cc_proto' - ]) - deps_to_remove = existing_api_deps.difference(actual_api_deps) - if deps_to_remove: - deps_mutation_cmds.append(('remove deps %s' % ' '.join(deps_to_remove), name)) - deps_to_add = actual_api_deps.difference(existing_api_deps) - if deps_to_add: - deps_mutation_cmds.append(('add deps %s' % ' '.join(deps_to_add), name)) - return run_buildozer(deps_mutation_cmds, contents) + source_dirname = os.path.dirname(path) + buildozer_out = run_buildozer([ + ('print kind name srcs hdrs deps', '*'), + ], contents).strip() + deps_mutation_cmds = [] + for line in buildozer_out.split('\n'): + match = re.match(BUILDOZER_PRINT_REGEX, line) + if not match: + # buildozer might emit complex multiline output when a 'select' or other + # macro is used. We're not smart enough to handle these today and they + # require manual fixup. + # TODO(htuch): investigate using --output_proto on buildozer to be able to + # consume something more usable in this situation. + continue + kind, name, srcs, hdrs, deps = match.groups() + if not name: + continue + source_paths = [] + if srcs != 'missing': + source_paths.extend( + os.path.join(source_dirname, f) + for f in srcs.split() + if f.endswith('.cc') or f.endswith('.h')) + if hdrs != 'missing': + source_paths.extend( + os.path.join(source_dirname, f) for f in hdrs.split() if f.endswith('.h')) + api_hdrs = set([]) + for p in source_paths: + # We're not smart enough to infer on generated files. + if os.path.exists(p): + api_hdrs = api_hdrs.union(find_api_headers(p)) + actual_api_deps = set(['@envoy_api//%s:pkg_cc_proto' % h for h in api_hdrs]) + existing_api_deps = set([]) + if deps != 'missing': + existing_api_deps = set([ + d for d in deps.split() + if d.startswith('@envoy_api') and d.endswith('pkg_cc_proto') and + d != '@com_github_cncf_udpa//udpa/annotations:pkg_cc_proto' + ]) + deps_to_remove = existing_api_deps.difference(actual_api_deps) + if deps_to_remove: + deps_mutation_cmds.append(('remove deps %s' % ' '.join(deps_to_remove), name)) + deps_to_add = actual_api_deps.difference(existing_api_deps) + if deps_to_add: + deps_mutation_cmds.append(('add deps %s' % ' '.join(deps_to_add), name)) + return run_buildozer(deps_mutation_cmds, contents) def fix_build(path): - with open(path, 'r') as f: - contents = f.read() - xforms = [ - functools.partial(fix_package_and_license, path), - functools.partial(fix_api_deps, path), - buildifier_lint, - ] - for xform in xforms: - contents = xform(contents) - return contents + with open(path, 'r') as f: + contents = f.read() + xforms = [ + functools.partial(fix_package_and_license, path), + functools.partial(fix_api_deps, path), + buildifier_lint, + ] + for xform in xforms: + contents = xform(contents) + return contents if __name__ == '__main__': - if len(sys.argv) == 2: - sys.stdout.write(fix_build(sys.argv[1])) - sys.exit(0) - elif len(sys.argv) == 3: - reorderd_source = fix_build(sys.argv[1]) - with open(sys.argv[2], 'w') as f: - f.write(reorderd_source) - sys.exit(0) - print('Usage: %s []' % sys.argv[0]) - sys.exit(1) + if len(sys.argv) == 2: + sys.stdout.write(fix_build(sys.argv[1])) + sys.exit(0) + elif len(sys.argv) == 3: + reorderd_source = fix_build(sys.argv[1]) + with open(sys.argv[2], 'w') as f: + f.write(reorderd_source) + sys.exit(0) + print('Usage: %s []' % sys.argv[0]) + sys.exit(1) diff --git a/tools/code_format/flake8.conf b/tools/code_format/flake8.conf index 3f26ffac0492..229b69928369 100644 --- a/tools/code_format/flake8.conf +++ b/tools/code_format/flake8.conf @@ -2,7 +2,7 @@ [flake8] # TODO(phlax): make this an exclusive list and enable most tests -select = N802 +select = N802,E111 # TODO(phlax): exclude less exclude = build_docs,.git,generated,test,examples diff --git a/tools/code_format/format_python_tools.py b/tools/code_format/format_python_tools.py index f622dcb54ad3..c54779df4595 100644 --- a/tools/code_format/format_python_tools.py +++ b/tools/code_format/format_python_tools.py @@ -9,68 +9,69 @@ def collect_files(): - """Collect all Python files in the tools directory. + """Collect all Python files in the tools directory. Returns: A collection of python files in the tools directory excluding any directories in the EXCLUDE_LIST constant. """ - # TODO: Add ability to collect a specific file or files. - matches = [] - path_parts = os.getcwd().split('/') - dirname = '.' - if path_parts[-1] == 'tools': - dirname = '/'.join(path_parts[:-1]) - for root, dirnames, filenames in os.walk(dirname): - dirnames[:] = [d for d in dirnames if d not in EXCLUDE_LIST] - for filename in fnmatch.filter(filenames, '*.py'): - if not filename.endswith('_pb2.py') and not filename.endswith('_pb2_grpc.py'): - matches.append(os.path.join(root, filename)) - return matches + # TODO: Add ability to collect a specific file or files. + matches = [] + path_parts = os.getcwd().split('/') + dirname = '.' + if path_parts[-1] == 'tools': + dirname = '/'.join(path_parts[:-1]) + for root, dirnames, filenames in os.walk(dirname): + dirnames[:] = [d for d in dirnames if d not in EXCLUDE_LIST] + for filename in fnmatch.filter(filenames, '*.py'): + if not filename.endswith('_pb2.py') and not filename.endswith('_pb2_grpc.py'): + matches.append(os.path.join(root, filename)) + return matches def validate_format(fix=False): - """Check the format of python files in the tools directory. + """Check the format of python files in the tools directory. Arguments: fix: a flag to indicate if fixes should be applied. """ - fixes_required = False - failed_update_files = set() - successful_update_files = set() - for python_file in collect_files(): - reformatted_source, encoding, changed = FormatFile(python_file, - style_config='tools/code_format/.style.yapf', - in_place=fix, - print_diff=not fix) - if not fix: - fixes_required = True if changed else fixes_required - if reformatted_source: - print(reformatted_source) - continue - file_list = failed_update_files if reformatted_source else successful_update_files - file_list.add(python_file) - if fix: - display_fix_results(successful_update_files, failed_update_files) - fixes_required = len(failed_update_files) > 0 - return not fixes_required + fixes_required = False + failed_update_files = set() + successful_update_files = set() + for python_file in collect_files(): + reformatted_source, encoding, changed = FormatFile( + python_file, + style_config='tools/code_format/.style.yapf', + in_place=fix, + print_diff=not fix) + if not fix: + fixes_required = True if changed else fixes_required + if reformatted_source: + print(reformatted_source) + continue + file_list = failed_update_files if reformatted_source else successful_update_files + file_list.add(python_file) + if fix: + display_fix_results(successful_update_files, failed_update_files) + fixes_required = len(failed_update_files) > 0 + return not fixes_required def display_fix_results(successful_files, failed_files): - if successful_files: - print('Successfully fixed {} files'.format(len(successful_files))) + if successful_files: + print('Successfully fixed {} files'.format(len(successful_files))) - if failed_files: - print('The following files failed to fix inline:') - for failed_file in failed_files: - print(' - {}'.format(failed_file)) + if failed_files: + print('The following files failed to fix inline:') + for failed_file in failed_files: + print(' - {}'.format(failed_file)) if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Tool to format python files.') - parser.add_argument('action', - choices=['check', 'fix'], - default='check', - help='Fix invalid syntax in files.') - args = parser.parse_args() - is_valid = validate_format(args.action == 'fix') - sys.exit(0 if is_valid else 1) + parser = argparse.ArgumentParser(description='Tool to format python files.') + parser.add_argument('action', + choices=['check', 'fix'], + default='check', + help='Fix invalid syntax in files.') + args = parser.parse_args() + is_valid = validate_format(args.action == 'fix') + sys.exit(0 if is_valid else 1) diff --git a/tools/code_format/header_order.py b/tools/code_format/header_order.py index ca4ab569f524..3bca02fc8c14 100755 --- a/tools/code_format/header_order.py +++ b/tools/code_format/header_order.py @@ -20,101 +20,101 @@ def reorder_headers(path): - source = pathlib.Path(path).read_text(encoding='utf-8') - - all_lines = iter(source.split('\n')) - before_includes_lines = [] - includes_lines = [] - after_includes_lines = [] - - # Collect all the lines prior to the first #include in before_includes_lines. - try: - while True: - line = next(all_lines) - if line.startswith('#include'): - includes_lines.append(line) - break - before_includes_lines.append(line) - except StopIteration: - pass - - # Collect all the #include and whitespace lines in includes_lines. - try: - while True: - line = next(all_lines) - if not line: - continue - if not line.startswith('#include'): - after_includes_lines.append(line) - break - includes_lines.append(line) - except StopIteration: - pass - - # Collect the remaining lines in after_includes_lines. - after_includes_lines += list(all_lines) - - # Filter for includes that finds the #include of the header file associated with the source file - # being processed. E.g. if 'path' is source/common/common/hex.cc, this filter matches - # "common/common/hex.h". - def file_header_filter(): - return lambda f: f.endswith('.h"') and path.endswith(f[1:-3] + '.cc') - - def regex_filter(regex): - return lambda f: re.match(regex, f) - - # Filters that define the #include blocks - block_filters = [ - file_header_filter(), - regex_filter('<.*\.h>'), - regex_filter('<.*>'), - ] - for subdir in include_dir_order: - block_filters.append(regex_filter('"' + subdir + '/.*"')) - - blocks = [] - already_included = set([]) - for b in block_filters: - block = [] - for line in includes_lines: - header = line[len('#include '):] - if line not in already_included and b(header): - block.append(line) - already_included.add(line) - if len(block) > 0: - blocks.append(block) - - # Anything not covered by block_filters gets its own block. - misc_headers = list(set(includes_lines).difference(already_included)) - if len(misc_headers) > 0: - blocks.append(misc_headers) - - reordered_includes_lines = '\n\n'.join(['\n'.join(sorted(block)) for block in blocks]) - - if reordered_includes_lines: - reordered_includes_lines += '\n' - - return '\n'.join( - filter(lambda x: x, [ - '\n'.join(before_includes_lines), - reordered_includes_lines, - '\n'.join(after_includes_lines), - ])) + source = pathlib.Path(path).read_text(encoding='utf-8') + + all_lines = iter(source.split('\n')) + before_includes_lines = [] + includes_lines = [] + after_includes_lines = [] + + # Collect all the lines prior to the first #include in before_includes_lines. + try: + while True: + line = next(all_lines) + if line.startswith('#include'): + includes_lines.append(line) + break + before_includes_lines.append(line) + except StopIteration: + pass + + # Collect all the #include and whitespace lines in includes_lines. + try: + while True: + line = next(all_lines) + if not line: + continue + if not line.startswith('#include'): + after_includes_lines.append(line) + break + includes_lines.append(line) + except StopIteration: + pass + + # Collect the remaining lines in after_includes_lines. + after_includes_lines += list(all_lines) + + # Filter for includes that finds the #include of the header file associated with the source file + # being processed. E.g. if 'path' is source/common/common/hex.cc, this filter matches + # "common/common/hex.h". + def file_header_filter(): + return lambda f: f.endswith('.h"') and path.endswith(f[1:-3] + '.cc') + + def regex_filter(regex): + return lambda f: re.match(regex, f) + + # Filters that define the #include blocks + block_filters = [ + file_header_filter(), + regex_filter('<.*\.h>'), + regex_filter('<.*>'), + ] + for subdir in include_dir_order: + block_filters.append(regex_filter('"' + subdir + '/.*"')) + + blocks = [] + already_included = set([]) + for b in block_filters: + block = [] + for line in includes_lines: + header = line[len('#include '):] + if line not in already_included and b(header): + block.append(line) + already_included.add(line) + if len(block) > 0: + blocks.append(block) + + # Anything not covered by block_filters gets its own block. + misc_headers = list(set(includes_lines).difference(already_included)) + if len(misc_headers) > 0: + blocks.append(misc_headers) + + reordered_includes_lines = '\n\n'.join(['\n'.join(sorted(block)) for block in blocks]) + + if reordered_includes_lines: + reordered_includes_lines += '\n' + + return '\n'.join( + filter(lambda x: x, [ + '\n'.join(before_includes_lines), + reordered_includes_lines, + '\n'.join(after_includes_lines), + ])) if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Header reordering.') - parser.add_argument('--path', type=str, help='specify the path to the header file') - parser.add_argument('--rewrite', action='store_true', help='rewrite header file in-place') - parser.add_argument('--include_dir_order', - type=str, - default=','.join(common.include_dir_order()), - help='specify the header block include directory order') - args = parser.parse_args() - target_path = args.path - include_dir_order = args.include_dir_order.split(',') - reorderd_source = reorder_headers(target_path) - if args.rewrite: - pathlib.Path(target_path).write_text(reorderd_source, encoding='utf-8') - else: - sys.stdout.buffer.write(reorderd_source.encode('utf-8')) + parser = argparse.ArgumentParser(description='Header reordering.') + parser.add_argument('--path', type=str, help='specify the path to the header file') + parser.add_argument('--rewrite', action='store_true', help='rewrite header file in-place') + parser.add_argument('--include_dir_order', + type=str, + default=','.join(common.include_dir_order()), + help='specify the header block include directory order') + args = parser.parse_args() + target_path = args.path + include_dir_order = args.include_dir_order.split(',') + reorderd_source = reorder_headers(target_path) + if args.rewrite: + pathlib.Path(target_path).write_text(reorderd_source, encoding='utf-8') + else: + sys.stdout.buffer.write(reorderd_source.encode('utf-8')) diff --git a/tools/code_format/paths.py b/tools/code_format/paths.py index 4a87b15b175e..ea88439e3c26 100644 --- a/tools/code_format/paths.py +++ b/tools/code_format/paths.py @@ -4,10 +4,10 @@ def get_buildifier(): - return os.getenv("BUILDIFIER_BIN") or (os.path.expandvars("$GOPATH/bin/buildifier") - if os.getenv("GOPATH") else shutil.which("buildifier")) + return os.getenv("BUILDIFIER_BIN") or (os.path.expandvars("$GOPATH/bin/buildifier") + if os.getenv("GOPATH") else shutil.which("buildifier")) def get_buildozer(): - return os.getenv("BUILDOZER_BIN") or (os.path.expandvars("$GOPATH/bin/buildozer") - if os.getenv("GOPATH") else shutil.which("buildozer")) + return os.getenv("BUILDOZER_BIN") or (os.path.expandvars("$GOPATH/bin/buildozer") + if os.getenv("GOPATH") else shutil.which("buildozer")) diff --git a/tools/code_format/python_flake8.py b/tools/code_format/python_flake8.py index a1912bd13873..21cd3c94e083 100644 --- a/tools/code_format/python_flake8.py +++ b/tools/code_format/python_flake8.py @@ -3,9 +3,9 @@ def main(): - subprocess.run(f"python -m flake8 --config tools/code_format/flake8.conf .".split(), - cwd=sys.argv[1]) + subprocess.run(f"python -m flake8 --config tools/code_format/flake8.conf .".split(), + cwd=sys.argv[1]) if __name__ == "__main__": - main() + main() diff --git a/tools/config_validation/validate_fragment.py b/tools/config_validation/validate_fragment.py index e014b024ea82..bc8a9510b143 100644 --- a/tools/config_validation/validate_fragment.py +++ b/tools/config_validation/validate_fragment.py @@ -23,7 +23,7 @@ def validate_fragment(type_name, fragment): - """Validate a dictionary representing a JSON/YAML fragment against an Envoy API proto3 type. + """Validate a dictionary representing a JSON/YAML fragment against an Envoy API proto3 type. Throws Protobuf errors on parsing exceptions, successful validations produce no result. @@ -34,39 +34,39 @@ def validate_fragment(type_name, fragment): fragment: a dictionary representing the parsed JSON/YAML configuration fragment. """ - json_fragment = json.dumps(fragment) + json_fragment = json.dumps(fragment) - r = runfiles.Create() - all_protos_pb_text_path = r.Rlocation( - 'envoy/tools/type_whisperer/all_protos_with_ext_pb_text.pb_text') - file_desc_set = descriptor_pb2.FileDescriptorSet() - text_format.Parse(pathlib.Path(all_protos_pb_text_path).read_text(), - file_desc_set, - allow_unknown_extension=True) + r = runfiles.Create() + all_protos_pb_text_path = r.Rlocation( + 'envoy/tools/type_whisperer/all_protos_with_ext_pb_text.pb_text') + file_desc_set = descriptor_pb2.FileDescriptorSet() + text_format.Parse(pathlib.Path(all_protos_pb_text_path).read_text(), + file_desc_set, + allow_unknown_extension=True) - pool = descriptor_pool.DescriptorPool() - for f in file_desc_set.file: - pool.Add(f) - desc = pool.FindMessageTypeByName(type_name) - msg = message_factory.MessageFactory(pool=pool).GetPrototype(desc)() - json_format.Parse(json_fragment, msg, descriptor_pool=pool) + pool = descriptor_pool.DescriptorPool() + for f in file_desc_set.file: + pool.Add(f) + desc = pool.FindMessageTypeByName(type_name) + msg = message_factory.MessageFactory(pool=pool).GetPrototype(desc)() + json_format.Parse(json_fragment, msg, descriptor_pool=pool) def parse_args(): - parser = argparse.ArgumentParser( - description='Validate a YAML fragment against an Envoy API proto3 type.') - parser.add_argument( - 'message_type', - help='a string providing the type name, e.g. envoy.config.bootstrap.v3.Bootstrap.') - parser.add_argument('fragment_path', nargs='?', help='Path to a YAML configuration fragment.') - parser.add_argument('-s', required=False, help='YAML configuration fragment.') + parser = argparse.ArgumentParser( + description='Validate a YAML fragment against an Envoy API proto3 type.') + parser.add_argument( + 'message_type', + help='a string providing the type name, e.g. envoy.config.bootstrap.v3.Bootstrap.') + parser.add_argument('fragment_path', nargs='?', help='Path to a YAML configuration fragment.') + parser.add_argument('-s', required=False, help='YAML configuration fragment.') - return parser.parse_args() + return parser.parse_args() if __name__ == '__main__': - parsed_args = parse_args() - message_type = parsed_args.message_type - content = parsed_args.s if (parsed_args.fragment_path is None) else pathlib.Path( - parsed_args.fragment_path).read_text() - validate_fragment(message_type, yaml.safe_load(content)) + parsed_args = parse_args() + message_type = parsed_args.message_type + content = parsed_args.s if (parsed_args.fragment_path is None) else pathlib.Path( + parsed_args.fragment_path).read_text() + validate_fragment(message_type, yaml.safe_load(content)) diff --git a/tools/dependency/cve_scan.py b/tools/dependency/cve_scan.py index 2626b9dab5f4..68b8b6603d58 100755 --- a/tools/dependency/cve_scan.py +++ b/tools/dependency/cve_scan.py @@ -62,25 +62,25 @@ class Cpe(namedtuple('CPE', ['part', 'vendor', 'product', 'version'])): - '''Model a subset of CPE fields that are used in CPE matching.''' + '''Model a subset of CPE fields that are used in CPE matching.''' - @classmethod - def from_string(cls, cpe_str): - assert (cpe_str.startswith('cpe:2.3:')) - components = cpe_str.split(':') - assert (len(components) >= 6) - return cls(*components[2:6]) + @classmethod + def from_string(cls, cpe_str): + assert (cpe_str.startswith('cpe:2.3:')) + components = cpe_str.split(':') + assert (len(components) >= 6) + return cls(*components[2:6]) - def __str__(self): - return f'cpe:2.3:{self.part}:{self.vendor}:{self.product}:{self.version}' + def __str__(self): + return f'cpe:2.3:{self.part}:{self.vendor}:{self.product}:{self.version}' - def vendor_normalized(self): - '''Return a normalized CPE where only part and vendor are significant.''' - return Cpe(self.part, self.vendor, '*', '*') + def vendor_normalized(self): + '''Return a normalized CPE where only part and vendor are significant.''' + return Cpe(self.part, self.vendor, '*', '*') def parse_cve_json(cve_json, cves, cpe_revmap): - '''Parse CVE JSON dictionary. + '''Parse CVE JSON dictionary. Args: cve_json: a NIST CVE JSON dictionary. @@ -88,42 +88,42 @@ def parse_cve_json(cve_json, cves, cpe_revmap): cpe_revmap: a reverse map from vendor normalized CPE to CVE ID string. ''' - # This provides an over-approximation of possible CPEs affected by CVE nodes - # metadata; it traverses the entire AND-OR tree and just gathers every CPE - # observed. Generally we expect that most of Envoy's CVE-CPE matches to be - # simple, plus it's interesting to consumers of this data to understand when a - # CPE pops up, even in a conditional setting. - def gather_cpes(nodes, cpe_set): - for node in nodes: - for cpe_match in node.get('cpe_match', []): - cpe_set.add(Cpe.from_string(cpe_match['cpe23Uri'])) - gather_cpes(node.get('children', []), cpe_set) - - for cve in cve_json['CVE_Items']: - cve_id = cve['cve']['CVE_data_meta']['ID'] - description = cve['cve']['description']['description_data'][0]['value'] - cpe_set = set() - gather_cpes(cve['configurations']['nodes'], cpe_set) - if len(cpe_set) == 0: - continue - cvss_v3_score = cve['impact']['baseMetricV3']['cvssV3']['baseScore'] - cvss_v3_severity = cve['impact']['baseMetricV3']['cvssV3']['baseSeverity'] - - def parse_cve_date(date_str): - assert (date_str.endswith('Z')) - return dt.date.fromisoformat(date_str.split('T')[0]) - - published_date = parse_cve_date(cve['publishedDate']) - last_modified_date = parse_cve_date(cve['lastModifiedDate']) - cves[cve_id] = Cve(cve_id, description, cpe_set, cvss_v3_score, cvss_v3_severity, - published_date, last_modified_date) - for cpe in cpe_set: - cpe_revmap[str(cpe.vendor_normalized())].add(cve_id) - return cves, cpe_revmap + # This provides an over-approximation of possible CPEs affected by CVE nodes + # metadata; it traverses the entire AND-OR tree and just gathers every CPE + # observed. Generally we expect that most of Envoy's CVE-CPE matches to be + # simple, plus it's interesting to consumers of this data to understand when a + # CPE pops up, even in a conditional setting. + def gather_cpes(nodes, cpe_set): + for node in nodes: + for cpe_match in node.get('cpe_match', []): + cpe_set.add(Cpe.from_string(cpe_match['cpe23Uri'])) + gather_cpes(node.get('children', []), cpe_set) + + for cve in cve_json['CVE_Items']: + cve_id = cve['cve']['CVE_data_meta']['ID'] + description = cve['cve']['description']['description_data'][0]['value'] + cpe_set = set() + gather_cpes(cve['configurations']['nodes'], cpe_set) + if len(cpe_set) == 0: + continue + cvss_v3_score = cve['impact']['baseMetricV3']['cvssV3']['baseScore'] + cvss_v3_severity = cve['impact']['baseMetricV3']['cvssV3']['baseSeverity'] + + def parse_cve_date(date_str): + assert (date_str.endswith('Z')) + return dt.date.fromisoformat(date_str.split('T')[0]) + + published_date = parse_cve_date(cve['publishedDate']) + last_modified_date = parse_cve_date(cve['lastModifiedDate']) + cves[cve_id] = Cve(cve_id, description, cpe_set, cvss_v3_score, cvss_v3_severity, + published_date, last_modified_date) + for cpe in cpe_set: + cpe_revmap[str(cpe.vendor_normalized())].add(cve_id) + return cves, cpe_revmap def download_cve_data(urls): - '''Download NIST CVE JSON databases from given URLs and parse. + '''Download NIST CVE JSON databases from given URLs and parse. Args: urls: a list of URLs. @@ -131,20 +131,20 @@ def download_cve_data(urls): cves: dictionary mapping CVE ID string to Cve object (output). cpe_revmap: a reverse map from vendor normalized CPE to CVE ID string. ''' - cves = {} - cpe_revmap = defaultdict(set) - for url in urls: - print(f'Loading NIST CVE database from {url}...') - with urllib.request.urlopen(url) as request: - with gzip.GzipFile(fileobj=request) as json_data: - parse_cve_json(json.loads(json_data.read()), cves, cpe_revmap) - return cves, cpe_revmap + cves = {} + cpe_revmap = defaultdict(set) + for url in urls: + print(f'Loading NIST CVE database from {url}...') + with urllib.request.urlopen(url) as request: + with gzip.GzipFile(fileobj=request) as json_data: + parse_cve_json(json.loads(json_data.read()), cves, cpe_revmap) + return cves, cpe_revmap def format_cve_details(cve, deps): - formatted_deps = ', '.join(sorted(deps)) - wrapped_description = '\n '.join(textwrap.wrap(cve.description)) - return f''' + formatted_deps = ', '.join(sorted(deps)) + wrapped_description = '\n '.join(textwrap.wrap(cve.description)) + return f''' CVE ID: {cve.id} CVSS v3 score: {cve.score} Severity: {cve.severity} @@ -161,7 +161,7 @@ def format_cve_details(cve, deps): def regex_groups_match(regex, lhs, rhs): - '''Do two strings match modulo a regular expression? + '''Do two strings match modulo a regular expression? Args: regex: regular expression @@ -170,16 +170,16 @@ def regex_groups_match(regex, lhs, rhs): Returns: A boolean indicating match. ''' - lhs_match = regex.search(lhs) - if lhs_match: - rhs_match = regex.search(rhs) - if rhs_match and lhs_match.groups() == rhs_match.groups(): - return True - return False + lhs_match = regex.search(lhs) + if lhs_match: + rhs_match = regex.search(rhs) + if rhs_match and lhs_match.groups() == rhs_match.groups(): + return True + return False def cpe_match(cpe, dep_metadata): - '''Heuristically match dependency metadata against CPE. + '''Heuristically match dependency metadata against CPE. We have a number of rules below that should are easy to compute without having to look at the dependency metadata. In the future, with additional access to @@ -195,39 +195,39 @@ def cpe_match(cpe, dep_metadata): Returns: A boolean indicating a match. ''' - dep_cpe = Cpe.from_string(dep_metadata['cpe']) - dep_version = dep_metadata['version'] - # The 'part' and 'vendor' must be an exact match. - if cpe.part != dep_cpe.part: - return False - if cpe.vendor != dep_cpe.vendor: - return False - # We allow Envoy dependency CPEs to wildcard the 'product', this is useful for - # LLVM where multiple product need to be covered. - if dep_cpe.product != '*' and cpe.product != dep_cpe.product: + dep_cpe = Cpe.from_string(dep_metadata['cpe']) + dep_version = dep_metadata['version'] + # The 'part' and 'vendor' must be an exact match. + if cpe.part != dep_cpe.part: + return False + if cpe.vendor != dep_cpe.vendor: + return False + # We allow Envoy dependency CPEs to wildcard the 'product', this is useful for + # LLVM where multiple product need to be covered. + if dep_cpe.product != '*' and cpe.product != dep_cpe.product: + return False + # Wildcard versions always match. + if cpe.version == '*': + return True + # An exact version match is a hit. + if cpe.version == dep_version: + return True + # Allow the 'release_date' dependency metadata to substitute for date. + # TODO(htuch): Consider fuzzier date ranges. + if cpe.version == dep_metadata['release_date']: + return True + # Try a fuzzy date match to deal with versions like fips-20190304 in dependency version. + if regex_groups_match(FUZZY_DATE_RE, dep_version, cpe.version): + return True + # Try a fuzzy semver match to deal with things like 2.1.0-beta3. + if regex_groups_match(FUZZY_SEMVER_RE, dep_version, cpe.version): + return True + # Fall-thru. return False - # Wildcard versions always match. - if cpe.version == '*': - return True - # An exact version match is a hit. - if cpe.version == dep_version: - return True - # Allow the 'release_date' dependency metadata to substitute for date. - # TODO(htuch): Consider fuzzier date ranges. - if cpe.version == dep_metadata['release_date']: - return True - # Try a fuzzy date match to deal with versions like fips-20190304 in dependency version. - if regex_groups_match(FUZZY_DATE_RE, dep_version, cpe.version): - return True - # Try a fuzzy semver match to deal with things like 2.1.0-beta3. - if regex_groups_match(FUZZY_SEMVER_RE, dep_version, cpe.version): - return True - # Fall-thru. - return False def cve_match(cve, dep_metadata): - '''Heuristically match dependency metadata against CVE. + '''Heuristically match dependency metadata against CVE. In general, we allow false positives but want to keep the noise low, to avoid the toil around having to populate IGNORES_CVES. @@ -238,27 +238,27 @@ def cve_match(cve, dep_metadata): Returns: A boolean indicating a match. ''' - wildcard_version_match = False - # Consider each CPE attached to the CVE for a match against the dependency CPE. - for cpe in cve.cpes: - if cpe_match(cpe, dep_metadata): - # Wildcard version matches need additional heuristics unrelated to CPE to - # qualify, e.g. last updated date. - if cpe.version == '*': - wildcard_version_match = True - else: - return True - if wildcard_version_match: - # If the CVE was published after the dependency was last updated, it's a - # potential match. - last_dep_update = dt.date.fromisoformat(dep_metadata['release_date']) - if last_dep_update <= cve.published_date: - return True - return False + wildcard_version_match = False + # Consider each CPE attached to the CVE for a match against the dependency CPE. + for cpe in cve.cpes: + if cpe_match(cpe, dep_metadata): + # Wildcard version matches need additional heuristics unrelated to CPE to + # qualify, e.g. last updated date. + if cpe.version == '*': + wildcard_version_match = True + else: + return True + if wildcard_version_match: + # If the CVE was published after the dependency was last updated, it's a + # potential match. + last_dep_update = dt.date.fromisoformat(dep_metadata['release_date']) + if last_dep_update <= cve.published_date: + return True + return False def cve_scan(cves, cpe_revmap, cve_allowlist, repository_locations): - '''Scan for CVEs in a parsed NIST CVE database. + '''Scan for CVEs in a parsed NIST CVE database. Args: cves: CVE dictionary as provided by download_cve_data(). @@ -270,38 +270,40 @@ def cve_scan(cves, cpe_revmap, cve_allowlist, repository_locations): possible_cves: a dictionary mapping CVE IDs to Cve objects. cve_deps: a dictionary mapping CVE IDs to dependency names. ''' - possible_cves = {} - cve_deps = defaultdict(list) - for dep, metadata in repository_locations.items(): - cpe = metadata.get('cpe', 'N/A') - if cpe == 'N/A': - continue - candidate_cve_ids = cpe_revmap.get(str(Cpe.from_string(cpe).vendor_normalized()), []) - for cve_id in candidate_cve_ids: - cve = cves[cve_id] - if cve.id in cve_allowlist: - continue - if cve_match(cve, metadata): - possible_cves[cve_id] = cve - cve_deps[cve_id].append(dep) - return possible_cves, cve_deps + possible_cves = {} + cve_deps = defaultdict(list) + for dep, metadata in repository_locations.items(): + cpe = metadata.get('cpe', 'N/A') + if cpe == 'N/A': + continue + candidate_cve_ids = cpe_revmap.get(str(Cpe.from_string(cpe).vendor_normalized()), []) + for cve_id in candidate_cve_ids: + cve = cves[cve_id] + if cve.id in cve_allowlist: + continue + if cve_match(cve, metadata): + possible_cves[cve_id] = cve + cve_deps[cve_id].append(dep) + return possible_cves, cve_deps if __name__ == '__main__': - # Allow local overrides for NIST CVE database URLs via args. - urls = sys.argv[1:] - if not urls: - # We only look back a few years, since we shouldn't have any ancient deps. - current_year = dt.datetime.now().year - scan_years = range(2018, current_year + 1) - urls = [ - f'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-{year}.json.gz' for year in scan_years - ] - cves, cpe_revmap = download_cve_data(urls) - possible_cves, cve_deps = cve_scan(cves, cpe_revmap, IGNORES_CVES, - dep_utils.repository_locations()) - if possible_cves: - print('\nBased on heuristic matching with the NIST CVE database, Envoy may be vulnerable to:') - for cve_id in sorted(possible_cves): - print(f'{format_cve_details(possible_cves[cve_id], cve_deps[cve_id])}') - sys.exit(1) + # Allow local overrides for NIST CVE database URLs via args. + urls = sys.argv[1:] + if not urls: + # We only look back a few years, since we shouldn't have any ancient deps. + current_year = dt.datetime.now().year + scan_years = range(2018, current_year + 1) + urls = [ + f'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-{year}.json.gz' + for year in scan_years + ] + cves, cpe_revmap = download_cve_data(urls) + possible_cves, cve_deps = cve_scan(cves, cpe_revmap, IGNORES_CVES, + dep_utils.repository_locations()) + if possible_cves: + print( + '\nBased on heuristic matching with the NIST CVE database, Envoy may be vulnerable to:') + for cve_id in sorted(possible_cves): + print(f'{format_cve_details(possible_cves[cve_id], cve_deps[cve_id])}') + sys.exit(1) diff --git a/tools/dependency/cve_scan_test.py b/tools/dependency/cve_scan_test.py index c5889888d293..ea2b1537c8a7 100755 --- a/tools/dependency/cve_scan_test.py +++ b/tools/dependency/cve_scan_test.py @@ -10,276 +10,277 @@ class CveScanTest(unittest.TestCase): - def test_parse_cve_json(self): - cve_json = { - 'CVE_Items': [ - { - 'cve': { - 'CVE_data_meta': { - 'ID': 'CVE-2020-1234' + def test_parse_cve_json(self): + cve_json = { + 'CVE_Items': [ + { + 'cve': { + 'CVE_data_meta': { + 'ID': 'CVE-2020-1234' + }, + 'description': { + 'description_data': [{ + 'value': 'foo' + }] + } }, - 'description': { - 'description_data': [{ - 'value': 'foo' - }] - } - }, - 'configurations': { - 'nodes': [{ - 'cpe_match': [{ - 'cpe23Uri': 'cpe:2.3:a:foo:bar:1.2.3' + 'configurations': { + 'nodes': [{ + 'cpe_match': [{ + 'cpe23Uri': 'cpe:2.3:a:foo:bar:1.2.3' + }], }], - }], - }, - 'impact': { - 'baseMetricV3': { - 'cvssV3': { - 'baseScore': 3.4, - 'baseSeverity': 'LOW' + }, + 'impact': { + 'baseMetricV3': { + 'cvssV3': { + 'baseScore': 3.4, + 'baseSeverity': 'LOW' + } } - } - }, - 'publishedDate': '2020-03-17T00:59Z', - 'lastModifiedDate': '2020-04-17T00:59Z' - }, - { - 'cve': { - 'CVE_data_meta': { - 'ID': 'CVE-2020-1235' }, - 'description': { - 'description_data': [{ - 'value': 'bar' - }] - } + 'publishedDate': '2020-03-17T00:59Z', + 'lastModifiedDate': '2020-04-17T00:59Z' }, - 'configurations': { - 'nodes': [{ - 'cpe_match': [{ - 'cpe23Uri': 'cpe:2.3:a:foo:bar:1.2.3' + { + 'cve': { + 'CVE_data_meta': { + 'ID': 'CVE-2020-1235' + }, + 'description': { + 'description_data': [{ + 'value': 'bar' + }] + } + }, + 'configurations': { + 'nodes': [{ + 'cpe_match': [{ + 'cpe23Uri': 'cpe:2.3:a:foo:bar:1.2.3' + }], + 'children': [ + { + 'cpe_match': [{ + 'cpe23Uri': 'cpe:2.3:a:foo:baz:3.2.3' + }] + }, + { + 'cpe_match': [{ + 'cpe23Uri': 'cpe:2.3:a:foo:*:*' + }, { + 'cpe23Uri': 'cpe:2.3:a:wat:bar:1.2.3' + }] + }, + ], }], - 'children': [ - { - 'cpe_match': [{ - 'cpe23Uri': 'cpe:2.3:a:foo:baz:3.2.3' - }] - }, - { - 'cpe_match': [{ - 'cpe23Uri': 'cpe:2.3:a:foo:*:*' - }, { - 'cpe23Uri': 'cpe:2.3:a:wat:bar:1.2.3' - }] - }, - ], - }], - }, - 'impact': { - 'baseMetricV3': { - 'cvssV3': { - 'baseScore': 9.9, - 'baseSeverity': 'HIGH' + }, + 'impact': { + 'baseMetricV3': { + 'cvssV3': { + 'baseScore': 9.9, + 'baseSeverity': 'HIGH' + } } - } + }, + 'publishedDate': '2020-03-18T00:59Z', + 'lastModifiedDate': '2020-04-18T00:59Z' }, - 'publishedDate': '2020-03-18T00:59Z', - 'lastModifiedDate': '2020-04-18T00:59Z' - }, - ] - } - cves = {} - cpe_revmap = defaultdict(set) - cve_scan.parse_cve_json(cve_json, cves, cpe_revmap) - self.maxDiff = None - self.assertDictEqual( - cves, { - 'CVE-2020-1234': - cve_scan.Cve(id='CVE-2020-1234', - description='foo', - cpes=set([self.build_cpe('cpe:2.3:a:foo:bar:1.2.3')]), - score=3.4, - severity='LOW', - published_date=dt.date(2020, 3, 17), - last_modified_date=dt.date(2020, 4, 17)), - 'CVE-2020-1235': - cve_scan.Cve(id='CVE-2020-1235', - description='bar', - cpes=set( - map(self.build_cpe, [ - 'cpe:2.3:a:foo:bar:1.2.3', 'cpe:2.3:a:foo:baz:3.2.3', - 'cpe:2.3:a:foo:*:*', 'cpe:2.3:a:wat:bar:1.2.3' - ])), - score=9.9, - severity='HIGH', - published_date=dt.date(2020, 3, 18), - last_modified_date=dt.date(2020, 4, 18)) - }) - self.assertDictEqual(cpe_revmap, { - 'cpe:2.3:a:foo:*:*': {'CVE-2020-1234', 'CVE-2020-1235'}, - 'cpe:2.3:a:wat:*:*': {'CVE-2020-1235'} - }) + ] + } + cves = {} + cpe_revmap = defaultdict(set) + cve_scan.parse_cve_json(cve_json, cves, cpe_revmap) + self.maxDiff = None + self.assertDictEqual( + cves, { + 'CVE-2020-1234': + cve_scan.Cve(id='CVE-2020-1234', + description='foo', + cpes=set([self.build_cpe('cpe:2.3:a:foo:bar:1.2.3')]), + score=3.4, + severity='LOW', + published_date=dt.date(2020, 3, 17), + last_modified_date=dt.date(2020, 4, 17)), + 'CVE-2020-1235': + cve_scan.Cve(id='CVE-2020-1235', + description='bar', + cpes=set( + map(self.build_cpe, [ + 'cpe:2.3:a:foo:bar:1.2.3', 'cpe:2.3:a:foo:baz:3.2.3', + 'cpe:2.3:a:foo:*:*', 'cpe:2.3:a:wat:bar:1.2.3' + ])), + score=9.9, + severity='HIGH', + published_date=dt.date(2020, 3, 18), + last_modified_date=dt.date(2020, 4, 18)) + }) + self.assertDictEqual( + cpe_revmap, { + 'cpe:2.3:a:foo:*:*': {'CVE-2020-1234', 'CVE-2020-1235'}, + 'cpe:2.3:a:wat:*:*': {'CVE-2020-1235'} + }) - def build_cpe(self, cpe_str): - return cve_scan.Cpe.from_string(cpe_str) + def build_cpe(self, cpe_str): + return cve_scan.Cpe.from_string(cpe_str) - def build_dep(self, cpe_str, version=None, release_date=None): - return {'cpe': cpe_str, 'version': version, 'release_date': release_date} + def build_dep(self, cpe_str, version=None, release_date=None): + return {'cpe': cpe_str, 'version': version, 'release_date': release_date} - def cpe_match(self, cpe_str, dep_cpe_str, version=None, release_date=None): - return cve_scan.cpe_match( - self.build_cpe(cpe_str), - self.build_dep(dep_cpe_str, version=version, release_date=release_date)) + def cpe_match(self, cpe_str, dep_cpe_str, version=None, release_date=None): + return cve_scan.cpe_match( + self.build_cpe(cpe_str), + self.build_dep(dep_cpe_str, version=version, release_date=release_date)) - def test_cpe_match(self): - # Mismatched part - self.assertFalse(self.cpe_match('cpe:2.3:o:foo:bar:*', 'cpe:2.3:a:foo:bar:*')) - # Mismatched vendor - self.assertFalse(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foz:bar:*')) - # Mismatched product - self.assertFalse(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foo:baz:*')) - # Wildcard product - self.assertTrue(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foo:*:*')) - # Wildcard version match - self.assertTrue(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foo:bar:*')) - # Exact version match - self.assertTrue( - self.cpe_match('cpe:2.3:a:foo:bar:1.2.3', 'cpe:2.3:a:foo:bar:*', version='1.2.3')) - # Date version match - self.assertTrue( - self.cpe_match('cpe:2.3:a:foo:bar:2020-03-05', - 'cpe:2.3:a:foo:bar:*', - release_date='2020-03-05')) - fuzzy_version_matches = [ - ('2020-03-05', '2020-03-05'), - ('2020-03-05', '20200305'), - ('2020-03-05', 'foo-20200305-bar'), - ('2020-03-05', 'foo-2020_03_05-bar'), - ('2020-03-05', 'foo-2020-03-05-bar'), - ('1.2.3', '1.2.3'), - ('1.2.3', '1-2-3'), - ('1.2.3', '1_2_3'), - ('1.2.3', '1:2:3'), - ('1.2.3', 'foo-1-2-3-bar'), - ] - for cpe_version, dep_version in fuzzy_version_matches: - self.assertTrue( - self.cpe_match(f'cpe:2.3:a:foo:bar:{cpe_version}', - 'cpe:2.3:a:foo:bar:*', - version=dep_version)) - fuzzy_version_no_matches = [ - ('2020-03-05', '2020-3.5'), - ('2020-03-05', '2020--03-05'), - ('1.2.3', '1@2@3'), - ('1.2.3', '1..2.3'), - ] - for cpe_version, dep_version in fuzzy_version_no_matches: - self.assertFalse( - self.cpe_match(f'cpe:2.3:a:foo:bar:{cpe_version}', - 'cpe:2.3:a:foo:bar:*', - version=dep_version)) + def test_cpe_match(self): + # Mismatched part + self.assertFalse(self.cpe_match('cpe:2.3:o:foo:bar:*', 'cpe:2.3:a:foo:bar:*')) + # Mismatched vendor + self.assertFalse(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foz:bar:*')) + # Mismatched product + self.assertFalse(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foo:baz:*')) + # Wildcard product + self.assertTrue(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foo:*:*')) + # Wildcard version match + self.assertTrue(self.cpe_match('cpe:2.3:a:foo:bar:*', 'cpe:2.3:a:foo:bar:*')) + # Exact version match + self.assertTrue( + self.cpe_match('cpe:2.3:a:foo:bar:1.2.3', 'cpe:2.3:a:foo:bar:*', version='1.2.3')) + # Date version match + self.assertTrue( + self.cpe_match('cpe:2.3:a:foo:bar:2020-03-05', + 'cpe:2.3:a:foo:bar:*', + release_date='2020-03-05')) + fuzzy_version_matches = [ + ('2020-03-05', '2020-03-05'), + ('2020-03-05', '20200305'), + ('2020-03-05', 'foo-20200305-bar'), + ('2020-03-05', 'foo-2020_03_05-bar'), + ('2020-03-05', 'foo-2020-03-05-bar'), + ('1.2.3', '1.2.3'), + ('1.2.3', '1-2-3'), + ('1.2.3', '1_2_3'), + ('1.2.3', '1:2:3'), + ('1.2.3', 'foo-1-2-3-bar'), + ] + for cpe_version, dep_version in fuzzy_version_matches: + self.assertTrue( + self.cpe_match(f'cpe:2.3:a:foo:bar:{cpe_version}', + 'cpe:2.3:a:foo:bar:*', + version=dep_version)) + fuzzy_version_no_matches = [ + ('2020-03-05', '2020-3.5'), + ('2020-03-05', '2020--03-05'), + ('1.2.3', '1@2@3'), + ('1.2.3', '1..2.3'), + ] + for cpe_version, dep_version in fuzzy_version_no_matches: + self.assertFalse( + self.cpe_match(f'cpe:2.3:a:foo:bar:{cpe_version}', + 'cpe:2.3:a:foo:bar:*', + version=dep_version)) - def build_cve(self, cve_id, cpes, published_date): - return cve_scan.Cve(cve_id, - description=None, - cpes=cpes, - score=None, - severity=None, - published_date=dt.date.fromisoformat(published_date), - last_modified_date=None) + def build_cve(self, cve_id, cpes, published_date): + return cve_scan.Cve(cve_id, + description=None, + cpes=cpes, + score=None, + severity=None, + published_date=dt.date.fromisoformat(published_date), + last_modified_date=None) - def cve_match(self, cve_id, cpes, published_date, dep_cpe_str, version=None, release_date=None): - return cve_scan.cve_match( - self.build_cve(cve_id, cpes=cpes, published_date=published_date), - self.build_dep(dep_cpe_str, version=version, release_date=release_date)) + def cve_match(self, cve_id, cpes, published_date, dep_cpe_str, version=None, release_date=None): + return cve_scan.cve_match( + self.build_cve(cve_id, cpes=cpes, published_date=published_date), + self.build_dep(dep_cpe_str, version=version, release_date=release_date)) - def test_cve_match(self): - # Empty CPEs, no match - self.assertFalse(self.cve_match('CVE-2020-123', set(), '2020-05-03', 'cpe:2.3:a:foo:bar:*')) - # Wildcard version, stale dependency match - self.assertTrue( - self.cve_match('CVE-2020-123', - set([self.build_cpe('cpe:2.3:a:foo:bar:*')]), - '2020-05-03', - 'cpe:2.3:a:foo:bar:*', - release_date='2020-05-02')) - self.assertTrue( - self.cve_match('CVE-2020-123', - set([self.build_cpe('cpe:2.3:a:foo:bar:*')]), - '2020-05-03', - 'cpe:2.3:a:foo:bar:*', - release_date='2020-05-03')) - # Wildcard version, recently updated - self.assertFalse( - self.cve_match('CVE-2020-123', - set([self.build_cpe('cpe:2.3:a:foo:bar:*')]), - '2020-05-03', - 'cpe:2.3:a:foo:bar:*', - release_date='2020-05-04')) - # Version match - self.assertTrue( - self.cve_match('CVE-2020-123', - set([self.build_cpe('cpe:2.3:a:foo:bar:1.2.3')]), - '2020-05-03', - 'cpe:2.3:a:foo:bar:*', - version='1.2.3')) - # Version mismatch - self.assertFalse( - self.cve_match('CVE-2020-123', - set([self.build_cpe('cpe:2.3:a:foo:bar:1.2.3')]), - '2020-05-03', - 'cpe:2.3:a:foo:bar:*', - version='1.2.4', - release_date='2020-05-02')) - # Multiple CPEs, match first, don't match later. - self.assertTrue( - self.cve_match('CVE-2020-123', - set([ - self.build_cpe('cpe:2.3:a:foo:bar:1.2.3'), - self.build_cpe('cpe:2.3:a:foo:baz:3.2.1') - ]), - '2020-05-03', - 'cpe:2.3:a:foo:bar:*', - version='1.2.3')) + def test_cve_match(self): + # Empty CPEs, no match + self.assertFalse(self.cve_match('CVE-2020-123', set(), '2020-05-03', 'cpe:2.3:a:foo:bar:*')) + # Wildcard version, stale dependency match + self.assertTrue( + self.cve_match('CVE-2020-123', + set([self.build_cpe('cpe:2.3:a:foo:bar:*')]), + '2020-05-03', + 'cpe:2.3:a:foo:bar:*', + release_date='2020-05-02')) + self.assertTrue( + self.cve_match('CVE-2020-123', + set([self.build_cpe('cpe:2.3:a:foo:bar:*')]), + '2020-05-03', + 'cpe:2.3:a:foo:bar:*', + release_date='2020-05-03')) + # Wildcard version, recently updated + self.assertFalse( + self.cve_match('CVE-2020-123', + set([self.build_cpe('cpe:2.3:a:foo:bar:*')]), + '2020-05-03', + 'cpe:2.3:a:foo:bar:*', + release_date='2020-05-04')) + # Version match + self.assertTrue( + self.cve_match('CVE-2020-123', + set([self.build_cpe('cpe:2.3:a:foo:bar:1.2.3')]), + '2020-05-03', + 'cpe:2.3:a:foo:bar:*', + version='1.2.3')) + # Version mismatch + self.assertFalse( + self.cve_match('CVE-2020-123', + set([self.build_cpe('cpe:2.3:a:foo:bar:1.2.3')]), + '2020-05-03', + 'cpe:2.3:a:foo:bar:*', + version='1.2.4', + release_date='2020-05-02')) + # Multiple CPEs, match first, don't match later. + self.assertTrue( + self.cve_match('CVE-2020-123', + set([ + self.build_cpe('cpe:2.3:a:foo:bar:1.2.3'), + self.build_cpe('cpe:2.3:a:foo:baz:3.2.1') + ]), + '2020-05-03', + 'cpe:2.3:a:foo:bar:*', + version='1.2.3')) - def test_cve_scan(self): - cves = { - 'CVE-2020-1234': - self.build_cve( - 'CVE-2020-1234', - set([ - self.build_cpe('cpe:2.3:a:foo:bar:1.2.3'), - self.build_cpe('cpe:2.3:a:foo:baz:3.2.1') - ]), '2020-05-03'), - 'CVE-2020-1235': - self.build_cve( - 'CVE-2020-1235', - set([ - self.build_cpe('cpe:2.3:a:foo:bar:1.2.3'), - self.build_cpe('cpe:2.3:a:foo:baz:3.2.1') + def test_cve_scan(self): + cves = { + 'CVE-2020-1234': + self.build_cve( + 'CVE-2020-1234', + set([ + self.build_cpe('cpe:2.3:a:foo:bar:1.2.3'), + self.build_cpe('cpe:2.3:a:foo:baz:3.2.1') + ]), '2020-05-03'), + 'CVE-2020-1235': + self.build_cve( + 'CVE-2020-1235', + set([ + self.build_cpe('cpe:2.3:a:foo:bar:1.2.3'), + self.build_cpe('cpe:2.3:a:foo:baz:3.2.1') + ]), '2020-05-03'), + 'CVE-2020-1236': + self.build_cve('CVE-2020-1236', set([ + self.build_cpe('cpe:2.3:a:foo:wat:1.2.3'), ]), '2020-05-03'), - 'CVE-2020-1236': - self.build_cve('CVE-2020-1236', set([ - self.build_cpe('cpe:2.3:a:foo:wat:1.2.3'), - ]), '2020-05-03'), - } - cpe_revmap = { - 'cpe:2.3:a:foo:*:*': ['CVE-2020-1234', 'CVE-2020-1235', 'CVE-2020-1236'], - } - cve_allowlist = ['CVE-2020-1235'] - repository_locations = { - 'bar': self.build_dep('cpe:2.3:a:foo:bar:*', version='1.2.3'), - 'baz': self.build_dep('cpe:2.3:a:foo:baz:*', version='3.2.1'), - 'foo': self.build_dep('cpe:2.3:a:foo:*:*', version='1.2.3'), - 'blah': self.build_dep('N/A'), - } - possible_cves, cve_deps = cve_scan.cve_scan(cves, cpe_revmap, cve_allowlist, - repository_locations) - self.assertListEqual(sorted(possible_cves.keys()), ['CVE-2020-1234', 'CVE-2020-1236']) - self.assertDictEqual(cve_deps, { - 'CVE-2020-1234': ['bar', 'baz', 'foo'], - 'CVE-2020-1236': ['foo'] - }) + } + cpe_revmap = { + 'cpe:2.3:a:foo:*:*': ['CVE-2020-1234', 'CVE-2020-1235', 'CVE-2020-1236'], + } + cve_allowlist = ['CVE-2020-1235'] + repository_locations = { + 'bar': self.build_dep('cpe:2.3:a:foo:bar:*', version='1.2.3'), + 'baz': self.build_dep('cpe:2.3:a:foo:baz:*', version='3.2.1'), + 'foo': self.build_dep('cpe:2.3:a:foo:*:*', version='1.2.3'), + 'blah': self.build_dep('N/A'), + } + possible_cves, cve_deps = cve_scan.cve_scan(cves, cpe_revmap, cve_allowlist, + repository_locations) + self.assertListEqual(sorted(possible_cves.keys()), ['CVE-2020-1234', 'CVE-2020-1236']) + self.assertDictEqual(cve_deps, { + 'CVE-2020-1234': ['bar', 'baz', 'foo'], + 'CVE-2020-1236': ['foo'] + }) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tools/dependency/exports.py b/tools/dependency/exports.py index 1fd0e59cb886..6a6d4e5763c4 100644 --- a/tools/dependency/exports.py +++ b/tools/dependency/exports.py @@ -8,10 +8,10 @@ # Shared Starlark/Python files must have a .bzl suffix for Starlark import, so # we are forced to do this workaround. def load_module(name, path): - spec = spec_from_loader(name, SourceFileLoader(name, path)) - module = module_from_spec(spec) - spec.loader.exec_module(module) - return module + spec = spec_from_loader(name, SourceFileLoader(name, path)) + module = module_from_spec(spec) + spec.loader.exec_module(module) + return module # this is the relative path in a bazel build diff --git a/tools/dependency/generate_external_dep_rst.py b/tools/dependency/generate_external_dep_rst.py index d3d1cc74ae29..48e7c0491260 100755 --- a/tools/dependency/generate_external_dep_rst.py +++ b/tools/dependency/generate_external_dep_rst.py @@ -14,8 +14,8 @@ # Render a CSV table given a list of table headers, widths and list of rows # (each a list of strings). def csv_table(headers, widths, rows): - csv_rows = '\n '.join(', '.join(row) for row in rows) - return f'''.. csv-table:: + csv_rows = '\n '.join(', '.join(row) for row in rows) + return f'''.. csv-table:: :header: {', '.join(headers)} :widths: {', '.join(str(w) for w in widths) } @@ -26,87 +26,87 @@ def csv_table(headers, widths, rows): # Anonymous external RST link for a given URL. def rst_link(text, url): - return f'`{text} <{url}>`__' + return f'`{text} <{url}>`__' # NIST CPE database search URL for a given CPE. def nist_cpe_url(cpe): - encoded_cpe = urllib.parse.quote(cpe) - return f'https://nvd.nist.gov/vuln/search/results?form_type=Advanced&results_type=overview&query={encoded_cpe}&search_type=all' + encoded_cpe = urllib.parse.quote(cpe) + return f'https://nvd.nist.gov/vuln/search/results?form_type=Advanced&results_type=overview&query={encoded_cpe}&search_type=all' # Render version strings human readable. def render_version(version): - # Heuristic, almost certainly a git SHA - if len(version) == 40: - # Abbreviate git SHA - return version[:7] - return version + # Heuristic, almost certainly a git SHA + if len(version) == 40: + # Abbreviate git SHA + return version[:7] + return version def render_title(title): - underline = '~' * len(title) - return f'\n{title}\n{underline}\n\n' + underline = '~' * len(title) + return f'\n{title}\n{underline}\n\n' # Determine the version link URL. If it's GitHub, use some heuristics to figure # out a release tag link, otherwise point to the GitHub tree at the respective # SHA. Otherwise, return the tarball download. def get_version_url(metadata): - # Figure out if it's a GitHub repo. - github_release = dep_utils.get_github_release_from_urls(metadata['urls']) - # If not, direct download link for tarball - if not github_release: - return metadata['urls'][0] - github_repo = f'https://github.com/{github_release.organization}/{github_release.project}' - if github_release.tagged: - # The GitHub version should look like the metadata version, but might have - # something like a "v" prefix. - return f'{github_repo}/releases/tag/{github_release.version}' - assert (metadata['version'] == github_release.version) - return f'{github_repo}/tree/{github_release.version}' + # Figure out if it's a GitHub repo. + github_release = dep_utils.get_github_release_from_urls(metadata['urls']) + # If not, direct download link for tarball + if not github_release: + return metadata['urls'][0] + github_repo = f'https://github.com/{github_release.organization}/{github_release.project}' + if github_release.tagged: + # The GitHub version should look like the metadata version, but might have + # something like a "v" prefix. + return f'{github_repo}/releases/tag/{github_release.version}' + assert (metadata['version'] == github_release.version) + return f'{github_repo}/tree/{github_release.version}' if __name__ == '__main__': - try: - generated_rst_dir = os.getenv("GENERATED_RST_DIR") or sys.argv[1] - except IndexError: - raise SystemExit( - "Output dir path must be either specified as arg or with GENERATED_RST_DIR env var") - - security_rst_root = os.path.join(generated_rst_dir, "intro/arch_overview/security") - - pathlib.Path(security_rst_root).mkdir(parents=True, exist_ok=True) - - Dep = namedtuple('Dep', ['name', 'sort_name', 'version', 'cpe', 'release_date']) - use_categories = defaultdict(lambda: defaultdict(list)) - # Bin rendered dependencies into per-use category lists. - for k, v in dep_utils.repository_locations().items(): - cpe = v.get('cpe', '') - if cpe == 'N/A': - cpe = '' - if cpe: - cpe = rst_link(cpe, nist_cpe_url(cpe)) - project_name = v['project_name'] - project_url = v['project_url'] - name = rst_link(project_name, project_url) - version = rst_link(render_version(v['version']), get_version_url(v)) - release_date = v['release_date'] - dep = Dep(name, project_name.lower(), version, cpe, release_date) - for category in v['use_category']: - for ext in v.get('extensions', ['core']): - use_categories[category][ext].append(dep) - - def csv_row(dep): - return [dep.name, dep.version, dep.release_date, dep.cpe] - - # Generate per-use category RST with CSV tables. - for category, exts in use_categories.items(): - content = '' - for ext_name, deps in sorted(exts.items()): - if ext_name != 'core': - content += render_title(ext_name) - output_path = pathlib.Path(security_rst_root, f'external_dep_{category}.rst') - content += csv_table(['Name', 'Version', 'Release date', 'CPE'], [2, 1, 1, 2], - [csv_row(dep) for dep in sorted(deps, key=lambda d: d.sort_name)]) - output_path.write_text(content) + try: + generated_rst_dir = os.getenv("GENERATED_RST_DIR") or sys.argv[1] + except IndexError: + raise SystemExit( + "Output dir path must be either specified as arg or with GENERATED_RST_DIR env var") + + security_rst_root = os.path.join(generated_rst_dir, "intro/arch_overview/security") + + pathlib.Path(security_rst_root).mkdir(parents=True, exist_ok=True) + + Dep = namedtuple('Dep', ['name', 'sort_name', 'version', 'cpe', 'release_date']) + use_categories = defaultdict(lambda: defaultdict(list)) + # Bin rendered dependencies into per-use category lists. + for k, v in dep_utils.repository_locations().items(): + cpe = v.get('cpe', '') + if cpe == 'N/A': + cpe = '' + if cpe: + cpe = rst_link(cpe, nist_cpe_url(cpe)) + project_name = v['project_name'] + project_url = v['project_url'] + name = rst_link(project_name, project_url) + version = rst_link(render_version(v['version']), get_version_url(v)) + release_date = v['release_date'] + dep = Dep(name, project_name.lower(), version, cpe, release_date) + for category in v['use_category']: + for ext in v.get('extensions', ['core']): + use_categories[category][ext].append(dep) + + def csv_row(dep): + return [dep.name, dep.version, dep.release_date, dep.cpe] + + # Generate per-use category RST with CSV tables. + for category, exts in use_categories.items(): + content = '' + for ext_name, deps in sorted(exts.items()): + if ext_name != 'core': + content += render_title(ext_name) + output_path = pathlib.Path(security_rst_root, f'external_dep_{category}.rst') + content += csv_table(['Name', 'Version', 'Release date', 'CPE'], [2, 1, 1, 2], + [csv_row(dep) for dep in sorted(deps, key=lambda d: d.sort_name)]) + output_path.write_text(content) diff --git a/tools/dependency/ossf_scorecard.py b/tools/dependency/ossf_scorecard.py index 35567592dcdb..edab186950f8 100755 --- a/tools/dependency/ossf_scorecard.py +++ b/tools/dependency/ossf_scorecard.py @@ -43,103 +43,104 @@ # Thrown on errors related to release date. class OssfScorecardError(Exception): - pass + pass # We skip build, test, etc. def is_scored_use_category(use_category): - return len( - set(use_category).intersection([ - 'dataplane_core', 'dataplane_ext', 'controlplane', 'observability_core', - 'observability_ext' - ])) > 0 + return len( + set(use_category).intersection([ + 'dataplane_core', 'dataplane_ext', 'controlplane', 'observability_core', + 'observability_ext' + ])) > 0 def score(scorecard_path, repository_locations): - results = {} - for dep, metadata in sorted(repository_locations.items()): - if not is_scored_use_category(metadata['use_category']): - continue - results_key = metadata['project_name'] - formatted_name = '=HYPERLINK("%s", "%s")' % (metadata['project_url'], results_key) - github_project_url = utils.get_github_project_url(metadata['urls']) - if not github_project_url: - na = 'Not Scorecard compatible' - results[results_key] = Scorecard(name=formatted_name, - contributors=na, - active=na, - ci_tests=na, - pull_requests=na, - code_review=na, - fuzzing=na, - security_policy=na, - releases=na) - continue - raw_scorecard = json.loads( - sp.check_output( - [scorecard_path, f'--repo={github_project_url}', '--show-details', '--format=json'])) - checks = {c['CheckName']: c for c in raw_scorecard['Checks']} - - # Generic check format. - def _format(key): - score = checks[key] - status = score['Pass'] - confidence = score['Confidence'] - return f'{status} ({confidence})' - - # Releases need to be extracted from Signed-Releases. - def release_format(): - score = checks['Signed-Releases'] - if score['Pass']: - return _format('Signed-Releases') - details = score['Details'] - release_found = details is not None and any('release found:' in d for d in details) - if release_found: - return 'True (10)' - else: - return 'False (10)' - - results[results_key] = Scorecard(name=formatted_name, - contributors=_format('Contributors'), - active=_format('Active'), - ci_tests=_format('CI-Tests'), - pull_requests=_format('Pull-Requests'), - code_review=_format('Code-Review'), - fuzzing=_format('Fuzzing'), - security_policy=_format('Security-Policy'), - releases=release_format()) - print(raw_scorecard) - print(results[results_key]) - return results + results = {} + for dep, metadata in sorted(repository_locations.items()): + if not is_scored_use_category(metadata['use_category']): + continue + results_key = metadata['project_name'] + formatted_name = '=HYPERLINK("%s", "%s")' % (metadata['project_url'], results_key) + github_project_url = utils.get_github_project_url(metadata['urls']) + if not github_project_url: + na = 'Not Scorecard compatible' + results[results_key] = Scorecard(name=formatted_name, + contributors=na, + active=na, + ci_tests=na, + pull_requests=na, + code_review=na, + fuzzing=na, + security_policy=na, + releases=na) + continue + raw_scorecard = json.loads( + sp.check_output( + [scorecard_path, f'--repo={github_project_url}', '--show-details', + '--format=json'])) + checks = {c['CheckName']: c for c in raw_scorecard['Checks']} + + # Generic check format. + def _format(key): + score = checks[key] + status = score['Pass'] + confidence = score['Confidence'] + return f'{status} ({confidence})' + + # Releases need to be extracted from Signed-Releases. + def release_format(): + score = checks['Signed-Releases'] + if score['Pass']: + return _format('Signed-Releases') + details = score['Details'] + release_found = details is not None and any('release found:' in d for d in details) + if release_found: + return 'True (10)' + else: + return 'False (10)' + + results[results_key] = Scorecard(name=formatted_name, + contributors=_format('Contributors'), + active=_format('Active'), + ci_tests=_format('CI-Tests'), + pull_requests=_format('Pull-Requests'), + code_review=_format('Code-Review'), + fuzzing=_format('Fuzzing'), + security_policy=_format('Security-Policy'), + releases=release_format()) + print(raw_scorecard) + print(results[results_key]) + return results def print_csv_results(csv_output_path, results): - headers = Scorecard._fields - with open(csv_output_path, 'w') as f: - writer = csv.writer(f) - writer.writerow(headers) - for name in sorted(results): - writer.writerow(getattr(results[name], h) for h in headers) + headers = Scorecard._fields + with open(csv_output_path, 'w') as f: + writer = csv.writer(f) + writer.writerow(headers) + for name in sorted(results): + writer.writerow(getattr(results[name], h) for h in headers) if __name__ == '__main__': - if len(sys.argv) != 4: - print( - 'Usage: %s ' - % sys.argv[0]) - sys.exit(1) - access_token = os.getenv('GITHUB_AUTH_TOKEN') - if not access_token: - print('Missing GITHUB_AUTH_TOKEN') - sys.exit(1) - path = sys.argv[1] - scorecard_path = sys.argv[2] - csv_output_path = sys.argv[3] - spec_loader = exports.repository_locations_utils.load_repository_locations_spec - path_module = exports.load_module('repository_locations', path) - try: - results = score(scorecard_path, spec_loader(path_module.REPOSITORY_LOCATIONS_SPEC)) - print_csv_results(csv_output_path, results) - except OssfScorecardError as e: - print(f'An error occurred while processing {path}, please verify the correctness of the ' - f'metadata: {e}') + if len(sys.argv) != 4: + print( + 'Usage: %s ' + % sys.argv[0]) + sys.exit(1) + access_token = os.getenv('GITHUB_AUTH_TOKEN') + if not access_token: + print('Missing GITHUB_AUTH_TOKEN') + sys.exit(1) + path = sys.argv[1] + scorecard_path = sys.argv[2] + csv_output_path = sys.argv[3] + spec_loader = exports.repository_locations_utils.load_repository_locations_spec + path_module = exports.load_module('repository_locations', path) + try: + results = score(scorecard_path, spec_loader(path_module.REPOSITORY_LOCATIONS_SPEC)) + print_csv_results(csv_output_path, results) + except OssfScorecardError as e: + print(f'An error occurred while processing {path}, please verify the correctness of the ' + f'metadata: {e}') diff --git a/tools/dependency/release_dates.py b/tools/dependency/release_dates.py index 5afdccd51c4c..ecfc2ca1af09 100644 --- a/tools/dependency/release_dates.py +++ b/tools/dependency/release_dates.py @@ -22,88 +22,89 @@ # Thrown on errors related to release date. class ReleaseDateError(Exception): - pass + pass # Format a datetime object as UTC YYYY-MM-DD. def format_utc_date(date): - # We only handle naive datetime objects right now, which is what PyGithub - # appears to be handing us. - assert (date.tzinfo is None) - return date.date().isoformat() + # We only handle naive datetime objects right now, which is what PyGithub + # appears to be handing us. + assert (date.tzinfo is None) + return date.date().isoformat() # Obtain latest release version and compare against metadata version, warn on # mismatch. def verify_and_print_latest_release(dep, repo, metadata_version, release_date): - try: - latest_release = repo.get_latest_release() - if latest_release.created_at > release_date and latest_release.tag_name != metadata_version: - print(f'*WARNING* {dep} has a newer release than {metadata_version}@<{release_date}>: ' - f'{latest_release.tag_name}@<{latest_release.created_at}>') - except github.UnknownObjectException: - pass + try: + latest_release = repo.get_latest_release() + if latest_release.created_at > release_date and latest_release.tag_name != metadata_version: + print(f'*WARNING* {dep} has a newer release than {metadata_version}@<{release_date}>: ' + f'{latest_release.tag_name}@<{latest_release.created_at}>') + except github.UnknownObjectException: + pass # Print GitHub release date, throw ReleaseDateError on mismatch with metadata release date. def verify_and_print_release_date(dep, github_release_date, metadata_release_date): - mismatch = '' - iso_release_date = format_utc_date(github_release_date) - print(f'{dep} has a GitHub release date {iso_release_date}') - if iso_release_date != metadata_release_date: - raise ReleaseDateError(f'Mismatch with metadata release date of {metadata_release_date}') + mismatch = '' + iso_release_date = format_utc_date(github_release_date) + print(f'{dep} has a GitHub release date {iso_release_date}') + if iso_release_date != metadata_release_date: + raise ReleaseDateError(f'Mismatch with metadata release date of {metadata_release_date}') # Extract release date from GitHub API. def get_release_date(repo, metadata_version, github_release): - if github_release.tagged: - tags = repo.get_tags() - for tag in tags: - if tag.name == github_release.version: - return tag.commit.commit.committer.date - return None - else: - assert (metadata_version == github_release.version) - commit = repo.get_commit(github_release.version) - return commit.commit.committer.date + if github_release.tagged: + tags = repo.get_tags() + for tag in tags: + if tag.name == github_release.version: + return tag.commit.commit.committer.date + return None + else: + assert (metadata_version == github_release.version) + commit = repo.get_commit(github_release.version) + return commit.commit.committer.date # Verify release dates in metadata against GitHub API. def verify_and_print_release_dates(repository_locations, github_instance): - for dep, metadata in sorted(repository_locations.items()): - release_date = None - # Obtain release information from GitHub API. - github_release = utils.get_github_release_from_urls(metadata['urls']) - if not github_release: - print(f'{dep} is not a GitHub repository') - continue - repo = github_instance.get_repo(f'{github_release.organization}/{github_release.project}') - release_date = get_release_date(repo, metadata['version'], github_release) - if release_date: - # Check whether there is a more recent version and warn if necessary. - verify_and_print_latest_release(dep, repo, github_release.version, release_date) - # Verify that the release date in metadata and GitHub correspond, - # otherwise throw ReleaseDateError. - verify_and_print_release_date(dep, release_date, metadata['release_date']) - else: - raise ReleaseDateError(f'{dep} is a GitHub repository with no no inferrable release date') + for dep, metadata in sorted(repository_locations.items()): + release_date = None + # Obtain release information from GitHub API. + github_release = utils.get_github_release_from_urls(metadata['urls']) + if not github_release: + print(f'{dep} is not a GitHub repository') + continue + repo = github_instance.get_repo(f'{github_release.organization}/{github_release.project}') + release_date = get_release_date(repo, metadata['version'], github_release) + if release_date: + # Check whether there is a more recent version and warn if necessary. + verify_and_print_latest_release(dep, repo, github_release.version, release_date) + # Verify that the release date in metadata and GitHub correspond, + # otherwise throw ReleaseDateError. + verify_and_print_release_date(dep, release_date, metadata['release_date']) + else: + raise ReleaseDateError( + f'{dep} is a GitHub repository with no no inferrable release date') if __name__ == '__main__': - if len(sys.argv) != 2: - print('Usage: %s ' % sys.argv[0]) - sys.exit(1) - access_token = os.getenv('GITHUB_TOKEN') - if not access_token: - print('Missing GITHUB_TOKEN') - sys.exit(1) - path = sys.argv[1] - spec_loader = exports.repository_locations_utils.load_repository_locations_spec - path_module = exports.load_module('repository_locations', path) - try: - verify_and_print_release_dates(spec_loader(path_module.REPOSITORY_LOCATIONS_SPEC), - github.Github(access_token)) - except ReleaseDateError as e: - print(f'An error occurred while processing {path}, please verify the correctness of the ' - f'metadata: {e}') - sys.exit(1) + if len(sys.argv) != 2: + print('Usage: %s ' % sys.argv[0]) + sys.exit(1) + access_token = os.getenv('GITHUB_TOKEN') + if not access_token: + print('Missing GITHUB_TOKEN') + sys.exit(1) + path = sys.argv[1] + spec_loader = exports.repository_locations_utils.load_repository_locations_spec + path_module = exports.load_module('repository_locations', path) + try: + verify_and_print_release_dates(spec_loader(path_module.REPOSITORY_LOCATIONS_SPEC), + github.Github(access_token)) + except ReleaseDateError as e: + print(f'An error occurred while processing {path}, please verify the correctness of the ' + f'metadata: {e}') + sys.exit(1) diff --git a/tools/dependency/utils.py b/tools/dependency/utils.py index 3646d81c5fb1..ba45ad5224bd 100644 --- a/tools/dependency/utils.py +++ b/tools/dependency/utils.py @@ -9,20 +9,20 @@ # All repository location metadata in the Envoy repository. def repository_locations(): - spec_loader = repository_locations_utils.load_repository_locations_spec - locations = spec_loader(envoy_repository_locations.REPOSITORY_LOCATIONS_SPEC) - locations.update(spec_loader(api_repository_locations.REPOSITORY_LOCATIONS_SPEC)) - return locations + spec_loader = repository_locations_utils.load_repository_locations_spec + locations = spec_loader(envoy_repository_locations.REPOSITORY_LOCATIONS_SPEC) + locations.update(spec_loader(api_repository_locations.REPOSITORY_LOCATIONS_SPEC)) + return locations # Obtain GitHub project URL from a list of URLs. def get_github_project_url(urls): - for url in urls: - if not url.startswith('https://github.com/'): - continue - components = url.split('/') - return f'https://github.com/{components[3]}/{components[4]}' - return None + for url in urls: + if not url.startswith('https://github.com/'): + continue + components = url.split('/') + return f'https://github.com/{components[3]}/{components[4]}' + return None # Information releated to a GitHub release version. @@ -33,26 +33,26 @@ def get_github_project_url(urls): # so, use heuristics to extract the release version and repo details, return # this, otherwise return None. def get_github_release_from_urls(urls): - for url in urls: - if not url.startswith('https://github.com/'): - continue - components = url.split('/') - if components[5] == 'archive': - # Only support .tar.gz, .zip today. Figure out the release tag from this - # filename. - if components[6].endswith('.tar.gz'): - github_version = components[6][:-len('.tar.gz')] - else: - assert (components[6].endswith('.zip')) - github_version = components[6][:-len('.zip')] - else: - # Release tag is a path component. - assert (components[5] == 'releases') - github_version = components[7] - # If it's not a GH hash, it's a tagged release. - tagged_release = len(github_version) != 40 - return GitHubRelease(organization=components[3], - project=components[4], - version=github_version, - tagged=tagged_release) - return None + for url in urls: + if not url.startswith('https://github.com/'): + continue + components = url.split('/') + if components[5] == 'archive': + # Only support .tar.gz, .zip today. Figure out the release tag from this + # filename. + if components[6].endswith('.tar.gz'): + github_version = components[6][:-len('.tar.gz')] + else: + assert (components[6].endswith('.zip')) + github_version = components[6][:-len('.zip')] + else: + # Release tag is a path component. + assert (components[5] == 'releases') + github_version = components[7] + # If it's not a GH hash, it's a tagged release. + tagged_release = len(github_version) != 40 + return GitHubRelease(organization=components[3], + project=components[4], + version=github_version, + tagged=tagged_release) + return None diff --git a/tools/dependency/validate.py b/tools/dependency/validate.py index 9a11046fa528..02c1657c6103 100755 --- a/tools/dependency/validate.py +++ b/tools/dependency/validate.py @@ -16,10 +16,10 @@ # Shared Starlark/Python files must have a .bzl suffix for Starlark import, so # we are forced to do this workaround. def load_module(name, path): - spec = spec_from_loader(name, SourceFileLoader(name, path)) - module = module_from_spec(spec) - spec.loader.exec_module(module) - return module + spec = spec_from_loader(name, SourceFileLoader(name, path)) + module = module_from_spec(spec) + spec.loader.exec_module(module) + return module envoy_repository_locations = load_module('envoy_repository_locations', @@ -53,28 +53,28 @@ def load_module(name, path): # "Test only" section of # docs/root/intro/arch_overview/security/external_deps.rst. def test_only_ignore(dep): - # Rust - if dep.startswith('raze__'): - return True - # Java - if dep.startswith('remotejdk'): - return True - # Python (pip3) - if '_pip3' in dep: - return True - return False + # Rust + if dep.startswith('raze__'): + return True + # Java + if dep.startswith('remotejdk'): + return True + # Python (pip3) + if '_pip3' in dep: + return True + return False class DependencyError(Exception): - """Error in dependency relationships.""" - pass + """Error in dependency relationships.""" + pass class DependencyInfo(object): - """Models dependency info in bazel/repositories.bzl.""" + """Models dependency info in bazel/repositories.bzl.""" - def deps_by_use_category(self, use_category): - """Find the set of external dependencies in a given use_category. + def deps_by_use_category(self, use_category): + """Find the set of external dependencies in a given use_category. Args: use_category: string providing use_category. @@ -82,11 +82,11 @@ def deps_by_use_category(self, use_category): Returns: Set of dependency identifiers that match use_category. """ - return set(name for name, metadata in REPOSITORY_LOCATIONS_SPEC.items() - if use_category in metadata['use_category']) + return set(name for name, metadata in REPOSITORY_LOCATIONS_SPEC.items() + if use_category in metadata['use_category']) - def get_metadata(self, dependency): - """Obtain repository metadata for a dependency. + def get_metadata(self, dependency): + """Obtain repository metadata for a dependency. Args: dependency: string providing dependency identifier. @@ -95,24 +95,26 @@ def get_metadata(self, dependency): A dictionary with the repository metadata as defined in bazel/repository_locations.bzl. """ - return REPOSITORY_LOCATIONS_SPEC.get(dependency) + return REPOSITORY_LOCATIONS_SPEC.get(dependency) class BuildGraph(object): - """Models the Bazel build graph.""" - - def __init__(self, ignore_deps=IGNORE_DEPS, repository_locations_spec=REPOSITORY_LOCATIONS_SPEC): - self._ignore_deps = ignore_deps - # Reverse map from untracked dependencies implied by other deps back to the dep. - self._implied_untracked_deps_revmap = {} - for dep, metadata in repository_locations_spec.items(): - implied_untracked_deps = metadata.get('implied_untracked_deps', []) - for untracked_dep in implied_untracked_deps: - assert (untracked_dep not in self._implied_untracked_deps_revmap) - self._implied_untracked_deps_revmap[untracked_dep] = dep - - def query_external_deps(self, *targets): - """Query the build graph for transitive external dependencies. + """Models the Bazel build graph.""" + + def __init__(self, + ignore_deps=IGNORE_DEPS, + repository_locations_spec=REPOSITORY_LOCATIONS_SPEC): + self._ignore_deps = ignore_deps + # Reverse map from untracked dependencies implied by other deps back to the dep. + self._implied_untracked_deps_revmap = {} + for dep, metadata in repository_locations_spec.items(): + implied_untracked_deps = metadata.get('implied_untracked_deps', []) + for untracked_dep in implied_untracked_deps: + assert (untracked_dep not in self._implied_untracked_deps_revmap) + self._implied_untracked_deps_revmap[untracked_dep] = dep + + def query_external_deps(self, *targets): + """Query the build graph for transitive external dependencies. Args: targets: Bazel targets. @@ -120,88 +122,90 @@ def query_external_deps(self, *targets): Returns: A set of dependency identifiers that are reachable from targets. """ - deps_query = ' union '.join(f'deps({l})' for l in targets) - try: - deps = subprocess.check_output(['bazel', 'query', deps_query], - stderr=subprocess.PIPE).decode().splitlines() - except subprocess.CalledProcessError as exc: - print( - f'Bazel query failed with error code {exc.returncode} and std error: {exc.stderr.decode()}' - ) - raise exc - ext_deps = set() - implied_untracked_deps = set() - for d in deps: - match = BAZEL_QUERY_EXTERNAL_DEP_RE.match(d) - if match: - ext_dep = match.group(1) - if ext_dep in self._ignore_deps: - continue - # If the dependency is untracked, add the source dependency that loaded - # it transitively. - if ext_dep in self._implied_untracked_deps_revmap: - ext_dep = self._implied_untracked_deps_revmap[ext_dep] - ext_deps.add(ext_dep) - return set(ext_deps) - - def list_extensions(self): - """List all extensions. + deps_query = ' union '.join(f'deps({l})' for l in targets) + try: + deps = subprocess.check_output(['bazel', 'query', deps_query], + stderr=subprocess.PIPE).decode().splitlines() + except subprocess.CalledProcessError as exc: + print( + f'Bazel query failed with error code {exc.returncode} and std error: {exc.stderr.decode()}' + ) + raise exc + ext_deps = set() + implied_untracked_deps = set() + for d in deps: + match = BAZEL_QUERY_EXTERNAL_DEP_RE.match(d) + if match: + ext_dep = match.group(1) + if ext_dep in self._ignore_deps: + continue + # If the dependency is untracked, add the source dependency that loaded + # it transitively. + if ext_dep in self._implied_untracked_deps_revmap: + ext_dep = self._implied_untracked_deps_revmap[ext_dep] + ext_deps.add(ext_dep) + return set(ext_deps) + + def list_extensions(self): + """List all extensions. Returns: Dictionary items from source/extensions/extensions_build_config.bzl. """ - return extensions_build_config.EXTENSIONS.items() + return extensions_build_config.EXTENSIONS.items() class Validator(object): - """Collection of validation methods.""" + """Collection of validation methods.""" - def __init__(self, dep_info, build_graph): - self._dep_info = dep_info - self._build_graph = build_graph - self._queried_core_deps = build_graph.query_external_deps( - '//source/exe:envoy_main_common_with_core_extensions_lib') + def __init__(self, dep_info, build_graph): + self._dep_info = dep_info + self._build_graph = build_graph + self._queried_core_deps = build_graph.query_external_deps( + '//source/exe:envoy_main_common_with_core_extensions_lib') - def validate_build_graph_structure(self): - """Validate basic assumptions about dependency relationship in the build graph. + def validate_build_graph_structure(self): + """Validate basic assumptions about dependency relationship in the build graph. Raises: DependencyError: on a dependency validation error. """ - print('Validating build dependency structure...') - queried_core_ext_deps = self._build_graph.query_external_deps( - '//source/exe:envoy_main_common_with_core_extensions_lib', '//source/extensions/...') - queried_all_deps = self._build_graph.query_external_deps('//source/...') - if queried_all_deps != queried_core_ext_deps: - raise DependencyError('Invalid build graph structure. deps(//source/...) != ' - 'deps(//source/exe:envoy_main_common_with_core_extensions_lib) ' - 'union deps(//source/extensions/...)') - - def validate_test_only_deps(self): - """Validate that test-only dependencies aren't included in //source/... + print('Validating build dependency structure...') + queried_core_ext_deps = self._build_graph.query_external_deps( + '//source/exe:envoy_main_common_with_core_extensions_lib', '//source/extensions/...') + queried_all_deps = self._build_graph.query_external_deps('//source/...') + if queried_all_deps != queried_core_ext_deps: + raise DependencyError('Invalid build graph structure. deps(//source/...) != ' + 'deps(//source/exe:envoy_main_common_with_core_extensions_lib) ' + 'union deps(//source/extensions/...)') + + def validate_test_only_deps(self): + """Validate that test-only dependencies aren't included in //source/... Raises: DependencyError: on a dependency validation error. """ - print('Validating test-only dependencies...') - # Validate that //source doesn't depend on test_only - queried_source_deps = self._build_graph.query_external_deps('//source/...') - expected_test_only_deps = self._dep_info.deps_by_use_category('test_only') - bad_test_only_deps = expected_test_only_deps.intersection(queried_source_deps) - if len(bad_test_only_deps) > 0: - raise DependencyError(f'//source depends on test-only dependencies: {bad_test_only_deps}') - # Validate that //test deps additional to those of //source are captured in - # test_only. - test_only_deps = self._build_graph.query_external_deps('//test/...') - source_deps = self._build_graph.query_external_deps('//source/...') - marginal_test_deps = test_only_deps.difference(source_deps) - bad_test_deps = marginal_test_deps.difference(expected_test_only_deps) - unknown_bad_test_deps = [dep for dep in bad_test_deps if not test_only_ignore(dep)] - if len(unknown_bad_test_deps) > 0: - raise DependencyError(f'Missing deps in test_only "use_category": {unknown_bad_test_deps}') - - def validate_data_plane_core_deps(self): - """Validate dataplane_core dependencies. + print('Validating test-only dependencies...') + # Validate that //source doesn't depend on test_only + queried_source_deps = self._build_graph.query_external_deps('//source/...') + expected_test_only_deps = self._dep_info.deps_by_use_category('test_only') + bad_test_only_deps = expected_test_only_deps.intersection(queried_source_deps) + if len(bad_test_only_deps) > 0: + raise DependencyError( + f'//source depends on test-only dependencies: {bad_test_only_deps}') + # Validate that //test deps additional to those of //source are captured in + # test_only. + test_only_deps = self._build_graph.query_external_deps('//test/...') + source_deps = self._build_graph.query_external_deps('//source/...') + marginal_test_deps = test_only_deps.difference(source_deps) + bad_test_deps = marginal_test_deps.difference(expected_test_only_deps) + unknown_bad_test_deps = [dep for dep in bad_test_deps if not test_only_ignore(dep)] + if len(unknown_bad_test_deps) > 0: + raise DependencyError( + f'Missing deps in test_only "use_category": {unknown_bad_test_deps}') + + def validate_data_plane_core_deps(self): + """Validate dataplane_core dependencies. Check that we at least tag as dataplane_core dependencies that match some well-known targets for the data-plane. @@ -209,28 +213,29 @@ def validate_data_plane_core_deps(self): Raises: DependencyError: on a dependency validation error. """ - print('Validating data-plane dependencies...') - # Necessary but not sufficient for dataplane. With some refactoring we could - # probably have more precise tagging of dataplane/controlplane/other deps in - # these paths. - queried_dataplane_core_min_deps = self._build_graph.query_external_deps( - '//source/common/api/...', '//source/common/buffer/...', '//source/common/chromium_url/...', - '//source/common/crypto/...', '//source/common/conn_pool/...', - '//source/common/formatter/...', '//source/common/http/...', '//source/common/ssl/...', - '//source/common/tcp/...', '//source/common/tcp_proxy/...', '//source/common/network/...') - # It's hard to disentangle API and dataplane today. - expected_dataplane_core_deps = self._dep_info.deps_by_use_category('dataplane_core').union( - self._dep_info.deps_by_use_category('api')) - bad_dataplane_core_deps = queried_dataplane_core_min_deps.difference( - expected_dataplane_core_deps) - if len(bad_dataplane_core_deps) > 0: - raise DependencyError( - f'Observed dataplane core deps {queried_dataplane_core_min_deps} is not covered by ' - f'"use_category" implied core deps {expected_dataplane_core_deps}: {bad_dataplane_core_deps} ' - 'are missing') - - def validate_control_plane_deps(self): - """Validate controlplane dependencies. + print('Validating data-plane dependencies...') + # Necessary but not sufficient for dataplane. With some refactoring we could + # probably have more precise tagging of dataplane/controlplane/other deps in + # these paths. + queried_dataplane_core_min_deps = self._build_graph.query_external_deps( + '//source/common/api/...', '//source/common/buffer/...', + '//source/common/chromium_url/...', '//source/common/crypto/...', + '//source/common/conn_pool/...', '//source/common/formatter/...', + '//source/common/http/...', '//source/common/ssl/...', '//source/common/tcp/...', + '//source/common/tcp_proxy/...', '//source/common/network/...') + # It's hard to disentangle API and dataplane today. + expected_dataplane_core_deps = self._dep_info.deps_by_use_category('dataplane_core').union( + self._dep_info.deps_by_use_category('api')) + bad_dataplane_core_deps = queried_dataplane_core_min_deps.difference( + expected_dataplane_core_deps) + if len(bad_dataplane_core_deps) > 0: + raise DependencyError( + f'Observed dataplane core deps {queried_dataplane_core_min_deps} is not covered by ' + f'"use_category" implied core deps {expected_dataplane_core_deps}: {bad_dataplane_core_deps} ' + 'are missing') + + def validate_control_plane_deps(self): + """Validate controlplane dependencies. Check that we at least tag as controlplane dependencies that match some well-known targets for @@ -239,25 +244,25 @@ def validate_control_plane_deps(self): Raises: DependencyError: on a dependency validation error. """ - print('Validating control-plane dependencies...') - # Necessary but not sufficient for controlplane. With some refactoring we could - # probably have more precise tagging of dataplane/controlplane/other deps in - # these paths. - queried_controlplane_core_min_deps = self._build_graph.query_external_deps( - '//source/common/config/...') - # Controlplane will always depend on API. - expected_controlplane_core_deps = self._dep_info.deps_by_use_category('controlplane').union( - self._dep_info.deps_by_use_category('api')) - bad_controlplane_core_deps = queried_controlplane_core_min_deps.difference( - expected_controlplane_core_deps) - if len(bad_controlplane_core_deps) > 0: - raise DependencyError( - f'Observed controlplane core deps {queried_controlplane_core_min_deps} is not covered ' - 'by "use_category" implied core deps {expected_controlplane_core_deps}: ' - '{bad_controlplane_core_deps} are missing') - - def validate_extension_deps(self, name, target): - """Validate that extensions are correctly declared for dataplane_ext and observability_ext. + print('Validating control-plane dependencies...') + # Necessary but not sufficient for controlplane. With some refactoring we could + # probably have more precise tagging of dataplane/controlplane/other deps in + # these paths. + queried_controlplane_core_min_deps = self._build_graph.query_external_deps( + '//source/common/config/...') + # Controlplane will always depend on API. + expected_controlplane_core_deps = self._dep_info.deps_by_use_category('controlplane').union( + self._dep_info.deps_by_use_category('api')) + bad_controlplane_core_deps = queried_controlplane_core_min_deps.difference( + expected_controlplane_core_deps) + if len(bad_controlplane_core_deps) > 0: + raise DependencyError( + f'Observed controlplane core deps {queried_controlplane_core_min_deps} is not covered ' + 'by "use_category" implied core deps {expected_controlplane_core_deps}: ' + '{bad_controlplane_core_deps} are missing') + + def validate_extension_deps(self, name, target): + """Validate that extensions are correctly declared for dataplane_ext and observability_ext. Args: name: extension name. @@ -266,49 +271,52 @@ def validate_extension_deps(self, name, target): Raises: DependencyError: on a dependency validation error. """ - print(f'Validating extension {name} dependencies...') - queried_deps = self._build_graph.query_external_deps(target) - marginal_deps = queried_deps.difference(self._queried_core_deps) - expected_deps = [] - for d in marginal_deps: - metadata = self._dep_info.get_metadata(d) - if metadata: - use_category = metadata['use_category'] - valid_use_category = any( - c in use_category for c in ['dataplane_ext', 'observability_ext', 'other', 'api']) - if not valid_use_category: - raise DependencyError( - f'Extensions {name} depends on {d} with "use_category" not including ' - '["dataplane_ext", "observability_ext", "api", "other"]') - if 'extensions' in metadata: - allowed_extensions = metadata['extensions'] - if name not in allowed_extensions: - raise DependencyError( - f'Extension {name} depends on {d} but {d} does not list {name} in its allowlist') - - def validate_all(self): - """Collection of all validations. + print(f'Validating extension {name} dependencies...') + queried_deps = self._build_graph.query_external_deps(target) + marginal_deps = queried_deps.difference(self._queried_core_deps) + expected_deps = [] + for d in marginal_deps: + metadata = self._dep_info.get_metadata(d) + if metadata: + use_category = metadata['use_category'] + valid_use_category = any( + c in use_category + for c in ['dataplane_ext', 'observability_ext', 'other', 'api']) + if not valid_use_category: + raise DependencyError( + f'Extensions {name} depends on {d} with "use_category" not including ' + '["dataplane_ext", "observability_ext", "api", "other"]') + if 'extensions' in metadata: + allowed_extensions = metadata['extensions'] + if name not in allowed_extensions: + raise DependencyError( + f'Extension {name} depends on {d} but {d} does not list {name} in its allowlist' + ) + + def validate_all(self): + """Collection of all validations. Raises: DependencyError: on a dependency validation error. """ - self.validate_build_graph_structure() - self.validate_test_only_deps() - self.validate_data_plane_core_deps() - self.validate_control_plane_deps() - # Validate the marginal dependencies introduced for each extension. - for name, target in sorted(build_graph.list_extensions()): - target_all = EXTENSION_LABEL_RE.match(target).group(1) + '/...' - self.validate_extension_deps(name, target_all) + self.validate_build_graph_structure() + self.validate_test_only_deps() + self.validate_data_plane_core_deps() + self.validate_control_plane_deps() + # Validate the marginal dependencies introduced for each extension. + for name, target in sorted(build_graph.list_extensions()): + target_all = EXTENSION_LABEL_RE.match(target).group(1) + '/...' + self.validate_extension_deps(name, target_all) if __name__ == '__main__': - dep_info = DependencyInfo() - build_graph = BuildGraph() - validator = Validator(dep_info, build_graph) - try: - validator.validate_all() - except DependencyError as e: - print('Dependency validation failed, please check metadata in bazel/repository_locations.bzl') - print(e) - sys.exit(1) + dep_info = DependencyInfo() + build_graph = BuildGraph() + validator = Validator(dep_info, build_graph) + try: + validator.validate_all() + except DependencyError as e: + print( + 'Dependency validation failed, please check metadata in bazel/repository_locations.bzl') + print(e) + sys.exit(1) diff --git a/tools/dependency/validate_test.py b/tools/dependency/validate_test.py index 090a6f0142a4..1b4f0a989dbd 100755 --- a/tools/dependency/validate_test.py +++ b/tools/dependency/validate_test.py @@ -7,126 +7,130 @@ class FakeDependencyInfo(object): - """validate.DependencyInfo fake.""" + """validate.DependencyInfo fake.""" - def __init__(self, deps): - self._deps = deps + def __init__(self, deps): + self._deps = deps - def deps_by_use_category(self, use_category): - return set(n for n, m in self._deps.items() if use_category in m['use_category']) + def deps_by_use_category(self, use_category): + return set(n for n, m in self._deps.items() if use_category in m['use_category']) - def get_metadata(self, dependency): - return self._deps.get(dependency) + def get_metadata(self, dependency): + return self._deps.get(dependency) class FakeBuildGraph(object): - """validate.BuildGraph fake.""" + """validate.BuildGraph fake.""" - def __init__(self, reachable_deps, extensions): - self._reachable_deps = reachable_deps - self._extensions = extensions + def __init__(self, reachable_deps, extensions): + self._reachable_deps = reachable_deps + self._extensions = extensions - def query_external_deps(self, *targets): - return set(sum((self._reachable_deps.get(t, []) for t in targets), [])) + def query_external_deps(self, *targets): + return set(sum((self._reachable_deps.get(t, []) for t in targets), [])) - def list_extensions(self): - return self._extensions + def list_extensions(self): + return self._extensions def fake_dep(use_category, extensions=[]): - return {'use_category': use_category, 'extensions': extensions} + return {'use_category': use_category, 'extensions': extensions} class ValidateTest(unittest.TestCase): - def build_validator(self, deps, reachable_deps, extensions=[]): - return validate.Validator(FakeDependencyInfo(deps), FakeBuildGraph(reachable_deps, extensions)) - - def test_valid_build_graph_structure(self): - validator = self.build_validator({}, { - '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'], - '//source/extensions/...': ['b'], - '//source/...': ['a', 'b'] - }) - validator.validate_build_graph_structure() - - def test_invalid_build_graph_structure(self): - validator = self.build_validator({}, { - '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'], - '//source/extensions/...': ['b'], - '//source/...': ['a', 'b', 'c'] - }) - self.assertRaises(validate.DependencyError, lambda: validator.validate_build_graph_structure()) - - def test_valid_test_only_deps(self): - validator = self.build_validator({'a': fake_dep('dataplane_core')}, {'//source/...': ['a']}) - validator.validate_test_only_deps() - validator = self.build_validator({'a': fake_dep('test_only')}, {'//test/...': ['a', 'b__pip3']}) - validator.validate_test_only_deps() - - def test_invalid_test_only_deps(self): - validator = self.build_validator({'a': fake_dep('test_only')}, {'//source/...': ['a']}) - self.assertRaises(validate.DependencyError, lambda: validator.validate_test_only_deps()) - validator = self.build_validator({'a': fake_dep('test_only')}, {'//test/...': ['b']}) - self.assertRaises(validate.DependencyError, lambda: validator.validate_test_only_deps()) - - def test_valid_dataplane_core_deps(self): - validator = self.build_validator({'a': fake_dep('dataplane_core')}, - {'//source/common/http/...': ['a']}) - validator.validate_data_plane_core_deps() - - def test_invalid_dataplane_core_deps(self): - validator = self.build_validator({'a': fake_dep('controlplane')}, - {'//source/common/http/...': ['a']}) - self.assertRaises(validate.DependencyError, lambda: validator.validate_data_plane_core_deps()) - - def test_valid_controlplane_deps(self): - validator = self.build_validator({'a': fake_dep('controlplane')}, - {'//source/common/config/...': ['a']}) - validator.validate_control_plane_deps() - - def test_invalid_controlplane_deps(self): - validator = self.build_validator({'a': fake_dep('other')}, - {'//source/common/config/...': ['a']}) - self.assertRaises(validate.DependencyError, lambda: validator.validate_control_plane_deps()) - - def test_valid_extension_deps(self): - validator = self.build_validator( - { - 'a': fake_dep('controlplane'), - 'b': fake_dep('dataplane_ext', ['foo']) - }, { - '//source/extensions/foo/...': ['a', 'b'], - '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'] - }) - validator.validate_extension_deps('foo', '//source/extensions/foo/...') - - def test_invalid_extension_deps_wrong_category(self): - validator = self.build_validator( - { - 'a': fake_dep('controlplane'), - 'b': fake_dep('controlplane', ['foo']) - }, { - '//source/extensions/foo/...': ['a', 'b'], - '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'] + def build_validator(self, deps, reachable_deps, extensions=[]): + return validate.Validator(FakeDependencyInfo(deps), + FakeBuildGraph(reachable_deps, extensions)) + + def test_valid_build_graph_structure(self): + validator = self.build_validator({}, { + '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'], + '//source/extensions/...': ['b'], + '//source/...': ['a', 'b'] }) - self.assertRaises( - validate.DependencyError, - lambda: validator.validate_extension_deps('foo', '//source/extensions/foo/...')) - - def test_invalid_extension_deps_allowlist(self): - validator = self.build_validator( - { - 'a': fake_dep('controlplane'), - 'b': fake_dep('dataplane_ext', ['bar']) - }, { - '//source/extensions/foo/...': ['a', 'b'], - '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'] + validator.validate_build_graph_structure() + + def test_invalid_build_graph_structure(self): + validator = self.build_validator({}, { + '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'], + '//source/extensions/...': ['b'], + '//source/...': ['a', 'b', 'c'] }) - self.assertRaises( - validate.DependencyError, - lambda: validator.validate_extension_deps('foo', '//source/extensions/foo/...')) + self.assertRaises(validate.DependencyError, + lambda: validator.validate_build_graph_structure()) + + def test_valid_test_only_deps(self): + validator = self.build_validator({'a': fake_dep('dataplane_core')}, {'//source/...': ['a']}) + validator.validate_test_only_deps() + validator = self.build_validator({'a': fake_dep('test_only')}, + {'//test/...': ['a', 'b__pip3']}) + validator.validate_test_only_deps() + + def test_invalid_test_only_deps(self): + validator = self.build_validator({'a': fake_dep('test_only')}, {'//source/...': ['a']}) + self.assertRaises(validate.DependencyError, lambda: validator.validate_test_only_deps()) + validator = self.build_validator({'a': fake_dep('test_only')}, {'//test/...': ['b']}) + self.assertRaises(validate.DependencyError, lambda: validator.validate_test_only_deps()) + + def test_valid_dataplane_core_deps(self): + validator = self.build_validator({'a': fake_dep('dataplane_core')}, + {'//source/common/http/...': ['a']}) + validator.validate_data_plane_core_deps() + + def test_invalid_dataplane_core_deps(self): + validator = self.build_validator({'a': fake_dep('controlplane')}, + {'//source/common/http/...': ['a']}) + self.assertRaises(validate.DependencyError, + lambda: validator.validate_data_plane_core_deps()) + + def test_valid_controlplane_deps(self): + validator = self.build_validator({'a': fake_dep('controlplane')}, + {'//source/common/config/...': ['a']}) + validator.validate_control_plane_deps() + + def test_invalid_controlplane_deps(self): + validator = self.build_validator({'a': fake_dep('other')}, + {'//source/common/config/...': ['a']}) + self.assertRaises(validate.DependencyError, lambda: validator.validate_control_plane_deps()) + + def test_valid_extension_deps(self): + validator = self.build_validator( + { + 'a': fake_dep('controlplane'), + 'b': fake_dep('dataplane_ext', ['foo']) + }, { + '//source/extensions/foo/...': ['a', 'b'], + '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'] + }) + validator.validate_extension_deps('foo', '//source/extensions/foo/...') + + def test_invalid_extension_deps_wrong_category(self): + validator = self.build_validator( + { + 'a': fake_dep('controlplane'), + 'b': fake_dep('controlplane', ['foo']) + }, { + '//source/extensions/foo/...': ['a', 'b'], + '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'] + }) + self.assertRaises( + validate.DependencyError, + lambda: validator.validate_extension_deps('foo', '//source/extensions/foo/...')) + + def test_invalid_extension_deps_allowlist(self): + validator = self.build_validator( + { + 'a': fake_dep('controlplane'), + 'b': fake_dep('dataplane_ext', ['bar']) + }, { + '//source/extensions/foo/...': ['a', 'b'], + '//source/exe:envoy_main_common_with_core_extensions_lib': ['a'] + }) + self.assertRaises( + validate.DependencyError, + lambda: validator.validate_extension_deps('foo', '//source/extensions/foo/...')) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tools/deprecate_features/deprecate_features.py b/tools/deprecate_features/deprecate_features.py index 3da00632ba9f..958d1592ce9e 100644 --- a/tools/deprecate_features/deprecate_features.py +++ b/tools/deprecate_features/deprecate_features.py @@ -10,41 +10,42 @@ # Sorts out the list of deprecated proto fields which should be disallowed and returns a tuple of # email and code changes. def deprecate_proto(): - grep_output = subprocess.check_output('grep -r "deprecated = true" api/*', shell=True) + grep_output = subprocess.check_output('grep -r "deprecated = true" api/*', shell=True) - filenames_and_fields = set() + filenames_and_fields = set() - # Compile the set of deprecated fields and the files they're in, deduping via set. - deprecated_regex = re.compile(r'.*\/([^\/]*.proto):[^=]* ([^= ]+) =.*') - for byte_line in grep_output.splitlines(): - line = str(byte_line) - match = deprecated_regex.match(line) - if match: - filenames_and_fields.add(tuple([match.group(1), match.group(2)])) - else: - print('no match in ' + line + ' please address manually!') + # Compile the set of deprecated fields and the files they're in, deduping via set. + deprecated_regex = re.compile(r'.*\/([^\/]*.proto):[^=]* ([^= ]+) =.*') + for byte_line in grep_output.splitlines(): + line = str(byte_line) + match = deprecated_regex.match(line) + if match: + filenames_and_fields.add(tuple([match.group(1), match.group(2)])) + else: + print('no match in ' + line + ' please address manually!') - # Now discard any deprecated features already listed in runtime_features - exiting_deprecated_regex = re.compile(r'.*"envoy.deprecated_features.(.*):(.*)",.*') - with open('source/common/runtime/runtime_features.cc', 'r') as features: - for line in features.readlines(): - match = exiting_deprecated_regex.match(line) - if match: - filenames_and_fields.discard(tuple([match.group(1), match.group(2)])) + # Now discard any deprecated features already listed in runtime_features + exiting_deprecated_regex = re.compile(r'.*"envoy.deprecated_features.(.*):(.*)",.*') + with open('source/common/runtime/runtime_features.cc', 'r') as features: + for line in features.readlines(): + match = exiting_deprecated_regex.match(line) + if match: + filenames_and_fields.discard(tuple([match.group(1), match.group(2)])) - # Finally sort out the code to add to runtime_features.cc and a canned email for envoy-announce. - code_snippets = [] - email_snippets = [] - for (filename, field) in filenames_and_fields: - code_snippets.append(' "envoy.deprecated_features.' + filename + ':' + field + '",\n') - email_snippets.append(field + ' from ' + filename + '\n') - code = ''.join(code_snippets) - email = '' - if email_snippets: - email = ('\nThe following deprecated configuration fields will be disallowed by default:\n' + - ''.join(email_snippets)) + # Finally sort out the code to add to runtime_features.cc and a canned email for envoy-announce. + code_snippets = [] + email_snippets = [] + for (filename, field) in filenames_and_fields: + code_snippets.append(' "envoy.deprecated_features.' + filename + ':' + field + '",\n') + email_snippets.append(field + ' from ' + filename + '\n') + code = ''.join(code_snippets) + email = '' + if email_snippets: + email = ( + '\nThe following deprecated configuration fields will be disallowed by default:\n' + + ''.join(email_snippets)) - return email, code + return email, code # Gather code and suggested email changes. @@ -57,11 +58,11 @@ def deprecate_proto(): print(email) if not input('Apply relevant runtime changes? [yN] ').strip().lower() in ('y', 'yes'): - exit(1) + exit(1) for line in fileinput.FileInput('source/common/runtime/runtime_features.cc', inplace=1): - if 'envoy.deprecated_features.deprecated.proto:is_deprecated_fatal' in line: - line = line.replace(line, line + deprecate_code) - print(line, end='') + if 'envoy.deprecated_features.deprecated.proto:is_deprecated_fatal' in line: + line = line.replace(line, line + deprecate_code) + print(line, end='') print('\nChanges applied. Please send the email above to envoy-announce.\n') diff --git a/tools/deprecate_version/deprecate_version.py b/tools/deprecate_version/deprecate_version.py index 780c247e3022..2452f041f9c3 100644 --- a/tools/deprecate_version/deprecate_version.py +++ b/tools/deprecate_version/deprecate_version.py @@ -34,9 +34,9 @@ from git import Repo try: - input = raw_input # Python 2 + input = raw_input # Python 2 except NameError: - pass # Python 3 + pass # Python 3 # Tag issues created with these labels. LABELS = ['deprecation', 'tech debt', 'no stalebot'] @@ -44,148 +44,151 @@ # Errors that happen during issue creation. class DeprecateVersionError(Exception): - pass + pass def get_confirmation(): - """Obtain stdin confirmation to create issues in GH.""" - return input('Creates issues? [yN] ').strip().lower() in ('y', 'yes') + """Obtain stdin confirmation to create issues in GH.""" + return input('Creates issues? [yN] ').strip().lower() in ('y', 'yes') def create_issues(access_token, runtime_and_pr): - """Create issues in GitHub for code to clean up old runtime guarded features. + """Create issues in GitHub for code to clean up old runtime guarded features. Args: access_token: GitHub access token (see comment at top of file). runtime_and_pr: a list of runtime guards and the PRs and commits they were added. """ - git = github.Github(access_token) - repo = git.get_repo('envoyproxy/envoy') - - # Find GitHub label objects for LABELS. - labels = [] - for label in repo.get_labels(): - if label.name in LABELS: - labels.append(label) - if len(labels) != len(LABELS): - raise DeprecateVersionError('Unknown labels (expected %s, got %s)' % (LABELS, labels)) - - issues = [] - for runtime_guard, pr, commit in runtime_and_pr: - # Who is the author? - if pr: - # Extract PR title, number, and author. - pr_info = repo.get_pull(pr) - change_title = pr_info.title - number = ('#%d') % pr - login = pr_info.user.login - else: - # Extract commit message, sha, and author. - # Only keep commit message title (remove description), and truncate to 50 characters. - change_title = commit.message.split('\n')[0][:50] - number = ('commit %s') % commit.hexsha - email = commit.author.email - # Use the commit author's email to search through users for their login. - search_user = git.search_users(email.split('@')[0] + " in:email") - login = search_user[0].login if search_user else None - - title = '%s deprecation' % (runtime_guard) - body = ('Your change %s (%s) introduced a runtime guarded feature. It has been 6 months since ' + git = github.Github(access_token) + repo = git.get_repo('envoyproxy/envoy') + + # Find GitHub label objects for LABELS. + labels = [] + for label in repo.get_labels(): + if label.name in LABELS: + labels.append(label) + if len(labels) != len(LABELS): + raise DeprecateVersionError('Unknown labels (expected %s, got %s)' % (LABELS, labels)) + + issues = [] + for runtime_guard, pr, commit in runtime_and_pr: + # Who is the author? + if pr: + # Extract PR title, number, and author. + pr_info = repo.get_pull(pr) + change_title = pr_info.title + number = ('#%d') % pr + login = pr_info.user.login + else: + # Extract commit message, sha, and author. + # Only keep commit message title (remove description), and truncate to 50 characters. + change_title = commit.message.split('\n')[0][:50] + number = ('commit %s') % commit.hexsha + email = commit.author.email + # Use the commit author's email to search through users for their login. + search_user = git.search_users(email.split('@')[0] + " in:email") + login = search_user[0].login if search_user else None + + title = '%s deprecation' % (runtime_guard) + body = ( + 'Your change %s (%s) introduced a runtime guarded feature. It has been 6 months since ' 'the new code has been exercised by default, so it\'s time to remove the old code ' 'path. This issue tracks source code cleanup so we don\'t forget.') % (number, change_title) - print(title) - print(body) - print(' >> Assigning to %s' % (login or email)) - search_title = '%s in:title' % title - - # TODO(htuch): Figure out how to do this without legacy and faster. - exists = repo.legacy_search_issues('open', search_title) or repo.legacy_search_issues( - 'closed', search_title) - if exists: - print("Issue with %s already exists" % search_title) - print(exists) - print(' >> Issue already exists, not posting!') - else: - issues.append((title, body, login)) - - if not issues: - print('No features to deprecate in this release') - return - - if get_confirmation(): - print('Creating issues...') - for title, body, login in issues: - try: - repo.create_issue(title, body=body, assignees=[login], labels=labels) - except github.GithubException as e: - try: - if login: - body += '\ncc @' + login - repo.create_issue(title, body=body, labels=labels) - print(('unable to assign issue %s to %s. Add them to the Envoy proxy org' - 'and assign it their way.') % (title, login)) - except github.GithubException as e: - print('GithubException while creating issue.') - raise + print(title) + print(body) + print(' >> Assigning to %s' % (login or email)) + search_title = '%s in:title' % title + + # TODO(htuch): Figure out how to do this without legacy and faster. + exists = repo.legacy_search_issues('open', search_title) or repo.legacy_search_issues( + 'closed', search_title) + if exists: + print("Issue with %s already exists" % search_title) + print(exists) + print(' >> Issue already exists, not posting!') + else: + issues.append((title, body, login)) + + if not issues: + print('No features to deprecate in this release') + return + + if get_confirmation(): + print('Creating issues...') + for title, body, login in issues: + try: + repo.create_issue(title, body=body, assignees=[login], labels=labels) + except github.GithubException as e: + try: + if login: + body += '\ncc @' + login + repo.create_issue(title, body=body, labels=labels) + print(('unable to assign issue %s to %s. Add them to the Envoy proxy org' + 'and assign it their way.') % (title, login)) + except github.GithubException as e: + print('GithubException while creating issue.') + raise def get_runtime_and_pr(): - """Returns a list of tuples of [runtime features to deprecate, PR, commit the feature was added] + """Returns a list of tuples of [runtime features to deprecate, PR, commit the feature was added] """ - repo = Repo(os.getcwd()) - - # grep source code looking for reloadable features which are true to find the - # PR they were added. - features_to_flip = [] - - runtime_features = re.compile(r'.*"(envoy.reloadable_features..*)",.*') - - removal_date = date.today() - datetime.timedelta(days=183) - found_test_feature_true = False - - # Walk the blame of runtime_features and look for true runtime features older than 6 months. - for commit, lines in repo.blame('HEAD', 'source/common/runtime/runtime_features.cc'): - for line in lines: - match = runtime_features.match(line) - if match: - runtime_guard = match.group(1) - if runtime_guard == 'envoy.reloadable_features.test_feature_false': - print("Found end sentinel\n") - if not found_test_feature_true: - # The script depends on the cc file having the true runtime block - # before the false runtime block. Fail if one isn't found. - print('Failed to find test_feature_true. Script needs fixing') - sys.exit(1) - return features_to_flip - if runtime_guard == 'envoy.reloadable_features.test_feature_true': - found_test_feature_true = True - continue - pr_num = re.search('\(#(\d+)\)', commit.message) - # Some commits may not come from a PR (if they are part of a security point release). - pr = (int(pr_num.group(1))) if pr_num else None - pr_date = date.fromtimestamp(commit.committed_date) - removable = (pr_date < removal_date) - # Add the runtime guard and PR to the list to file issues about. - print('Flag ' + runtime_guard + ' added at ' + str(pr_date) + ' ' + - (removable and 'and is safe to remove' or 'is not ready to remove')) - if removable: - features_to_flip.append((runtime_guard, pr, commit)) - print('Failed to find test_feature_false. Script needs fixing') - sys.exit(1) + repo = Repo(os.getcwd()) + + # grep source code looking for reloadable features which are true to find the + # PR they were added. + features_to_flip = [] + + runtime_features = re.compile(r'.*"(envoy.reloadable_features..*)",.*') + + removal_date = date.today() - datetime.timedelta(days=183) + found_test_feature_true = False + + # Walk the blame of runtime_features and look for true runtime features older than 6 months. + for commit, lines in repo.blame('HEAD', 'source/common/runtime/runtime_features.cc'): + for line in lines: + match = runtime_features.match(line) + if match: + runtime_guard = match.group(1) + if runtime_guard == 'envoy.reloadable_features.test_feature_false': + print("Found end sentinel\n") + if not found_test_feature_true: + # The script depends on the cc file having the true runtime block + # before the false runtime block. Fail if one isn't found. + print('Failed to find test_feature_true. Script needs fixing') + sys.exit(1) + return features_to_flip + if runtime_guard == 'envoy.reloadable_features.test_feature_true': + found_test_feature_true = True + continue + pr_num = re.search('\(#(\d+)\)', commit.message) + # Some commits may not come from a PR (if they are part of a security point release). + pr = (int(pr_num.group(1))) if pr_num else None + pr_date = date.fromtimestamp(commit.committed_date) + removable = (pr_date < removal_date) + # Add the runtime guard and PR to the list to file issues about. + print('Flag ' + runtime_guard + ' added at ' + str(pr_date) + ' ' + + (removable and 'and is safe to remove' or 'is not ready to remove')) + if removable: + features_to_flip.append((runtime_guard, pr, commit)) + print('Failed to find test_feature_false. Script needs fixing') + sys.exit(1) if __name__ == '__main__': - runtime_and_pr = get_runtime_and_pr() + runtime_and_pr = get_runtime_and_pr() - if not runtime_and_pr: - print('No code is deprecated.') - sys.exit(0) + if not runtime_and_pr: + print('No code is deprecated.') + sys.exit(0) - access_token = os.getenv('GITHUB_TOKEN') - if not access_token: - print('Missing GITHUB_TOKEN: see instructions in tools/deprecate_version/deprecate_version.py') - sys.exit(1) + access_token = os.getenv('GITHUB_TOKEN') + if not access_token: + print( + 'Missing GITHUB_TOKEN: see instructions in tools/deprecate_version/deprecate_version.py' + ) + sys.exit(1) - create_issues(access_token, runtime_and_pr) + create_issues(access_token, runtime_and_pr) diff --git a/tools/envoy_collect/envoy_collect.py b/tools/envoy_collect/envoy_collect.py index 229f1da15b40..4198b2062f00 100755 --- a/tools/envoy_collect/envoy_collect.py +++ b/tools/envoy_collect/envoy_collect.py @@ -60,11 +60,11 @@ def fetch_url(url): - return urllib.request.urlopen(url).read().decode('utf-8') + return urllib.request.urlopen(url).read().decode('utf-8') def modify_envoy_config(config_path, perf, output_directory): - """Modify Envoy config to support gathering logs, etc. + """Modify Envoy config to support gathering logs, etc. Args: config_path: the command-line specified Envoy config path. @@ -73,38 +73,38 @@ def modify_envoy_config(config_path, perf, output_directory): Returns: (modified Envoy config path, list of additional files to collect) """ - # No modifications yet when in performance profiling mode. - if perf: - return config_path, [] - - # Load original Envoy config. - with open(config_path, 'r') as f: - envoy_config = json.loads(f.read()) - - # Add unconditional access logs for all listeners. - access_log_paths = [] - for n, listener in enumerate(envoy_config['listeners']): - for network_filter in listener['filters']: - if network_filter['name'] == 'http_connection_manager': - config = network_filter['config'] - access_log_path = os.path.join(output_directory, 'access_%d.log' % n) - access_log_config = {'path': access_log_path} - if 'access_log' in config: - config['access_log'].append(access_log_config) - else: - config['access_log'] = [access_log_config] - access_log_paths.append(access_log_path) - - # Write out modified Envoy config. - modified_envoy_config_path = os.path.join(output_directory, 'config.json') - with open(modified_envoy_config_path, 'w') as f: - f.write(json.dumps(envoy_config, indent=2)) - - return modified_envoy_config_path, access_log_paths + # No modifications yet when in performance profiling mode. + if perf: + return config_path, [] + + # Load original Envoy config. + with open(config_path, 'r') as f: + envoy_config = json.loads(f.read()) + + # Add unconditional access logs for all listeners. + access_log_paths = [] + for n, listener in enumerate(envoy_config['listeners']): + for network_filter in listener['filters']: + if network_filter['name'] == 'http_connection_manager': + config = network_filter['config'] + access_log_path = os.path.join(output_directory, 'access_%d.log' % n) + access_log_config = {'path': access_log_path} + if 'access_log' in config: + config['access_log'].append(access_log_config) + else: + config['access_log'] = [access_log_config] + access_log_paths.append(access_log_path) + + # Write out modified Envoy config. + modified_envoy_config_path = os.path.join(output_directory, 'config.json') + with open(modified_envoy_config_path, 'w') as f: + f.write(json.dumps(envoy_config, indent=2)) + + return modified_envoy_config_path, access_log_paths def run_envoy(envoy_shcmd_args, envoy_log_path, admin_address_path, dump_handlers_paths): - """Run Envoy subprocess and trigger admin endpoint gathering on SIGINT. + """Run Envoy subprocess and trigger admin endpoint gathering on SIGINT. Args: envoy_shcmd_args: list of Envoy subprocess args. @@ -115,133 +115,133 @@ def run_envoy(envoy_shcmd_args, envoy_log_path, admin_address_path, dump_handler Returns: The Envoy subprocess exit code. """ - envoy_shcmd = ' '.join(map(pipes.quote, envoy_shcmd_args)) - print(envoy_shcmd) - - # Some process setup stuff to ensure the child process gets cleaned up properly if the - # collector dies and doesn't get its signals implicitly. - def envoy_preexec_fn(): - os.setpgrp() - libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True) - libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM) - - # Launch Envoy, register for SIGINT, and wait for the child process to exit. - with open(envoy_log_path, 'w') as envoy_log: - envoy_proc = sp.Popen(envoy_shcmd, - stdin=sp.PIPE, - stderr=envoy_log, - preexec_fn=envoy_preexec_fn, - shell=True) - - def signal_handler(signum, frame): - # The read is deferred until the signal so that the Envoy process gets a - # chance to write the file out. - with open(admin_address_path, 'r') as f: - admin_address = 'http://%s' % f.read() - # Fetch from the admin endpoint. - for handler, path in dump_handlers_paths.items(): - handler_url = '%s/%s' % (admin_address, handler) - print('Fetching %s' % handler_url) - with open(path, 'w') as f: - f.write(fetch_url(handler_url)) - # Send SIGINT to Envoy process, it should exit and execution will - # continue from the envoy_proc.wait() below. - print('Sending Envoy process (PID=%d) SIGINT...' % envoy_proc.pid) - envoy_proc.send_signal(signal.SIGINT) - - signal.signal(signal.SIGINT, signal_handler) - return envoy_proc.wait() + envoy_shcmd = ' '.join(map(pipes.quote, envoy_shcmd_args)) + print(envoy_shcmd) + + # Some process setup stuff to ensure the child process gets cleaned up properly if the + # collector dies and doesn't get its signals implicitly. + def envoy_preexec_fn(): + os.setpgrp() + libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True) + libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM) + + # Launch Envoy, register for SIGINT, and wait for the child process to exit. + with open(envoy_log_path, 'w') as envoy_log: + envoy_proc = sp.Popen(envoy_shcmd, + stdin=sp.PIPE, + stderr=envoy_log, + preexec_fn=envoy_preexec_fn, + shell=True) + + def signal_handler(signum, frame): + # The read is deferred until the signal so that the Envoy process gets a + # chance to write the file out. + with open(admin_address_path, 'r') as f: + admin_address = 'http://%s' % f.read() + # Fetch from the admin endpoint. + for handler, path in dump_handlers_paths.items(): + handler_url = '%s/%s' % (admin_address, handler) + print('Fetching %s' % handler_url) + with open(path, 'w') as f: + f.write(fetch_url(handler_url)) + # Send SIGINT to Envoy process, it should exit and execution will + # continue from the envoy_proc.wait() below. + print('Sending Envoy process (PID=%d) SIGINT...' % envoy_proc.pid) + envoy_proc.send_signal(signal.SIGINT) + + signal.signal(signal.SIGINT, signal_handler) + return envoy_proc.wait() def envoy_collect(parse_result, unknown_args): - """Run Envoy and collect its artifacts. + """Run Envoy and collect its artifacts. Args: parse_result: Namespace object with envoy_collect.py's args. unknown_args: list of remaining args to pass to Envoy binary. """ - # Are we in performance mode? Otherwise, debug. - perf = parse_result.performance - return_code = 1 # Non-zero default return. - envoy_tmpdir = tempfile.mkdtemp(prefix='envoy-collect-tmp-') - # Try and do stuff with envoy_tmpdir, rm -rf regardless of success/failure. - try: - # Setup Envoy config and determine the paths of the files we're going to - # generate. - modified_envoy_config_path, access_log_paths = modify_envoy_config( - parse_result.config_path, perf, envoy_tmpdir) - dump_handlers_paths = {h: os.path.join(envoy_tmpdir, '%s.txt' % h) for h in DUMP_HANDLERS} - envoy_log_path = os.path.join(envoy_tmpdir, 'envoy.log') - # The manifest of files that will be placed in the output .tar. - manifest = access_log_paths + list( - dump_handlers_paths.values()) + [modified_envoy_config_path, envoy_log_path] - # This is where we will find out where the admin endpoint is listening. - admin_address_path = os.path.join(envoy_tmpdir, 'admin_address.txt') - - # Only run under 'perf record' in performance mode. - if perf: - perf_data_path = os.path.join(envoy_tmpdir, 'perf.data') - manifest.append(perf_data_path) - perf_record_args = [ - PERF_PATH, - 'record', - '-o', - perf_data_path, - '-g', - '--', - ] - else: - perf_record_args = [] - - # This is how we will invoke the wrapped envoy. - envoy_shcmd_args = perf_record_args + [ - parse_result.envoy_binary, - '-c', - modified_envoy_config_path, - '-l', - 'error' if perf else 'trace', - '--admin-address-path', - admin_address_path, - ] + unknown_args[1:] - - # Run the Envoy process (under 'perf record' if needed). - return_code = run_envoy(envoy_shcmd_args, envoy_log_path, admin_address_path, - dump_handlers_paths) - - # Collect manifest files and tar them. - with tarfile.TarFile(parse_result.output_path, 'w') as output_tar: - for path in manifest: - if os.path.exists(path): - print('Adding %s to archive' % path) - output_tar.add(path, arcname=os.path.basename(path)) + # Are we in performance mode? Otherwise, debug. + perf = parse_result.performance + return_code = 1 # Non-zero default return. + envoy_tmpdir = tempfile.mkdtemp(prefix='envoy-collect-tmp-') + # Try and do stuff with envoy_tmpdir, rm -rf regardless of success/failure. + try: + # Setup Envoy config and determine the paths of the files we're going to + # generate. + modified_envoy_config_path, access_log_paths = modify_envoy_config( + parse_result.config_path, perf, envoy_tmpdir) + dump_handlers_paths = {h: os.path.join(envoy_tmpdir, '%s.txt' % h) for h in DUMP_HANDLERS} + envoy_log_path = os.path.join(envoy_tmpdir, 'envoy.log') + # The manifest of files that will be placed in the output .tar. + manifest = access_log_paths + list( + dump_handlers_paths.values()) + [modified_envoy_config_path, envoy_log_path] + # This is where we will find out where the admin endpoint is listening. + admin_address_path = os.path.join(envoy_tmpdir, 'admin_address.txt') + + # Only run under 'perf record' in performance mode. + if perf: + perf_data_path = os.path.join(envoy_tmpdir, 'perf.data') + manifest.append(perf_data_path) + perf_record_args = [ + PERF_PATH, + 'record', + '-o', + perf_data_path, + '-g', + '--', + ] else: - print('%s not found' % path) - - print('Wrote Envoy artifacts to %s' % parse_result.output_path) - finally: - shutil.rmtree(envoy_tmpdir) - return return_code + perf_record_args = [] + + # This is how we will invoke the wrapped envoy. + envoy_shcmd_args = perf_record_args + [ + parse_result.envoy_binary, + '-c', + modified_envoy_config_path, + '-l', + 'error' if perf else 'trace', + '--admin-address-path', + admin_address_path, + ] + unknown_args[1:] + + # Run the Envoy process (under 'perf record' if needed). + return_code = run_envoy(envoy_shcmd_args, envoy_log_path, admin_address_path, + dump_handlers_paths) + + # Collect manifest files and tar them. + with tarfile.TarFile(parse_result.output_path, 'w') as output_tar: + for path in manifest: + if os.path.exists(path): + print('Adding %s to archive' % path) + output_tar.add(path, arcname=os.path.basename(path)) + else: + print('%s not found' % path) + + print('Wrote Envoy artifacts to %s' % parse_result.output_path) + finally: + shutil.rmtree(envoy_tmpdir) + return return_code if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Envoy wrapper to collect stats/log/profile.') - default_output_path = 'envoy-%s.tar' % datetime.datetime.now().isoformat('-') - parser.add_argument('--output-path', default=default_output_path, help='path to output .tar.') - # We either need to interpret or override these, so we declare them in - # envoy_collect.py and always parse and present them again when invoking - # Envoy. - parser.add_argument('--config-path', - '-c', - required=True, - help='Path to Envoy configuration file.') - parser.add_argument('--log-level', - '-l', - help='Envoy log level. This will be overridden when invoking Envoy.') - # envoy_collect specific args. - parser.add_argument('--performance', - action='store_true', - help='Performance mode (collect perf trace, minimize log verbosity).') - parser.add_argument('--envoy-binary', - default=DEFAULT_ENVOY_PATH, - help='Path to Envoy binary (%s by default).' % DEFAULT_ENVOY_PATH) - sys.exit(envoy_collect(*parser.parse_known_args(sys.argv))) + parser = argparse.ArgumentParser(description='Envoy wrapper to collect stats/log/profile.') + default_output_path = 'envoy-%s.tar' % datetime.datetime.now().isoformat('-') + parser.add_argument('--output-path', default=default_output_path, help='path to output .tar.') + # We either need to interpret or override these, so we declare them in + # envoy_collect.py and always parse and present them again when invoking + # Envoy. + parser.add_argument('--config-path', + '-c', + required=True, + help='Path to Envoy configuration file.') + parser.add_argument('--log-level', + '-l', + help='Envoy log level. This will be overridden when invoking Envoy.') + # envoy_collect specific args. + parser.add_argument('--performance', + action='store_true', + help='Performance mode (collect perf trace, minimize log verbosity).') + parser.add_argument('--envoy-binary', + default=DEFAULT_ENVOY_PATH, + help='Path to Envoy binary (%s by default).' % DEFAULT_ENVOY_PATH) + sys.exit(envoy_collect(*parser.parse_known_args(sys.argv))) diff --git a/tools/envoy_headersplit/headersplit.py b/tools/envoy_headersplit/headersplit.py index 22dcb68c3967..60661516a179 100644 --- a/tools/envoy_headersplit/headersplit.py +++ b/tools/envoy_headersplit/headersplit.py @@ -17,7 +17,7 @@ def to_filename(classname: str) -> str: - """ + """ maps mock class name (in C++ codes) to filenames under the Envoy naming convention. e.g. map "MockAdminStream" to "admin_stream" @@ -27,17 +27,17 @@ def to_filename(classname: str) -> str: Returns: corresponding file name """ - filename = classname.replace("Mock", "", 1) # Remove only first "Mock" - ret = "" - for index, val in enumerate(filename): - if val.isupper() and index > 0: - ret += "_" - ret += val - return ret.lower() + filename = classname.replace("Mock", "", 1) # Remove only first "Mock" + ret = "" + for index, val in enumerate(filename): + if val.isupper() and index > 0: + ret += "_" + ret += val + return ret.lower() def get_directives(translation_unit: Type[TranslationUnit]) -> str: - """ + """ "extracts" all header includes statements and other directives from the target source code file for instance: @@ -64,17 +64,17 @@ def get_directives(translation_unit: Type[TranslationUnit]) -> str: We choose to return the string instead of list of includes since we will simply copy-paste the include statements into generated headers. Return string seems more convenient """ - cursor = translation_unit.cursor - for descendant in cursor.walk_preorder(): - if descendant.location.file is not None and descendant.location.file.name == cursor.displayname: - filename = descendant.location.file.name - contents = read_file_contents(filename) - return contents[:descendant.extent.start.offset] - return "" + cursor = translation_unit.cursor + for descendant in cursor.walk_preorder(): + if descendant.location.file is not None and descendant.location.file.name == cursor.displayname: + filename = descendant.location.file.name + contents = read_file_contents(filename) + return contents[:descendant.extent.start.offset] + return "" def cursors_in_same_file(cursor: Cursor) -> List[Cursor]: - """ + """ get all child cursors which are pointing to the same file as the input cursor Args: @@ -83,22 +83,22 @@ def cursors_in_same_file(cursor: Cursor) -> List[Cursor]: Returns: a list of cursor """ - cursors = [] - for descendant in cursor.walk_preorder(): - # We don't want Cursors from files other than the input file, - # otherwise we get definitions for every file included - # when clang parsed the input file (i.e. if we don't limit descendant location, - # it will check definitions from included headers and get class definitions like std::string) - if descendant.location.file is None: - continue - if descendant.location.file.name != cursor.displayname: - continue - cursors.append(descendant) - return cursors + cursors = [] + for descendant in cursor.walk_preorder(): + # We don't want Cursors from files other than the input file, + # otherwise we get definitions for every file included + # when clang parsed the input file (i.e. if we don't limit descendant location, + # it will check definitions from included headers and get class definitions like std::string) + if descendant.location.file is None: + continue + if descendant.location.file.name != cursor.displayname: + continue + cursors.append(descendant) + return cursors def class_definitions(cursor: Cursor) -> List[Cursor]: - """ + """ extracts all class definitions in the file pointed by cursor. (typical mocks.h) Args: @@ -107,23 +107,23 @@ def class_definitions(cursor: Cursor) -> List[Cursor]: Returns: a list of cursor, each pointing to a class definition. """ - cursors = cursors_in_same_file(cursor) - class_cursors = [] - for descendant in cursors: - # check if descendant is pointing to a class declaration block. - if descendant.kind != CursorKind.CLASS_DECL: - continue - if not descendant.is_definition(): - continue - # check if this class is directly enclosed by a namespace. - if descendant.semantic_parent.kind != CursorKind.NAMESPACE: - continue - class_cursors.append(descendant) - return class_cursors + cursors = cursors_in_same_file(cursor) + class_cursors = [] + for descendant in cursors: + # check if descendant is pointing to a class declaration block. + if descendant.kind != CursorKind.CLASS_DECL: + continue + if not descendant.is_definition(): + continue + # check if this class is directly enclosed by a namespace. + if descendant.semantic_parent.kind != CursorKind.NAMESPACE: + continue + class_cursors.append(descendant) + return class_cursors def class_implementations(cursor: Cursor) -> List[Cursor]: - """ + """ extracts all class implementation in the file pointed by cursor. (typical mocks.cc) Args: @@ -132,21 +132,21 @@ def class_implementations(cursor: Cursor) -> List[Cursor]: Returns: a list of cursor, each pointing to a class implementation. """ - cursors = cursors_in_same_file(cursor) - impl_cursors = [] - for descendant in cursors: - if descendant.kind == CursorKind.NAMESPACE: - continue - # check if descendant is pointing to a class method - if descendant.semantic_parent is None: - continue - if descendant.semantic_parent.kind == CursorKind.CLASS_DECL: - impl_cursors.append(descendant) - return impl_cursors + cursors = cursors_in_same_file(cursor) + impl_cursors = [] + for descendant in cursors: + if descendant.kind == CursorKind.NAMESPACE: + continue + # check if descendant is pointing to a class method + if descendant.semantic_parent is None: + continue + if descendant.semantic_parent.kind == CursorKind.CLASS_DECL: + impl_cursors.append(descendant) + return impl_cursors def extract_definition(cursor: Cursor, classnames: List[str]) -> Tuple[str, str, List[str]]: - """ + """ extracts class definition source code pointed by the cursor parameter. and find dependent mock classes by naming look up. @@ -163,29 +163,29 @@ def extract_definition(cursor: Cursor, classnames: List[str]) -> Tuple[str, str, It can not detect and resolve forward declaration and cyclic dependency. Need to address manually. """ - filename = cursor.location.file.name - contents = read_file_contents(filename) - class_name = cursor.spelling - class_defn = contents[cursor.extent.start.offset:cursor.extent.end.offset] + ";" - # need to know enclosing semantic parents (namespaces) - # to generate corresponding definitions - parent_cursor = cursor.semantic_parent - while parent_cursor.kind == CursorKind.NAMESPACE: - if parent_cursor.spelling == "": - break - class_defn = "namespace {} {{\n".format(parent_cursor.spelling) + class_defn + "\n}\n" - parent_cursor = parent_cursor.semantic_parent - # resolve dependency - # by simple naming look up - deps = set() - for classname in classnames: - if classname in class_defn and classname != class_name: - deps.add(classname) - return class_name, class_defn, deps + filename = cursor.location.file.name + contents = read_file_contents(filename) + class_name = cursor.spelling + class_defn = contents[cursor.extent.start.offset:cursor.extent.end.offset] + ";" + # need to know enclosing semantic parents (namespaces) + # to generate corresponding definitions + parent_cursor = cursor.semantic_parent + while parent_cursor.kind == CursorKind.NAMESPACE: + if parent_cursor.spelling == "": + break + class_defn = "namespace {} {{\n".format(parent_cursor.spelling) + class_defn + "\n}\n" + parent_cursor = parent_cursor.semantic_parent + # resolve dependency + # by simple naming look up + deps = set() + for classname in classnames: + if classname in class_defn and classname != class_name: + deps.add(classname) + return class_name, class_defn, deps def get_implline(cursor: Cursor) -> int: - """ + """ finds the first line of implementation source code for class method pointed by the cursor parameter. @@ -206,11 +206,11 @@ def get_implline(cursor: Cursor) -> int: offset from Cursor, we can still get the start line of the corresponding method instead. (We can't get the correct line number for the last line due to skipping function bodies) """ - return cursor.extent.start.line - 1 + return cursor.extent.start.line - 1 def extract_implementations(impl_cursors: List[Cursor], source_code: str) -> Dict[str, str]: - """ + """ extracts method function body for each cursor in list impl_cursors from source code groups those function bodies with class name to help generating the divided {classname}.cc returns a dict maps class name to the concatenation of all its member methods implementations. @@ -222,35 +222,35 @@ def extract_implementations(impl_cursors: List[Cursor], source_code: str) -> Dic Returns: classname_to_impl: a dict maps class name to its member methods implementations """ - classname_to_impl = dict() - for i, cursor in enumerate(impl_cursors): - classname = cursor.semantic_parent.spelling - # get first line of function body - implline = get_implline(cursor) - # get last line of function body - if i + 1 < len(impl_cursors): - # i is not the last method, get the start line for the next method - # as the last line of i - impl_end = get_implline(impl_cursors[i + 1]) - impl = "".join(source_code[implline:impl_end]) - else: - # i is the last method, after removing the lines containing close brackets - # for namespaces, the rest should be the function body - offset = 0 - while implline + offset < len(source_code): - if "// namespace" in source_code[implline + offset]: - break - offset += 1 - impl = "".join(source_code[implline:implline + offset]) - if classname in classname_to_impl: - classname_to_impl[classname] += impl + "\n" - else: - classname_to_impl[classname] = impl + "\n" - return classname_to_impl + classname_to_impl = dict() + for i, cursor in enumerate(impl_cursors): + classname = cursor.semantic_parent.spelling + # get first line of function body + implline = get_implline(cursor) + # get last line of function body + if i + 1 < len(impl_cursors): + # i is not the last method, get the start line for the next method + # as the last line of i + impl_end = get_implline(impl_cursors[i + 1]) + impl = "".join(source_code[implline:impl_end]) + else: + # i is the last method, after removing the lines containing close brackets + # for namespaces, the rest should be the function body + offset = 0 + while implline + offset < len(source_code): + if "// namespace" in source_code[implline + offset]: + break + offset += 1 + impl = "".join(source_code[implline:implline + offset]) + if classname in classname_to_impl: + classname_to_impl[classname] += impl + "\n" + else: + classname_to_impl[classname] = impl + "\n" + return classname_to_impl def get_enclosing_namespace(defn: Cursor) -> Tuple[str, str]: - """ + """ retrieves all enclosing namespaces for the class pointed by defn. this is necessary to construct the mock class header e.g.: @@ -276,31 +276,31 @@ class MockClass {...} Returns: namespace_prefix, namespace_suffix: a pair of string, representing the enclosing namespaces """ - namespace_prefix = "" - namespace_suffix = "" - parent_cursor = defn.semantic_parent - while parent_cursor.kind == CursorKind.NAMESPACE: - if parent_cursor.spelling == "": - break - namespace_prefix = "namespace {} {{\n".format(parent_cursor.spelling) + namespace_prefix - namespace_suffix += "\n}" - parent_cursor = parent_cursor.semantic_parent - namespace_suffix += "\n" - return namespace_prefix, namespace_suffix + namespace_prefix = "" + namespace_suffix = "" + parent_cursor = defn.semantic_parent + while parent_cursor.kind == CursorKind.NAMESPACE: + if parent_cursor.spelling == "": + break + namespace_prefix = "namespace {} {{\n".format(parent_cursor.spelling) + namespace_prefix + namespace_suffix += "\n}" + parent_cursor = parent_cursor.semantic_parent + namespace_suffix += "\n" + return namespace_prefix, namespace_suffix def read_file_contents(path): - with open(path, "r") as input_file: - return input_file.read() + with open(path, "r") as input_file: + return input_file.read() def write_file_contents(class_name, class_defn, class_impl): - with open("{}.h".format(to_filename(class_name)), "w") as decl_file: - decl_file.write(class_defn) - with open("{}.cc".format(to_filename(class_name)), "w") as impl_file: - impl_file.write(class_impl) - # generating bazel build file, need to fill dependency manually - bazel_text = """ + with open("{}.h".format(to_filename(class_name)), "w") as decl_file: + decl_file.write(class_defn) + with open("{}.cc".format(to_filename(class_name)), "w") as impl_file: + impl_file.write(class_impl) + # generating bazel build file, need to fill dependency manually + bazel_text = """ envoy_cc_mock( name = "{}_mocks", srcs = ["{}.cc"], @@ -310,60 +310,61 @@ def write_file_contents(class_name, class_defn, class_impl): ] ) """.format(to_filename(class_name), to_filename(class_name), to_filename(class_name)) - with open("BUILD", "r+") as bazel_file: - contents = bazel_file.read() - if 'name = "{}_mocks"'.format(to_filename(class_name)) not in contents: - bazel_file.write(bazel_text) + with open("BUILD", "r+") as bazel_file: + contents = bazel_file.read() + if 'name = "{}_mocks"'.format(to_filename(class_name)) not in contents: + bazel_file.write(bazel_text) def main(args): - """ + """ divides the monolithic mock file into different mock class files. """ - decl_filename = args["decl"] - impl_filename = args["impl"] - idx = Index.create() - impl_translation_unit = TranslationUnit.from_source( - impl_filename, options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) - impl_includes = get_directives(impl_translation_unit) - decl_translation_unit = idx.parse(decl_filename, ["-x", "c++"]) - defns = class_definitions(decl_translation_unit.cursor) - decl_includes = get_directives(decl_translation_unit) - impl_cursors = class_implementations(impl_translation_unit.cursor) - contents = read_file_contents(impl_filename) - classname_to_impl = extract_implementations(impl_cursors, contents) - classnames = [cursor.spelling for cursor in defns] - for defn in defns: - # writing {class}.h and {classname}.cc - class_name, class_defn, deps = extract_definition(defn, classnames) - includes = "" - for name in deps: - includes += '#include "{}.h"\n'.format(to_filename(name)) - class_defn = decl_includes + includes + class_defn - class_impl = "" - if class_name not in classname_to_impl: - print("Warning: empty class {}".format(class_name)) - else: - impl_include = impl_includes.replace(decl_filename, "{}.h".format(to_filename(class_name))) - # we need to enclose methods with namespaces - namespace_prefix, namespace_suffix = get_enclosing_namespace(defn) - class_impl = impl_include + namespace_prefix + \ - classname_to_impl[class_name] + namespace_suffix - write_file_contents(class_name, class_defn, class_impl) + decl_filename = args["decl"] + impl_filename = args["impl"] + idx = Index.create() + impl_translation_unit = TranslationUnit.from_source( + impl_filename, options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) + impl_includes = get_directives(impl_translation_unit) + decl_translation_unit = idx.parse(decl_filename, ["-x", "c++"]) + defns = class_definitions(decl_translation_unit.cursor) + decl_includes = get_directives(decl_translation_unit) + impl_cursors = class_implementations(impl_translation_unit.cursor) + contents = read_file_contents(impl_filename) + classname_to_impl = extract_implementations(impl_cursors, contents) + classnames = [cursor.spelling for cursor in defns] + for defn in defns: + # writing {class}.h and {classname}.cc + class_name, class_defn, deps = extract_definition(defn, classnames) + includes = "" + for name in deps: + includes += '#include "{}.h"\n'.format(to_filename(name)) + class_defn = decl_includes + includes + class_defn + class_impl = "" + if class_name not in classname_to_impl: + print("Warning: empty class {}".format(class_name)) + else: + impl_include = impl_includes.replace(decl_filename, + "{}.h".format(to_filename(class_name))) + # we need to enclose methods with namespaces + namespace_prefix, namespace_suffix = get_enclosing_namespace(defn) + class_impl = impl_include + namespace_prefix + \ + classname_to_impl[class_name] + namespace_suffix + write_file_contents(class_name, class_defn, class_impl) if __name__ == "__main__": - PARSER = argparse.ArgumentParser() - PARSER.add_argument( - "-d", - "--decl", - default="mocks.h", - help="Path to the monolithic header .h file that needs to be splitted", - ) - PARSER.add_argument( - "-i", - "--impl", - default="mocks.cc", - help="Path to the implementation code .cc file that needs to be splitted", - ) - main(vars(PARSER.parse_args())) + PARSER = argparse.ArgumentParser() + PARSER.add_argument( + "-d", + "--decl", + default="mocks.h", + help="Path to the monolithic header .h file that needs to be splitted", + ) + PARSER.add_argument( + "-i", + "--impl", + default="mocks.cc", + help="Path to the implementation code .cc file that needs to be splitted", + ) + main(vars(PARSER.parse_args())) diff --git a/tools/envoy_headersplit/headersplit_test.py b/tools/envoy_headersplit/headersplit_test.py index 962626dbe37f..7515b9ff3276 100644 --- a/tools/envoy_headersplit/headersplit_test.py +++ b/tools/envoy_headersplit/headersplit_test.py @@ -14,30 +14,30 @@ class HeadersplitTest(unittest.TestCase): - # A header contains a simple class print hello world - source_code_hello_world = open("tools/envoy_headersplit/code_corpus/hello.h", "r").read() - # A C++ source code contains definition for several classes - source_class_defn = open("tools/envoy_headersplit/code_corpus/class_defn.h", "r").read() - # almost the same as above, but classes are not enclosed by namespace - source_class_defn_without_namespace = open( - "tools/envoy_headersplit/code_corpus/class_defn_without_namespace.h", "r").read() - # A C++ source code contains method implementations for class_defn.h - source_class_impl = open("tools/envoy_headersplit/code_corpus/class_impl.cc", "r").read() - - def test_to_filename(self): - # Test class name with one "mock" - self.assertEqual(headersplit.to_filename("MockAdminStream"), "admin_stream") - - # Test class name with two "Mock" - self.assertEqual(headersplit.to_filename("MockClusterMockPrioritySet"), - "cluster_mock_priority_set") - - # Test class name with no "Mock" - self.assertEqual(headersplit.to_filename("TestRetryHostPredicateFactory"), - "test_retry_host_predicate_factory") - - def test_get_directives(self): - includes = """// your first c++ program + # A header contains a simple class print hello world + source_code_hello_world = open("tools/envoy_headersplit/code_corpus/hello.h", "r").read() + # A C++ source code contains definition for several classes + source_class_defn = open("tools/envoy_headersplit/code_corpus/class_defn.h", "r").read() + # almost the same as above, but classes are not enclosed by namespace + source_class_defn_without_namespace = open( + "tools/envoy_headersplit/code_corpus/class_defn_without_namespace.h", "r").read() + # A C++ source code contains method implementations for class_defn.h + source_class_impl = open("tools/envoy_headersplit/code_corpus/class_impl.cc", "r").read() + + def test_to_filename(self): + # Test class name with one "mock" + self.assertEqual(headersplit.to_filename("MockAdminStream"), "admin_stream") + + # Test class name with two "Mock" + self.assertEqual(headersplit.to_filename("MockClusterMockPrioritySet"), + "cluster_mock_priority_set") + + # Test class name with no "Mock" + self.assertEqual(headersplit.to_filename("TestRetryHostPredicateFactory"), + "test_retry_host_predicate_factory") + + def test_get_directives(self): + includes = """// your first c++ program // NOLINT(namespace-envoy) #include @@ -46,58 +46,59 @@ def test_get_directives(self): #include "foo/bar" """ - translation_unit_hello_world = TranslationUnit.from_source( - "tools/envoy_headersplit/code_corpus/hello.h", - options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) - self.assertEqual(headersplit.get_directives(translation_unit_hello_world), includes) - - def test_class_definitions(self): - idx = Index.create() - translation_unit_class_defn = idx.parse("tools/envoy_headersplit/code_corpus/class_defn.h", - ["-x", "c++"]) - defns_cursors = headersplit.class_definitions(translation_unit_class_defn.cursor) - defns_names = [cursor.spelling for cursor in defns_cursors] - self.assertEqual(defns_names, ["Foo", "Bar", "FooBar", "DeadBeaf"]) - idx = Index.create() - translation_unit_class_defn = idx.parse( - "tools/envoy_headersplit/code_corpus/class_defn_without_namespace.h", ["-x", "c++"]) - defns_cursors = headersplit.class_definitions(translation_unit_class_defn.cursor) - defns_names = [cursor.spelling for cursor in defns_cursors] - self.assertEqual(defns_names, []) - - def test_class_implementations(self): - translation_unit_class_impl = TranslationUnit.from_source( - "tools/envoy_headersplit/code_corpus/class_impl.cc", - options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) - impls_cursors = headersplit.class_implementations(translation_unit_class_impl.cursor) - impls_names = [cursor.spelling for cursor in impls_cursors] - self.assertEqual(impls_names, ["getFoo", "val", "DeadBeaf"]) - - def test_class_implementations_error(self): - # LibClang will fail in parse this source file (it's modified from the original - # test/server/mocks.cc from Envoy repository) if we don't add flag PARSE_SKIP_FUNCTION_BODIES - # to ignore function bodies. - impl_translation_unit = TranslationUnit.from_source( - "tools/envoy_headersplit/code_corpus/fail_mocks.cc") - impls_cursors = headersplit.class_implementations(impl_translation_unit.cursor) - # impls_name is not complete in this case - impls_names = [cursor.spelling for cursor in impls_cursors] - # LibClang will stop parsing at - # MockListenerComponentFactory::MockListenerComponentFactory() - # : socket_(std::make_shared>()) { - # ^ - # Since parsing stops early, we will have incomplete method list. - # The reason is not clear, however, this issue can be addressed by adding parsing flag to - # ignore function body - - # get correct list of member methods - impl_translation_unit_correct = TranslationUnit.from_source( - "tools/envoy_headersplit/code_corpus/fail_mocks.cc", - options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) - impls_cursors_correct = headersplit.class_implementations(impl_translation_unit_correct.cursor) - impls_names_correct = [cursor.spelling for cursor in impls_cursors_correct] - self.assertNotEqual(impls_names, impls_names_correct) + translation_unit_hello_world = TranslationUnit.from_source( + "tools/envoy_headersplit/code_corpus/hello.h", + options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) + self.assertEqual(headersplit.get_directives(translation_unit_hello_world), includes) + + def test_class_definitions(self): + idx = Index.create() + translation_unit_class_defn = idx.parse("tools/envoy_headersplit/code_corpus/class_defn.h", + ["-x", "c++"]) + defns_cursors = headersplit.class_definitions(translation_unit_class_defn.cursor) + defns_names = [cursor.spelling for cursor in defns_cursors] + self.assertEqual(defns_names, ["Foo", "Bar", "FooBar", "DeadBeaf"]) + idx = Index.create() + translation_unit_class_defn = idx.parse( + "tools/envoy_headersplit/code_corpus/class_defn_without_namespace.h", ["-x", "c++"]) + defns_cursors = headersplit.class_definitions(translation_unit_class_defn.cursor) + defns_names = [cursor.spelling for cursor in defns_cursors] + self.assertEqual(defns_names, []) + + def test_class_implementations(self): + translation_unit_class_impl = TranslationUnit.from_source( + "tools/envoy_headersplit/code_corpus/class_impl.cc", + options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) + impls_cursors = headersplit.class_implementations(translation_unit_class_impl.cursor) + impls_names = [cursor.spelling for cursor in impls_cursors] + self.assertEqual(impls_names, ["getFoo", "val", "DeadBeaf"]) + + def test_class_implementations_error(self): + # LibClang will fail in parse this source file (it's modified from the original + # test/server/mocks.cc from Envoy repository) if we don't add flag PARSE_SKIP_FUNCTION_BODIES + # to ignore function bodies. + impl_translation_unit = TranslationUnit.from_source( + "tools/envoy_headersplit/code_corpus/fail_mocks.cc") + impls_cursors = headersplit.class_implementations(impl_translation_unit.cursor) + # impls_name is not complete in this case + impls_names = [cursor.spelling for cursor in impls_cursors] + # LibClang will stop parsing at + # MockListenerComponentFactory::MockListenerComponentFactory() + # : socket_(std::make_shared>()) { + # ^ + # Since parsing stops early, we will have incomplete method list. + # The reason is not clear, however, this issue can be addressed by adding parsing flag to + # ignore function body + + # get correct list of member methods + impl_translation_unit_correct = TranslationUnit.from_source( + "tools/envoy_headersplit/code_corpus/fail_mocks.cc", + options=TranslationUnit.PARSE_SKIP_FUNCTION_BODIES) + impls_cursors_correct = headersplit.class_implementations( + impl_translation_unit_correct.cursor) + impls_names_correct = [cursor.spelling for cursor in impls_cursors_correct] + self.assertNotEqual(impls_names, impls_names_correct) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/tools/envoy_headersplit/replace_includes.py b/tools/envoy_headersplit/replace_includes.py index 29e51f8b1b82..f418cd37b743 100644 --- a/tools/envoy_headersplit/replace_includes.py +++ b/tools/envoy_headersplit/replace_includes.py @@ -19,7 +19,7 @@ class Server::MockAdmin. def to_classname(filename: str) -> str: - """ + """ maps divided mock class file name to class names inverse function of headersplit.to_filename e.g. map "test/mocks/server/admin_stream.h" to "MockAdminStream" @@ -30,13 +30,13 @@ def to_classname(filename: str) -> str: Returns: corresponding class name """ - classname_tokens = filename.split('/')[-1].replace('.h', '').split('_') - classname = "Mock" + ''.join(map(lambda x: x[:1].upper() + x[1:], classname_tokens)) - return classname + classname_tokens = filename.split('/')[-1].replace('.h', '').split('_') + classname = "Mock" + ''.join(map(lambda x: x[:1].upper() + x[1:], classname_tokens)) + return classname def to_bazelname(filename: str, mockname: str) -> str: - """ + """ maps divided mock class file name to bazel target name e.g. map "test/mocks/server/admin_stream.h" to "//test/mocks/server:admin_stream_mocks" @@ -47,13 +47,13 @@ def to_bazelname(filename: str, mockname: str) -> str: Returns: corresponding bazel target name """ - bazelname = "//test/mocks/{}:".format(mockname) - bazelname += filename.split('/')[-1].replace('.h', '') + '_mocks'.format(mockname) - return bazelname + bazelname = "//test/mocks/{}:".format(mockname) + bazelname += filename.split('/')[-1].replace('.h', '') + '_mocks'.format(mockname) + return bazelname def get_filenames(mockname: str) -> List[str]: - """ + """ scans all headers in test/mocks/{mockname}, return corresponding file names Args: @@ -62,59 +62,60 @@ def get_filenames(mockname: str) -> List[str]: Returns: List of file name for the headers in test/mock/{mocksname} """ - dir = Path("test/mocks/{}/".format(mockname)) - filenames = list(map(str, dir.glob('*.h'))) - return filenames + dir = Path("test/mocks/{}/".format(mockname)) + filenames = list(map(str, dir.glob('*.h'))) + return filenames def replace_includes(mockname): - filenames = get_filenames(mockname) - classnames = [to_classname(filename) for filename in filenames] - p = Path('./test') - changed_list = [] # list of test code that been refactored - # walk through all files and check files that contains "{mockname}/mocks.h" - # don't forget change dependency on bazel - for test_file in p.glob('**/*.cc'): - replace_includes = "" - used_mock_header = False - bazel_targets = "" - with test_file.open() as f: - content = f.read() - if '#include "test/mocks/{}/mocks.h"'.format(mockname) in content: - used_mock_header = True + filenames = get_filenames(mockname) + classnames = [to_classname(filename) for filename in filenames] + p = Path('./test') + changed_list = [] # list of test code that been refactored + # walk through all files and check files that contains "{mockname}/mocks.h" + # don't forget change dependency on bazel + for test_file in p.glob('**/*.cc'): replace_includes = "" - for classname in classnames: - if classname in content: - # replace mocks.h with mock class header used by this test library - # limitation: if some class names in classnames are substrings of others, this part - # will bring over-inclusion e.g. if we have MockCluster and MockClusterFactory, and - # the source code only used MockClusterFactory, then the result code will also include - # MockCluster since it also shows in the file. - # TODO: use clang to analysis class usage instead by simple find and replace - replace_includes += '#include "test/mocks/{}/{}.h"\n'.format( - mockname, to_filename(classname)) - bazel_targets += '"{}",'.format(to_bazelname(to_filename(classname), mockname)) - if used_mock_header: - changed_list.append(str(test_file.relative_to(Path('.'))) + '\n') - with test_file.open(mode='w') as f: - f.write( - content.replace('#include "test/mocks/{}/mocks.h"\n'.format(mockname), - replace_includes)) - with (test_file.parent / 'BUILD').open() as f: - # write building files - content = f.read() - split_content = content.split(test_file.name) - split_content[1] = split_content[1].replace( - '"//test/mocks/{}:{}_mocks",'.format(mockname, mockname), bazel_targets, 1) - content = split_content[0] + test_file.name + split_content[1] - with (test_file.parent / 'BUILD').open('w') as f: - f.write(content) - with open("changed.txt", "w") as f: - f.writelines(changed_list) + used_mock_header = False + bazel_targets = "" + with test_file.open() as f: + content = f.read() + if '#include "test/mocks/{}/mocks.h"'.format(mockname) in content: + used_mock_header = True + replace_includes = "" + for classname in classnames: + if classname in content: + # replace mocks.h with mock class header used by this test library + # limitation: if some class names in classnames are substrings of others, this part + # will bring over-inclusion e.g. if we have MockCluster and MockClusterFactory, and + # the source code only used MockClusterFactory, then the result code will also include + # MockCluster since it also shows in the file. + # TODO: use clang to analysis class usage instead by simple find and replace + replace_includes += '#include "test/mocks/{}/{}.h"\n'.format( + mockname, to_filename(classname)) + bazel_targets += '"{}",'.format( + to_bazelname(to_filename(classname), mockname)) + if used_mock_header: + changed_list.append(str(test_file.relative_to(Path('.'))) + '\n') + with test_file.open(mode='w') as f: + f.write( + content.replace('#include "test/mocks/{}/mocks.h"\n'.format(mockname), + replace_includes)) + with (test_file.parent / 'BUILD').open() as f: + # write building files + content = f.read() + split_content = content.split(test_file.name) + split_content[1] = split_content[1].replace( + '"//test/mocks/{}:{}_mocks",'.format(mockname, mockname), bazel_targets, 1) + content = split_content[0] + test_file.name + split_content[1] + with (test_file.parent / 'BUILD').open('w') as f: + f.write(content) + with open("changed.txt", "w") as f: + f.writelines(changed_list) if __name__ == '__main__': - PARSER = argparse.ArgumentParser() - PARSER.add_argument('-m', '--mockname', default="server", help="mock folder that been divided") - mockname = vars(PARSER.parse_args())['mockname'] - replace_includes(mockname) + PARSER = argparse.ArgumentParser() + PARSER.add_argument('-m', '--mockname', default="server", help="mock folder that been divided") + mockname = vars(PARSER.parse_args())['mockname'] + replace_includes(mockname) diff --git a/tools/envoy_headersplit/replace_includes_test.py b/tools/envoy_headersplit/replace_includes_test.py index e066deb5db5c..f386e2f698fb 100644 --- a/tools/envoy_headersplit/replace_includes_test.py +++ b/tools/envoy_headersplit/replace_includes_test.py @@ -10,62 +10,64 @@ class ReplaceIncludesTest(unittest.TestCase): - def test_to_classname(self): - # Test file name with whole path - self.assertEqual(replace_includes.to_classname("test/mocks/server/admin_stream.h"), - "MockAdminStream") - # Test file name without .h extension - self.assertEqual(replace_includes.to_classname("cluster_mock_priority_set"), - "MockClusterMockPrioritySet") + def test_to_classname(self): + # Test file name with whole path + self.assertEqual(replace_includes.to_classname("test/mocks/server/admin_stream.h"), + "MockAdminStream") + # Test file name without .h extension + self.assertEqual(replace_includes.to_classname("cluster_mock_priority_set"), + "MockClusterMockPrioritySet") - def test_to_bazelname(self): - # Test file name with whole path - self.assertEqual(replace_includes.to_bazelname("test/mocks/server/admin_stream.h", "server"), - "//test/mocks/server:admin_stream_mocks") - # Test file name without .h extension - self.assertEqual(replace_includes.to_bazelname("cluster_mock_priority_set", "upstream"), - "//test/mocks/upstream:cluster_mock_priority_set_mocks") + def test_to_bazelname(self): + # Test file name with whole path + self.assertEqual( + replace_includes.to_bazelname("test/mocks/server/admin_stream.h", "server"), + "//test/mocks/server:admin_stream_mocks") + # Test file name without .h extension + self.assertEqual(replace_includes.to_bazelname("cluster_mock_priority_set", "upstream"), + "//test/mocks/upstream:cluster_mock_priority_set_mocks") - class FakeDir(): - # fake directory to test get_filenames - def glob(self, _): - return [ - Path("test/mocks/server/admin_stream.h"), - Path("test/mocks/server/admin.h"), - Path("test/mocks/upstream/cluster_manager.h") - ] + class FakeDir(): + # fake directory to test get_filenames + def glob(self, _): + return [ + Path("test/mocks/server/admin_stream.h"), + Path("test/mocks/server/admin.h"), + Path("test/mocks/upstream/cluster_manager.h") + ] - @mock.patch("replace_includes.Path", return_value=FakeDir()) - def test_get_filenames(self, mock_Path): - self.assertEqual(replace_includes.get_filenames("sever"), [ - "test/mocks/server/admin_stream.h", "test/mocks/server/admin.h", - "test/mocks/upstream/cluster_manager.h" - ]) + @mock.patch("replace_includes.Path", return_value=FakeDir()) + def test_get_filenames(self, mock_Path): + self.assertEqual(replace_includes.get_filenames("sever"), [ + "test/mocks/server/admin_stream.h", "test/mocks/server/admin.h", + "test/mocks/upstream/cluster_manager.h" + ]) - def test_replace_includes(self): - fake_source_code = open("tools/envoy_headersplit/code_corpus/fake_source_code.cc", "r").read() - fake_build_file = open("tools/envoy_headersplit/code_corpus/fake_build", "r").read() - os.mkdir("test") - os.mkdir("test/mocks") - os.mkdir("test/mocks/upstream") - open("test/mocks/upstream/cluster_manager.h", "a").close() - with open("test/async_client_impl_test.cc", "w") as f: - f.write(fake_source_code) - with open("test/BUILD", "w") as f: - f.write(fake_build_file) - replace_includes.replace_includes("upstream") - source_code = "" - build_file = "" - with open("test/async_client_impl_test.cc", "r") as f: - source_code = f.read() - with open("test/BUILD", "r") as f: - build_file = f.read() - self.assertEqual(source_code, - fake_source_code.replace("upstream/mocks", "upstream/cluster_manager")) - self.assertEqual( - build_file, - fake_build_file.replace("upstream:upstream_mocks", "upstream:cluster_manager_mocks")) + def test_replace_includes(self): + fake_source_code = open("tools/envoy_headersplit/code_corpus/fake_source_code.cc", + "r").read() + fake_build_file = open("tools/envoy_headersplit/code_corpus/fake_build", "r").read() + os.mkdir("test") + os.mkdir("test/mocks") + os.mkdir("test/mocks/upstream") + open("test/mocks/upstream/cluster_manager.h", "a").close() + with open("test/async_client_impl_test.cc", "w") as f: + f.write(fake_source_code) + with open("test/BUILD", "w") as f: + f.write(fake_build_file) + replace_includes.replace_includes("upstream") + source_code = "" + build_file = "" + with open("test/async_client_impl_test.cc", "r") as f: + source_code = f.read() + with open("test/BUILD", "r") as f: + build_file = f.read() + self.assertEqual(source_code, + fake_source_code.replace("upstream/mocks", "upstream/cluster_manager")) + self.assertEqual( + build_file, + fake_build_file.replace("upstream:upstream_mocks", "upstream:cluster_manager_mocks")) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/tools/extensions/generate_extension_db.py b/tools/extensions/generate_extension_db.py index 1ba5048f41bf..e12e7b0211a1 100644 --- a/tools/extensions/generate_extension_db.py +++ b/tools/extensions/generate_extension_db.py @@ -28,8 +28,8 @@ ENVOY_SRCDIR = os.getenv('ENVOY_SRCDIR', '/source') if not os.path.exists(ENVOY_SRCDIR): - raise SystemExit( - "Envoy source must either be located at /source, or ENVOY_SRCDIR env var must be set") + raise SystemExit( + "Envoy source must either be located at /source, or ENVOY_SRCDIR env var must be set") # source/extensions/extensions_build_config.bzl must have a .bzl suffix for Starlark # import, so we are forced to do this workaround. @@ -42,91 +42,91 @@ class ExtensionDbError(Exception): - pass + pass def is_missing(value): - return value == '(missing)' + return value == '(missing)' def num_read_filters_fuzzed(): - data = pathlib.Path( - os.path.join( - ENVOY_SRCDIR, - 'test/extensions/filters/network/common/fuzz/uber_per_readfilter.cc')).read_text() - # Hack-ish! We only search the first 50 lines to capture the filters in filterNames(). - return len(re.findall('NetworkFilterNames::get()', ''.join(data.splitlines()[:50]))) + data = pathlib.Path( + os.path.join( + ENVOY_SRCDIR, + 'test/extensions/filters/network/common/fuzz/uber_per_readfilter.cc')).read_text() + # Hack-ish! We only search the first 50 lines to capture the filters in filterNames(). + return len(re.findall('NetworkFilterNames::get()', ''.join(data.splitlines()[:50]))) def num_robust_to_downstream_network_filters(db): - # Count number of network filters robust to untrusted downstreams. - return len([ - ext for ext, data in db.items() - if 'network' in ext and data['security_posture'] == 'robust_to_untrusted_downstream' - ]) + # Count number of network filters robust to untrusted downstreams. + return len([ + ext for ext, data in db.items() + if 'network' in ext and data['security_posture'] == 'robust_to_untrusted_downstream' + ]) def get_extension_metadata(target): - if not BUILDOZER_PATH: - raise ExtensionDbError('Buildozer not found!') - r = subprocess.run( - [BUILDOZER_PATH, '-stdout', 'print security_posture status undocumented category', target], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - rout = r.stdout.decode('utf-8').strip().split(' ') - security_posture, status, undocumented = rout[:3] - categories = ' '.join(rout[3:]) - if is_missing(security_posture): - raise ExtensionDbError( - 'Missing security posture for %s. Please make sure the target is an envoy_cc_extension and security_posture is set' - % target) - if is_missing(categories): - raise ExtensionDbError( - 'Missing extension category for %s. Please make sure the target is an envoy_cc_extension and category is set' - % target) - # evaluate tuples/lists - # wrap strings in a list - categories = (ast.literal_eval(categories) if - ('[' in categories or '(' in categories) else [categories]) - return { - 'security_posture': security_posture, - 'undocumented': False if is_missing(undocumented) else bool(undocumented), - 'status': 'stable' if is_missing(status) else status, - 'categories': categories, - } + if not BUILDOZER_PATH: + raise ExtensionDbError('Buildozer not found!') + r = subprocess.run( + [BUILDOZER_PATH, '-stdout', 'print security_posture status undocumented category', target], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + rout = r.stdout.decode('utf-8').strip().split(' ') + security_posture, status, undocumented = rout[:3] + categories = ' '.join(rout[3:]) + if is_missing(security_posture): + raise ExtensionDbError( + 'Missing security posture for %s. Please make sure the target is an envoy_cc_extension and security_posture is set' + % target) + if is_missing(categories): + raise ExtensionDbError( + 'Missing extension category for %s. Please make sure the target is an envoy_cc_extension and category is set' + % target) + # evaluate tuples/lists + # wrap strings in a list + categories = (ast.literal_eval(categories) if + ('[' in categories or '(' in categories) else [categories]) + return { + 'security_posture': security_posture, + 'undocumented': False if is_missing(undocumented) else bool(undocumented), + 'status': 'stable' if is_missing(status) else status, + 'categories': categories, + } if __name__ == '__main__': - try: - output_path = os.getenv("EXTENSION_DB_PATH") or sys.argv[1] - except IndexError: - raise SystemExit( - "Output path must be either specified as arg or with EXTENSION_DB_PATH env var") - - extension_db = {} - # Include all extensions from source/extensions/extensions_build_config.bzl - all_extensions = {} - all_extensions.update(extensions_build_config.EXTENSIONS) - for extension, target in all_extensions.items(): - extension_db[extension] = get_extension_metadata(target) - if num_robust_to_downstream_network_filters(extension_db) != num_read_filters_fuzzed(): - raise ExtensionDbError('Check that all network filters robust against untrusted' - 'downstreams are fuzzed by adding them to filterNames() in' - 'test/extensions/filters/network/common/uber_per_readfilter.cc') - # The TLS and generic upstream extensions are hard-coded into the build, so - # not in source/extensions/extensions_build_config.bzl - # TODO(mattklein123): Read these special keys from all_extensions.bzl or a shared location to - # avoid duplicate logic. - extension_db['envoy.transport_sockets.tls'] = get_extension_metadata( - '//source/extensions/transport_sockets/tls:config') - extension_db['envoy.upstreams.http.generic'] = get_extension_metadata( - '//source/extensions/upstreams/http/generic:config') - extension_db['envoy.upstreams.tcp.generic'] = get_extension_metadata( - '//source/extensions/upstreams/tcp/generic:config') - extension_db['envoy.upstreams.http.http_protocol_options'] = get_extension_metadata( - '//source/extensions/upstreams/http:config') - extension_db['envoy.request_id.uuid'] = get_extension_metadata( - '//source/extensions/request_id/uuid:config') - - pathlib.Path(os.path.dirname(output_path)).mkdir(parents=True, exist_ok=True) - pathlib.Path(output_path).write_text(json.dumps(extension_db)) + try: + output_path = os.getenv("EXTENSION_DB_PATH") or sys.argv[1] + except IndexError: + raise SystemExit( + "Output path must be either specified as arg or with EXTENSION_DB_PATH env var") + + extension_db = {} + # Include all extensions from source/extensions/extensions_build_config.bzl + all_extensions = {} + all_extensions.update(extensions_build_config.EXTENSIONS) + for extension, target in all_extensions.items(): + extension_db[extension] = get_extension_metadata(target) + if num_robust_to_downstream_network_filters(extension_db) != num_read_filters_fuzzed(): + raise ExtensionDbError('Check that all network filters robust against untrusted' + 'downstreams are fuzzed by adding them to filterNames() in' + 'test/extensions/filters/network/common/uber_per_readfilter.cc') + # The TLS and generic upstream extensions are hard-coded into the build, so + # not in source/extensions/extensions_build_config.bzl + # TODO(mattklein123): Read these special keys from all_extensions.bzl or a shared location to + # avoid duplicate logic. + extension_db['envoy.transport_sockets.tls'] = get_extension_metadata( + '//source/extensions/transport_sockets/tls:config') + extension_db['envoy.upstreams.http.generic'] = get_extension_metadata( + '//source/extensions/upstreams/http/generic:config') + extension_db['envoy.upstreams.tcp.generic'] = get_extension_metadata( + '//source/extensions/upstreams/tcp/generic:config') + extension_db['envoy.upstreams.http.http_protocol_options'] = get_extension_metadata( + '//source/extensions/upstreams/http:config') + extension_db['envoy.request_id.uuid'] = get_extension_metadata( + '//source/extensions/request_id/uuid:config') + + pathlib.Path(os.path.dirname(output_path)).mkdir(parents=True, exist_ok=True) + pathlib.Path(output_path).write_text(json.dumps(extension_db)) diff --git a/tools/extensions/generate_extension_rst.py b/tools/extensions/generate_extension_rst.py index 31c36dbdcc6d..40966c8e02a2 100644 --- a/tools/extensions/generate_extension_rst.py +++ b/tools/extensions/generate_extension_rst.py @@ -11,40 +11,42 @@ def format_item(extension, metadata): - if metadata['undocumented']: - item = '* %s' % extension - else: - item = '* :ref:`%s `' % (extension, extension) - if metadata['status'] == 'alpha': - item += ' (alpha)' - return item + if metadata['undocumented']: + item = '* %s' % extension + else: + item = '* :ref:`%s `' % (extension, extension) + if metadata['status'] == 'alpha': + item += ' (alpha)' + return item if __name__ == '__main__': - try: - generated_rst_dir = os.environ["GENERATED_RST_DIR"] - except KeyError: - raise SystemExit("Path to an output directory must be specified with GENERATED_RST_DIR env var") - security_rst_root = os.path.join(generated_rst_dir, "intro/arch_overview/security") - - try: - extension_db_path = os.environ["EXTENSION_DB_PATH"] - except KeyError: - raise SystemExit("Path to a json extension db must be specified with EXTENSION_DB_PATH env var") - if not os.path.exists(extension_db_path): - subprocess.run("tools/extensions/generate_extension_db".split(), check=True) - extension_db = json.loads(pathlib.Path(extension_db_path).read_text()) - - pathlib.Path(security_rst_root).mkdir(parents=True, exist_ok=True) - - security_postures = defaultdict(list) - for extension, metadata in extension_db.items(): - security_postures[metadata['security_posture']].append(extension) - - for sp, extensions in security_postures.items(): - output_path = pathlib.Path(security_rst_root, 'secpos_%s.rst' % sp) - content = '\n'.join( - format_item(extension, extension_db[extension]) - for extension in sorted(extensions) - if extension_db[extension]['status'] != 'wip') - output_path.write_text(content) + try: + generated_rst_dir = os.environ["GENERATED_RST_DIR"] + except KeyError: + raise SystemExit( + "Path to an output directory must be specified with GENERATED_RST_DIR env var") + security_rst_root = os.path.join(generated_rst_dir, "intro/arch_overview/security") + + try: + extension_db_path = os.environ["EXTENSION_DB_PATH"] + except KeyError: + raise SystemExit( + "Path to a json extension db must be specified with EXTENSION_DB_PATH env var") + if not os.path.exists(extension_db_path): + subprocess.run("tools/extensions/generate_extension_db".split(), check=True) + extension_db = json.loads(pathlib.Path(extension_db_path).read_text()) + + pathlib.Path(security_rst_root).mkdir(parents=True, exist_ok=True) + + security_postures = defaultdict(list) + for extension, metadata in extension_db.items(): + security_postures[metadata['security_posture']].append(extension) + + for sp, extensions in security_postures.items(): + output_path = pathlib.Path(security_rst_root, 'secpos_%s.rst' % sp) + content = '\n'.join( + format_item(extension, extension_db[extension]) + for extension in sorted(extensions) + if extension_db[extension]['status'] != 'wip') + output_path.write_text(content) diff --git a/tools/find_related_envoy_files.py b/tools/find_related_envoy_files.py index 60acbcc30ab4..3c0578f25444 100755 --- a/tools/find_related_envoy_files.py +++ b/tools/find_related_envoy_files.py @@ -38,17 +38,17 @@ # is trouble. envoy_index = fname.rfind(ENVOY_ROOT) if envoy_index == -1: - sys.exit(0) + sys.exit(0) envoy_index += len(ENVOY_ROOT) absolute_location = fname[0:envoy_index] # "/path/to/gitroot/envoy/" path = fname[envoy_index:] path_elements = path.split("/") if len(path_elements) < 3: - sys.exit(0) + sys.exit(0) leaf = path_elements[len(path_elements) - 1] dot = leaf.rfind(".") if dot == -1 or dot == len(leaf) - 1: - sys.exit(0) + sys.exit(0) ext = leaf[dot:] @@ -56,36 +56,36 @@ # is emitted if the input path or extension does not match the expected pattern, # or if the file doesn't exist. def emit(source_path, dest_path, source_ending, dest_ending): - if fname.endswith(source_ending) and path.startswith(source_path + "/"): - path_len = len(path) - len(source_path) - len(source_ending) - new_path = (absolute_location + dest_path + path[len(source_path):-len(source_ending)] + - dest_ending) - if os.path.isfile(new_path): - print(new_path) + if fname.endswith(source_ending) and path.startswith(source_path + "/"): + path_len = len(path) - len(source_path) - len(source_ending) + new_path = (absolute_location + dest_path + path[len(source_path):-len(source_ending)] + + dest_ending) + if os.path.isfile(new_path): + print(new_path) # Depending on which type of file is passed into the script: test, cc, # h, or interface, emit any related ones in cyclic order. root = path_elements[0] if root == TEST_ROOT: - emit("test/common", INTERFACE_REAL_ROOT, "_impl_test.cc", ".h") - emit(TEST_ROOT, SOURCE_ROOT, "_test.cc", ".cc") - emit(TEST_ROOT, SOURCE_ROOT, "_test.cc", ".h") - emit(TEST_ROOT, TEST_ROOT, ".cc", ".h") - emit(TEST_ROOT, TEST_ROOT, ".cc", "_test.cc") - emit(TEST_ROOT, TEST_ROOT, ".h", "_test.cc") - emit(TEST_ROOT, TEST_ROOT, ".h", ".cc") - emit(TEST_ROOT, TEST_ROOT, "_test.cc", ".cc") - emit(TEST_ROOT, TEST_ROOT, "_test.cc", ".h") + emit("test/common", INTERFACE_REAL_ROOT, "_impl_test.cc", ".h") + emit(TEST_ROOT, SOURCE_ROOT, "_test.cc", ".cc") + emit(TEST_ROOT, SOURCE_ROOT, "_test.cc", ".h") + emit(TEST_ROOT, TEST_ROOT, ".cc", ".h") + emit(TEST_ROOT, TEST_ROOT, ".cc", "_test.cc") + emit(TEST_ROOT, TEST_ROOT, ".h", "_test.cc") + emit(TEST_ROOT, TEST_ROOT, ".h", ".cc") + emit(TEST_ROOT, TEST_ROOT, "_test.cc", ".cc") + emit(TEST_ROOT, TEST_ROOT, "_test.cc", ".h") elif root == SOURCE_ROOT and ext == ".cc": - emit(SOURCE_ROOT, SOURCE_ROOT, ".cc", ".h") - emit(SOURCE_ROOT, TEST_ROOT, ".cc", "_test.cc") - emit("source/common", INTERFACE_REAL_ROOT, "_impl.cc", ".h") + emit(SOURCE_ROOT, SOURCE_ROOT, ".cc", ".h") + emit(SOURCE_ROOT, TEST_ROOT, ".cc", "_test.cc") + emit("source/common", INTERFACE_REAL_ROOT, "_impl.cc", ".h") elif root == SOURCE_ROOT and ext == ".h": - emit(SOURCE_ROOT, TEST_ROOT, ".h", "_test.cc") - emit("source/common", INTERFACE_REAL_ROOT, "_impl.h", ".h") - emit(SOURCE_ROOT, SOURCE_ROOT, ".h", ".cc") + emit(SOURCE_ROOT, TEST_ROOT, ".h", "_test.cc") + emit("source/common", INTERFACE_REAL_ROOT, "_impl.h", ".h") + emit(SOURCE_ROOT, SOURCE_ROOT, ".h", ".cc") elif root == INTERFACE_SYNTHETIC_ROOT: - emit(INTERFACE_SYNTHETIC_ROOT, "source/common", ".h", "_impl.cc") - emit(INTERFACE_SYNTHETIC_ROOT, "source/common", ".h", "_impl.h") - emit(INTERFACE_SYNTHETIC_ROOT, "test/common", ".h", "_impl_test.cc") + emit(INTERFACE_SYNTHETIC_ROOT, "source/common", ".h", "_impl.cc") + emit(INTERFACE_SYNTHETIC_ROOT, "source/common", ".h", "_impl.h") + emit(INTERFACE_SYNTHETIC_ROOT, "test/common", ".h", "_impl_test.cc") diff --git a/tools/gen_compilation_database.py b/tools/gen_compilation_database.py index d1d1ef0a5a83..1249ca7a9ead 100755 --- a/tools/gen_compilation_database.py +++ b/tools/gen_compilation_database.py @@ -12,90 +12,90 @@ # This method is equivalent to https://github.com/grailbio/bazel-compilation-database/blob/master/generate.sh def generate_compilation_database(args): - # We need to download all remote outputs for generated source code. This option lives here to override those - # specified in bazelrc. - bazel_options = shlex.split(os.environ.get("BAZEL_BUILD_OPTIONS", "")) + [ - "--config=compdb", - "--remote_download_outputs=all", - ] + # We need to download all remote outputs for generated source code. This option lives here to override those + # specified in bazelrc. + bazel_options = shlex.split(os.environ.get("BAZEL_BUILD_OPTIONS", "")) + [ + "--config=compdb", + "--remote_download_outputs=all", + ] - subprocess.check_call(["bazel", "build"] + bazel_options + [ - "--aspects=@bazel_compdb//:aspects.bzl%compilation_database_aspect", - "--output_groups=compdb_files,header_files" - ] + args.bazel_targets) + subprocess.check_call(["bazel", "build"] + bazel_options + [ + "--aspects=@bazel_compdb//:aspects.bzl%compilation_database_aspect", + "--output_groups=compdb_files,header_files" + ] + args.bazel_targets) - execroot = subprocess.check_output(["bazel", "info", "execution_root"] + - bazel_options).decode().strip() + execroot = subprocess.check_output(["bazel", "info", "execution_root"] + + bazel_options).decode().strip() - compdb = [] - for compdb_file in Path(execroot).glob("**/*.compile_commands.json"): - compdb.extend(json.loads("[" + compdb_file.read_text().replace("__EXEC_ROOT__", execroot) + - "]")) - return compdb + compdb = [] + for compdb_file in Path(execroot).glob("**/*.compile_commands.json"): + compdb.extend( + json.loads("[" + compdb_file.read_text().replace("__EXEC_ROOT__", execroot) + "]")) + return compdb def is_header(filename): - for ext in (".h", ".hh", ".hpp", ".hxx"): - if filename.endswith(ext): - return True - return False + for ext in (".h", ".hh", ".hpp", ".hxx"): + if filename.endswith(ext): + return True + return False def is_compile_target(target, args): - filename = target["file"] - if not args.include_headers and is_header(filename): - return False + filename = target["file"] + if not args.include_headers and is_header(filename): + return False - if not args.include_genfiles: - if filename.startswith("bazel-out/"): - return False + if not args.include_genfiles: + if filename.startswith("bazel-out/"): + return False - if not args.include_external: - if filename.startswith("external/"): - return False + if not args.include_external: + if filename.startswith("external/"): + return False - return True + return True def modify_compile_command(target, args): - cc, options = target["command"].split(" ", 1) + cc, options = target["command"].split(" ", 1) - # Workaround for bazel added C++11 options, those doesn't affect build itself but - # clang-tidy will misinterpret them. - options = options.replace("-std=c++0x ", "") - options = options.replace("-std=c++11 ", "") + # Workaround for bazel added C++11 options, those doesn't affect build itself but + # clang-tidy will misinterpret them. + options = options.replace("-std=c++0x ", "") + options = options.replace("-std=c++11 ", "") - if args.vscode: - # Visual Studio Code doesn't seem to like "-iquote". Replace it with - # old-style "-I". - options = options.replace("-iquote ", "-I ") + if args.vscode: + # Visual Studio Code doesn't seem to like "-iquote". Replace it with + # old-style "-I". + options = options.replace("-iquote ", "-I ") - if is_header(target["file"]): - options += " -Wno-pragma-once-outside-header -Wno-unused-const-variable" - options += " -Wno-unused-function" - if not target["file"].startswith("external/"): - # *.h file is treated as C header by default while our headers files are all C++17. - options = "-x c++ -std=c++17 -fexceptions " + options + if is_header(target["file"]): + options += " -Wno-pragma-once-outside-header -Wno-unused-const-variable" + options += " -Wno-unused-function" + if not target["file"].startswith("external/"): + # *.h file is treated as C header by default while our headers files are all C++17. + options = "-x c++ -std=c++17 -fexceptions " + options - target["command"] = " ".join([cc, options]) - return target + target["command"] = " ".join([cc, options]) + return target def fix_compilation_database(args, db): - db = [modify_compile_command(target, args) for target in db if is_compile_target(target, args)] + db = [modify_compile_command(target, args) for target in db if is_compile_target(target, args)] - with open("compile_commands.json", "w") as db_file: - json.dump(db, db_file, indent=2) + with open("compile_commands.json", "w") as db_file: + json.dump(db, db_file, indent=2) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Generate JSON compilation database') - parser.add_argument('--include_external', action='store_true') - parser.add_argument('--include_genfiles', action='store_true') - parser.add_argument('--include_headers', action='store_true') - parser.add_argument('--vscode', action='store_true') - parser.add_argument('bazel_targets', - nargs='*', - default=["//source/...", "//test/...", "//tools/..."]) - args = parser.parse_args() - fix_compilation_database(args, generate_compilation_database(args)) + parser = argparse.ArgumentParser(description='Generate JSON compilation database') + parser.add_argument('--include_external', action='store_true') + parser.add_argument('--include_genfiles', action='store_true') + parser.add_argument('--include_headers', action='store_true') + parser.add_argument('--vscode', action='store_true') + parser.add_argument('bazel_targets', + nargs='*', + default=["//source/...", "//test/...", "//tools/..."]) + args = parser.parse_args() + fix_compilation_database(args, generate_compilation_database(args)) diff --git a/tools/github/sync_assignable.py b/tools/github/sync_assignable.py index 2db0e8169a5d..5bbf15234e5c 100644 --- a/tools/github/sync_assignable.py +++ b/tools/github/sync_assignable.py @@ -18,36 +18,36 @@ def get_confirmation(): - """Obtain stdin confirmation to add users in GH.""" - return input('Add users to envoyproxy/assignable ? [yN] ').strip().lower() in ('y', 'yes') + """Obtain stdin confirmation to add users in GH.""" + return input('Add users to envoyproxy/assignable ? [yN] ').strip().lower() in ('y', 'yes') def sync_assignable(access_token): - organization = github.Github(access_token).get_organization('envoyproxy') - team = organization.get_team_by_slug('assignable') - organization_members = set(organization.get_members()) - assignable_members = set(team.get_members()) - missing = organization_members.difference(assignable_members) + organization = github.Github(access_token).get_organization('envoyproxy') + team = organization.get_team_by_slug('assignable') + organization_members = set(organization.get_members()) + assignable_members = set(team.get_members()) + missing = organization_members.difference(assignable_members) - if not missing: - print('envoyproxy/assignable is consistent with organization membership.') - return 0 + if not missing: + print('envoyproxy/assignable is consistent with organization membership.') + return 0 - print('The following organization members are missing from envoyproxy/assignable:') - for m in missing: - print(m.login) + print('The following organization members are missing from envoyproxy/assignable:') + for m in missing: + print(m.login) - if not get_confirmation(): - return 1 + if not get_confirmation(): + return 1 - for m in missing: - team.add_membership(m, 'member') + for m in missing: + team.add_membership(m, 'member') if __name__ == '__main__': - access_token = os.getenv('GITHUB_TOKEN') - if not access_token: - print('Missing GITHUB_TOKEN') - sys.exit(1) + access_token = os.getenv('GITHUB_TOKEN') + if not access_token: + print('Missing GITHUB_TOKEN') + sys.exit(1) - sys.exit(sync_assignable(access_token)) + sys.exit(sync_assignable(access_token)) diff --git a/tools/print_dependencies.py b/tools/print_dependencies.py index 4c0fefb1a03f..fd17c36fd647 100755 --- a/tools/print_dependencies.py +++ b/tools/print_dependencies.py @@ -14,41 +14,41 @@ def print_deps(deps): - print(json.dumps(deps, sort_keys=True, indent=2)) + print(json.dumps(deps, sort_keys=True, indent=2)) if __name__ == '__main__': - deps = [] + deps = [] - DEPS.REPOSITORY_LOCATIONS.update(API_DEPS.REPOSITORY_LOCATIONS) + DEPS.REPOSITORY_LOCATIONS.update(API_DEPS.REPOSITORY_LOCATIONS) - for key, loc in DEPS.REPOSITORY_LOCATIONS.items(): - deps.append({ - 'identifier': key, - 'file-sha256': loc.get('sha256'), - 'file-url': loc.get('urls')[0], - 'file-prefix': loc.get('strip_prefix', ''), - }) + for key, loc in DEPS.REPOSITORY_LOCATIONS.items(): + deps.append({ + 'identifier': key, + 'file-sha256': loc.get('sha256'), + 'file-url': loc.get('urls')[0], + 'file-prefix': loc.get('strip_prefix', ''), + }) - deps = sorted(deps, key=lambda k: k['identifier']) + deps = sorted(deps, key=lambda k: k['identifier']) - # Print all dependencies if a target is unspecified - if len(sys.argv) == 1: - print_deps(deps) - exit(0) + # Print all dependencies if a target is unspecified + if len(sys.argv) == 1: + print_deps(deps) + exit(0) - # Bazel target to print - target = sys.argv[1] - output = subprocess.check_output(['bazel', 'query', 'deps(%s)' % target]) + # Bazel target to print + target = sys.argv[1] + output = subprocess.check_output(['bazel', 'query', 'deps(%s)' % target]) - repos = set() + repos = set() - # Gather the explicit list of repositories - repo_regex = re.compile('^@(.*)\/\/') - for line in output.split('\n'): - match = repo_regex.match(line) - if match: - repos.add(match.group(1)) + # Gather the explicit list of repositories + repo_regex = re.compile('^@(.*)\/\/') + for line in output.split('\n'): + match = repo_regex.match(line) + if match: + repos.add(match.group(1)) - deps = filter(lambda dep: dep['identifier'] in repos, deps) - print_deps(deps) + deps = filter(lambda dep: dep['identifier'] in repos, deps) + print_deps(deps) diff --git a/tools/proto_format/active_protos_gen.py b/tools/proto_format/active_protos_gen.py index fde0f9961af2..45e2cab570c9 100755 --- a/tools/proto_format/active_protos_gen.py +++ b/tools/proto_format/active_protos_gen.py @@ -35,33 +35,33 @@ # Key sort function to achieve consistent results with buildifier. def build_order_key(key): - return key.replace(':', '!') + return key.replace(':', '!') def deps_format(pkgs): - if not pkgs: - return '' - return '\n'.join( - ' "//%s:pkg",' % p.replace('.', '/') for p in sorted(pkgs, key=build_order_key)) + '\n' + if not pkgs: + return '' + return '\n'.join(' "//%s:pkg",' % p.replace('.', '/') + for p in sorted(pkgs, key=build_order_key)) + '\n' # Find packages with a given package version status in a given API tree root. def find_pkgs(package_version_status, api_root): - try: - active_files = subprocess.check_output( - ['grep', '-l', '-r', - 'package_version_status = %s;' % package_version_status, - api_root]).decode().strip().split('\n') - api_protos = [f for f in active_files if f.endswith('.proto')] - except subprocess.CalledProcessError: - api_protos = [] - return set([os.path.dirname(p)[len(api_root) + 1:] for p in api_protos]) + try: + active_files = subprocess.check_output( + ['grep', '-l', '-r', + 'package_version_status = %s;' % package_version_status, + api_root]).decode().strip().split('\n') + api_protos = [f for f in active_files if f.endswith('.proto')] + except subprocess.CalledProcessError: + api_protos = [] + return set([os.path.dirname(p)[len(api_root) + 1:] for p in api_protos]) if __name__ == '__main__': - api_root = sys.argv[1] - active_pkgs = find_pkgs('ACTIVE', api_root) - frozen_pkgs = find_pkgs('FROZEN', api_root) - sys.stdout.write( - BUILD_FILE_TEMPLATE.substitute(active_pkgs=deps_format(active_pkgs), - frozen_pkgs=deps_format(frozen_pkgs))) + api_root = sys.argv[1] + active_pkgs = find_pkgs('ACTIVE', api_root) + frozen_pkgs = find_pkgs('FROZEN', api_root) + sys.stdout.write( + BUILD_FILE_TEMPLATE.substitute(active_pkgs=deps_format(active_pkgs), + frozen_pkgs=deps_format(frozen_pkgs))) diff --git a/tools/proto_format/proto_sync.py b/tools/proto_format/proto_sync.py index f80b32f33cd9..6bcb1d6dc636 100755 --- a/tools/proto_format/proto_sync.py +++ b/tools/proto_format/proto_sync.py @@ -59,44 +59,44 @@ class ProtoSyncError(Exception): - pass + pass class RequiresReformatError(ProtoSyncError): - def __init__(self, message): - super(RequiresReformatError, self).__init__( - '%s; either run ./ci/do_ci.sh fix_format or ./tools/proto_format/proto_format.sh fix to reformat.\n' - % message) + def __init__(self, message): + super(RequiresReformatError, self).__init__( + '%s; either run ./ci/do_ci.sh fix_format or ./tools/proto_format/proto_format.sh fix to reformat.\n' + % message) def get_directory_from_package(package): - """Get directory path from package name or full qualified message name + """Get directory path from package name or full qualified message name Args: package: the full qualified name of package or message. """ - return '/'.join(s for s in package.split('.') if s and s[0].islower()) + return '/'.join(s for s in package.split('.') if s and s[0].islower()) def get_destination_path(src): - """Obtain destination path from a proto file path by reading its package statement. + """Obtain destination path from a proto file path by reading its package statement. Args: src: source path """ - src_path = pathlib.Path(src) - contents = src_path.read_text(encoding='utf8') - matches = re.findall(PACKAGE_REGEX, contents) - if len(matches) != 1: - raise RequiresReformatError("Expect {} has only one package declaration but has {}".format( - src, len(matches))) - return pathlib.Path(get_directory_from_package( - matches[0])).joinpath(src_path.name.split('.')[0] + ".proto") + src_path = pathlib.Path(src) + contents = src_path.read_text(encoding='utf8') + matches = re.findall(PACKAGE_REGEX, contents) + if len(matches) != 1: + raise RequiresReformatError("Expect {} has only one package declaration but has {}".format( + src, len(matches))) + return pathlib.Path(get_directory_from_package( + matches[0])).joinpath(src_path.name.split('.')[0] + ".proto") def get_abs_rel_destination_path(dst_root, src): - """Obtain absolute path from a proto file path combined with destination root. + """Obtain absolute path from a proto file path combined with destination root. Creates the parent directory if necessary. @@ -104,46 +104,46 @@ def get_abs_rel_destination_path(dst_root, src): dst_root: destination root path. src: source path. """ - rel_dst_path = get_destination_path(src) - dst = dst_root.joinpath(rel_dst_path) - dst.parent.mkdir(0o755, parents=True, exist_ok=True) - return dst, rel_dst_path + rel_dst_path = get_destination_path(src) + dst = dst_root.joinpath(rel_dst_path) + dst.parent.mkdir(0o755, parents=True, exist_ok=True) + return dst, rel_dst_path def proto_print(src, dst): - """Pretty-print FileDescriptorProto to a destination file. + """Pretty-print FileDescriptorProto to a destination file. Args: src: source path for FileDescriptorProto. dst: destination path for formatted proto. """ - print('proto_print %s' % dst) - subprocess.check_output([ - 'bazel-bin/tools/protoxform/protoprint', src, - str(dst), - './bazel-bin/tools/protoxform/protoprint.runfiles/envoy/tools/type_whisperer/api_type_db.pb_text' - ]) + print('proto_print %s' % dst) + subprocess.check_output([ + 'bazel-bin/tools/protoxform/protoprint', src, + str(dst), + './bazel-bin/tools/protoxform/protoprint.runfiles/envoy/tools/type_whisperer/api_type_db.pb_text' + ]) def merge_active_shadow(active_src, shadow_src, dst): - """Merge active/shadow FileDescriptorProto to a destination file. + """Merge active/shadow FileDescriptorProto to a destination file. Args: active_src: source path for active FileDescriptorProto. shadow_src: source path for active FileDescriptorProto. dst: destination path for FileDescriptorProto. """ - print('merge_active_shadow %s' % dst) - subprocess.check_output([ - 'bazel-bin/tools/protoxform/merge_active_shadow', - active_src, - shadow_src, - dst, - ]) + print('merge_active_shadow %s' % dst) + subprocess.check_output([ + 'bazel-bin/tools/protoxform/merge_active_shadow', + active_src, + shadow_src, + dst, + ]) def sync_proto_file(dst_srcs): - """Pretty-print a proto descriptor from protoxform.py Bazel cache artifacts." + """Pretty-print a proto descriptor from protoxform.py Bazel cache artifacts." In the case where we are generating an Envoy internal shadow, it may be necessary to combine the current active proto, subject to hand editing, with @@ -153,36 +153,36 @@ def sync_proto_file(dst_srcs): Args: dst_srcs: destination/sources path tuple. """ - dst, srcs = dst_srcs - assert (len(srcs) > 0) - # If we only have one candidate source for a destination, just pretty-print. - if len(srcs) == 1: - src = srcs[0] - proto_print(src, dst) - else: - # We should only see an active and next major version candidate from - # previous version today. - assert (len(srcs) == 2) - shadow_srcs = [ - s for s in srcs if s.endswith('.next_major_version_candidate.envoy_internal.proto') - ] - active_src = [s for s in srcs if s.endswith('active_or_frozen.proto')][0] - # If we're building the shadow, we need to combine the next major version - # candidate shadow with the potentially hand edited active version. - if len(shadow_srcs) > 0: - assert (len(shadow_srcs) == 1) - with tempfile.NamedTemporaryFile() as f: - merge_active_shadow(active_src, shadow_srcs[0], f.name) - proto_print(f.name, dst) + dst, srcs = dst_srcs + assert (len(srcs) > 0) + # If we only have one candidate source for a destination, just pretty-print. + if len(srcs) == 1: + src = srcs[0] + proto_print(src, dst) else: - proto_print(active_src, dst) - src = active_src - rel_dst_path = get_destination_path(src) - return ['//%s:pkg' % str(rel_dst_path.parent)] + # We should only see an active and next major version candidate from + # previous version today. + assert (len(srcs) == 2) + shadow_srcs = [ + s for s in srcs if s.endswith('.next_major_version_candidate.envoy_internal.proto') + ] + active_src = [s for s in srcs if s.endswith('active_or_frozen.proto')][0] + # If we're building the shadow, we need to combine the next major version + # candidate shadow with the potentially hand edited active version. + if len(shadow_srcs) > 0: + assert (len(shadow_srcs) == 1) + with tempfile.NamedTemporaryFile() as f: + merge_active_shadow(active_src, shadow_srcs[0], f.name) + proto_print(f.name, dst) + else: + proto_print(active_src, dst) + src = active_src + rel_dst_path = get_destination_path(src) + return ['//%s:pkg' % str(rel_dst_path.parent)] def get_import_deps(proto_path): - """Obtain the Bazel dependencies for the import paths from a .proto file. + """Obtain the Bazel dependencies for the import paths from a .proto file. Args: proto_path: path to .proto. @@ -190,41 +190,42 @@ def get_import_deps(proto_path): Returns: A list of Bazel targets reflecting the imports in the .proto at proto_path. """ - imports = [] - with open(proto_path, 'r', encoding='utf8') as f: - for line in f: - match = re.match(IMPORT_REGEX, line) - if match: - import_path = match.group(1) - # We can ignore imports provided implicitly by api_proto_package(). - if any(import_path.startswith(p) for p in API_BUILD_SYSTEM_IMPORT_PREFIXES): - continue - # Special case handling for UDPA annotations. - if import_path.startswith('udpa/annotations/'): - imports.append('@com_github_cncf_udpa//udpa/annotations:pkg') - continue - # Special case handling for UDPA core. - if import_path.startswith('xds/core/v3/'): - imports.append('@com_github_cncf_udpa//xds/core/v3:pkg') - continue - # Explicit remapping for external deps, compute paths for envoy/*. - if import_path in external_proto_deps.EXTERNAL_PROTO_IMPORT_BAZEL_DEP_MAP: - imports.append(external_proto_deps.EXTERNAL_PROTO_IMPORT_BAZEL_DEP_MAP[import_path]) - continue - if import_path.startswith('envoy/'): - # Ignore package internal imports. - if os.path.dirname(proto_path).endswith(os.path.dirname(import_path)): - continue - imports.append('//%s:pkg' % os.path.dirname(import_path)) - continue - raise ProtoSyncError( - 'Unknown import path mapping for %s, please update the mappings in tools/proto_format/proto_sync.py.\n' - % import_path) - return imports + imports = [] + with open(proto_path, 'r', encoding='utf8') as f: + for line in f: + match = re.match(IMPORT_REGEX, line) + if match: + import_path = match.group(1) + # We can ignore imports provided implicitly by api_proto_package(). + if any(import_path.startswith(p) for p in API_BUILD_SYSTEM_IMPORT_PREFIXES): + continue + # Special case handling for UDPA annotations. + if import_path.startswith('udpa/annotations/'): + imports.append('@com_github_cncf_udpa//udpa/annotations:pkg') + continue + # Special case handling for UDPA core. + if import_path.startswith('xds/core/v3/'): + imports.append('@com_github_cncf_udpa//xds/core/v3:pkg') + continue + # Explicit remapping for external deps, compute paths for envoy/*. + if import_path in external_proto_deps.EXTERNAL_PROTO_IMPORT_BAZEL_DEP_MAP: + imports.append( + external_proto_deps.EXTERNAL_PROTO_IMPORT_BAZEL_DEP_MAP[import_path]) + continue + if import_path.startswith('envoy/'): + # Ignore package internal imports. + if os.path.dirname(proto_path).endswith(os.path.dirname(import_path)): + continue + imports.append('//%s:pkg' % os.path.dirname(import_path)) + continue + raise ProtoSyncError( + 'Unknown import path mapping for %s, please update the mappings in tools/proto_format/proto_sync.py.\n' + % import_path) + return imports def get_previous_message_type_deps(proto_path): - """Obtain the Bazel dependencies for the previous version of messages in a .proto file. + """Obtain the Bazel dependencies for the previous version of messages in a .proto file. We need to link in earlier proto descriptors to support Envoy reflection upgrades. @@ -234,17 +235,17 @@ def get_previous_message_type_deps(proto_path): Returns: A list of Bazel targets reflecting the previous message types in the .proto at proto_path. """ - contents = pathlib.Path(proto_path).read_text(encoding='utf8') - matches = re.findall(PREVIOUS_MESSAGE_TYPE_REGEX, contents) - deps = [] - for m in matches: - target = '//%s:pkg' % get_directory_from_package(m) - deps.append(target) - return deps + contents = pathlib.Path(proto_path).read_text(encoding='utf8') + matches = re.findall(PREVIOUS_MESSAGE_TYPE_REGEX, contents) + deps = [] + for m in matches: + target = '//%s:pkg' % get_directory_from_package(m) + deps.append(target) + return deps def has_services(proto_path): - """Does a .proto file have any service definitions? + """Does a .proto file have any service definitions? Args: proto_path: path to .proto. @@ -252,20 +253,20 @@ def has_services(proto_path): Returns: True iff there are service definitions in the .proto at proto_path. """ - with open(proto_path, 'r', encoding='utf8') as f: - for line in f: - if re.match(SERVICE_REGEX, line): - return True - return False + with open(proto_path, 'r', encoding='utf8') as f: + for line in f: + if re.match(SERVICE_REGEX, line): + return True + return False # Key sort function to achieve consistent results with buildifier. def build_order_key(key): - return key.replace(':', '!') + return key.replace(':', '!') def build_file_contents(root, files): - """Compute the canonical BUILD contents for an api/ proto directory. + """Compute the canonical BUILD contents for an api/ proto directory. Args: root: base path to directory. @@ -274,65 +275,65 @@ def build_file_contents(root, files): Returns: A string containing the canonical BUILD file content for root. """ - import_deps = set(sum([get_import_deps(os.path.join(root, f)) for f in files], [])) - history_deps = set(sum([get_previous_message_type_deps(os.path.join(root, f)) for f in files], - [])) - deps = import_deps.union(history_deps) - _has_services = any(has_services(os.path.join(root, f)) for f in files) - fields = [] - if _has_services: - fields.append(' has_services = True,') - if deps: - if len(deps) == 1: - formatted_deps = '"%s"' % list(deps)[0] - else: - formatted_deps = '\n' + '\n'.join( - ' "%s",' % dep for dep in sorted(deps, key=build_order_key)) + '\n ' - fields.append(' deps = [%s],' % formatted_deps) - formatted_fields = '\n' + '\n'.join(fields) + '\n' if fields else '' - return BUILD_FILE_TEMPLATE.substitute(fields=formatted_fields) + import_deps = set(sum([get_import_deps(os.path.join(root, f)) for f in files], [])) + history_deps = set( + sum([get_previous_message_type_deps(os.path.join(root, f)) for f in files], [])) + deps = import_deps.union(history_deps) + _has_services = any(has_services(os.path.join(root, f)) for f in files) + fields = [] + if _has_services: + fields.append(' has_services = True,') + if deps: + if len(deps) == 1: + formatted_deps = '"%s"' % list(deps)[0] + else: + formatted_deps = '\n' + '\n'.join( + ' "%s",' % dep for dep in sorted(deps, key=build_order_key)) + '\n ' + fields.append(' deps = [%s],' % formatted_deps) + formatted_fields = '\n' + '\n'.join(fields) + '\n' if fields else '' + return BUILD_FILE_TEMPLATE.substitute(fields=formatted_fields) def sync_build_files(cmd, dst_root): - """Diff or in-place update api/ BUILD files. + """Diff or in-place update api/ BUILD files. Args: cmd: 'check' or 'fix'. """ - for root, dirs, files in os.walk(str(dst_root)): - is_proto_dir = any(f.endswith('.proto') for f in files) - if not is_proto_dir: - continue - build_contents = build_file_contents(root, files) - build_path = os.path.join(root, 'BUILD') - with open(build_path, 'w') as f: - f.write(build_contents) + for root, dirs, files in os.walk(str(dst_root)): + is_proto_dir = any(f.endswith('.proto') for f in files) + if not is_proto_dir: + continue + build_contents = build_file_contents(root, files) + build_path = os.path.join(root, 'BUILD') + with open(build_path, 'w') as f: + f.write(build_contents) def generate_current_api_dir(api_dir, dst_dir): - """Helper function to generate original API repository to be compared with diff. + """Helper function to generate original API repository to be compared with diff. This copies the original API repository and deletes file we don't want to compare. Args: api_dir: the original api directory dst_dir: the api directory to be compared in temporary directory """ - dst = dst_dir.joinpath("envoy") - shutil.copytree(str(api_dir.joinpath("envoy")), str(dst)) + dst = dst_dir.joinpath("envoy") + shutil.copytree(str(api_dir.joinpath("envoy")), str(dst)) - for p in dst.glob('**/*.md'): - p.unlink() - # envoy.service.auth.v2alpha exist for compatibility while we don't run in protoxform - # so we ignore it here. - shutil.rmtree(str(dst.joinpath("service", "auth", "v2alpha"))) + for p in dst.glob('**/*.md'): + p.unlink() + # envoy.service.auth.v2alpha exist for compatibility while we don't run in protoxform + # so we ignore it here. + shutil.rmtree(str(dst.joinpath("service", "auth", "v2alpha"))) def git_status(path): - return subprocess.check_output(['git', 'status', '--porcelain', str(path)]).decode() + return subprocess.check_output(['git', 'status', '--porcelain', str(path)]).decode() def git_modified_files(path, suffix): - """Obtain a list of modified files since the last commit merged by GitHub. + """Obtain a list of modified files since the last commit merged by GitHub. Args: path: path to examine. @@ -340,14 +341,14 @@ def git_modified_files(path, suffix): Return: A list of strings providing the paths of modified files in the repo. """ - try: - modified_files = subprocess.check_output( - ['tools/git/modified_since_last_github_commit.sh', 'api', 'proto']).decode().split() - return modified_files - except subprocess.CalledProcessError as e: - if e.returncode == 1: - return [] - raise + try: + modified_files = subprocess.check_output( + ['tools/git/modified_since_last_github_commit.sh', 'api', 'proto']).decode().split() + return modified_files + except subprocess.CalledProcessError as e: + if e.returncode == 1: + return [] + raise # If we're not forcing format, i.e. FORCE_PROTO_FORMAT=yes, in the environment, @@ -355,102 +356,106 @@ def git_modified_files(path, suffix): # heuristics. This saves a ton of time, since proto format and sync is not # running under Bazel and can't do change detection. def should_sync(path, api_proto_modified_files, py_tools_modified_files): - if os.getenv('FORCE_PROTO_FORMAT') == 'yes': - return True - # If tools change, safest thing to do is rebuild everything. - if len(py_tools_modified_files) > 0: - return True - # Check to see if the basename of the file has been modified since the last - # GitHub commit. If so, rebuild. This is safe and conservative across package - # migrations in v3 and v4alpha; we could achieve a lower rate of false - # positives if we examined package migration annotations, at the expense of - # complexity. - for p in api_proto_modified_files: - if os.path.basename(p) in path: - return True - # Otherwise we can safely skip syncing. - return False + if os.getenv('FORCE_PROTO_FORMAT') == 'yes': + return True + # If tools change, safest thing to do is rebuild everything. + if len(py_tools_modified_files) > 0: + return True + # Check to see if the basename of the file has been modified since the last + # GitHub commit. If so, rebuild. This is safe and conservative across package + # migrations in v3 and v4alpha; we could achieve a lower rate of false + # positives if we examined package migration annotations, at the expense of + # complexity. + for p in api_proto_modified_files: + if os.path.basename(p) in path: + return True + # Otherwise we can safely skip syncing. + return False def sync(api_root, mode, labels, shadow): - api_proto_modified_files = git_modified_files('api', 'proto') - py_tools_modified_files = git_modified_files('tools', 'py') - with tempfile.TemporaryDirectory() as tmp: - dst_dir = pathlib.Path(tmp).joinpath("b") - paths = [] - for label in labels: - paths.append(utils.bazel_bin_path_for_output_artifact(label, '.active_or_frozen.proto')) - paths.append( - utils.bazel_bin_path_for_output_artifact( - label, '.next_major_version_candidate.envoy_internal.proto' - if shadow else '.next_major_version_candidate.proto')) - dst_src_paths = defaultdict(list) - for path in paths: - if os.stat(path).st_size > 0: - abs_dst_path, rel_dst_path = get_abs_rel_destination_path(dst_dir, path) - if should_sync(path, api_proto_modified_files, py_tools_modified_files): - dst_src_paths[abs_dst_path].append(path) - else: - print('Skipping sync of %s' % path) - src_path = str(pathlib.Path(api_root, rel_dst_path)) - shutil.copy(src_path, abs_dst_path) - with mp.Pool() as p: - pkg_deps = p.map(sync_proto_file, dst_src_paths.items()) - sync_build_files(mode, dst_dir) - - current_api_dir = pathlib.Path(tmp).joinpath("a") - current_api_dir.mkdir(0o755, True, True) - api_root_path = pathlib.Path(api_root) - generate_current_api_dir(api_root_path, current_api_dir) - - # These support files are handled manually. - for f in [ - 'envoy/annotations/resource.proto', 'envoy/annotations/deprecation.proto', - 'envoy/annotations/BUILD' - ]: - copy_dst_dir = pathlib.Path(dst_dir, os.path.dirname(f)) - copy_dst_dir.mkdir(exist_ok=True) - shutil.copy(str(pathlib.Path(api_root, f)), str(copy_dst_dir)) - - diff = subprocess.run(['diff', '-Npur', "a", "b"], cwd=tmp, stdout=subprocess.PIPE).stdout - - if diff.strip(): - if mode == "check": - print("Please apply following patch to directory '{}'".format(api_root), file=sys.stderr) - print(diff.decode(), file=sys.stderr) - sys.exit(1) - if mode == "fix": - _git_status = git_status(api_root) - if _git_status: - print('git status indicates a dirty API tree:\n%s' % _git_status) - print( - 'Proto formatting may overwrite or delete files in the above list with no git backup.' - ) - if input('Continue? [yN] ').strip().lower() != 'y': - sys.exit(1) - src_files = set(str(p.relative_to(current_api_dir)) for p in current_api_dir.rglob('*')) - dst_files = set(str(p.relative_to(dst_dir)) for p in dst_dir.rglob('*')) - deleted_files = src_files.difference(dst_files) - if deleted_files: - print('The following files will be deleted: %s' % sorted(deleted_files)) - print( - 'If this is not intended, please see https://github.com/envoyproxy/envoy/blob/main/api/STYLE.md#adding-an-extension-configuration-to-the-api.' - ) - if input('Delete files? [yN] ').strip().lower() == 'y': - subprocess.run(['patch', '-p1'], input=diff, cwd=str(api_root_path.resolve())) - else: - sys.exit(1) - else: - subprocess.run(['patch', '-p1'], input=diff, cwd=str(api_root_path.resolve())) + api_proto_modified_files = git_modified_files('api', 'proto') + py_tools_modified_files = git_modified_files('tools', 'py') + with tempfile.TemporaryDirectory() as tmp: + dst_dir = pathlib.Path(tmp).joinpath("b") + paths = [] + for label in labels: + paths.append(utils.bazel_bin_path_for_output_artifact(label, '.active_or_frozen.proto')) + paths.append( + utils.bazel_bin_path_for_output_artifact( + label, '.next_major_version_candidate.envoy_internal.proto' + if shadow else '.next_major_version_candidate.proto')) + dst_src_paths = defaultdict(list) + for path in paths: + if os.stat(path).st_size > 0: + abs_dst_path, rel_dst_path = get_abs_rel_destination_path(dst_dir, path) + if should_sync(path, api_proto_modified_files, py_tools_modified_files): + dst_src_paths[abs_dst_path].append(path) + else: + print('Skipping sync of %s' % path) + src_path = str(pathlib.Path(api_root, rel_dst_path)) + shutil.copy(src_path, abs_dst_path) + with mp.Pool() as p: + pkg_deps = p.map(sync_proto_file, dst_src_paths.items()) + sync_build_files(mode, dst_dir) + + current_api_dir = pathlib.Path(tmp).joinpath("a") + current_api_dir.mkdir(0o755, True, True) + api_root_path = pathlib.Path(api_root) + generate_current_api_dir(api_root_path, current_api_dir) + + # These support files are handled manually. + for f in [ + 'envoy/annotations/resource.proto', 'envoy/annotations/deprecation.proto', + 'envoy/annotations/BUILD' + ]: + copy_dst_dir = pathlib.Path(dst_dir, os.path.dirname(f)) + copy_dst_dir.mkdir(exist_ok=True) + shutil.copy(str(pathlib.Path(api_root, f)), str(copy_dst_dir)) + + diff = subprocess.run(['diff', '-Npur', "a", "b"], cwd=tmp, stdout=subprocess.PIPE).stdout + + if diff.strip(): + if mode == "check": + print("Please apply following patch to directory '{}'".format(api_root), + file=sys.stderr) + print(diff.decode(), file=sys.stderr) + sys.exit(1) + if mode == "fix": + _git_status = git_status(api_root) + if _git_status: + print('git status indicates a dirty API tree:\n%s' % _git_status) + print( + 'Proto formatting may overwrite or delete files in the above list with no git backup.' + ) + if input('Continue? [yN] ').strip().lower() != 'y': + sys.exit(1) + src_files = set( + str(p.relative_to(current_api_dir)) for p in current_api_dir.rglob('*')) + dst_files = set(str(p.relative_to(dst_dir)) for p in dst_dir.rglob('*')) + deleted_files = src_files.difference(dst_files) + if deleted_files: + print('The following files will be deleted: %s' % sorted(deleted_files)) + print( + 'If this is not intended, please see https://github.com/envoyproxy/envoy/blob/main/api/STYLE.md#adding-an-extension-configuration-to-the-api.' + ) + if input('Delete files? [yN] ').strip().lower() == 'y': + subprocess.run(['patch', '-p1'], + input=diff, + cwd=str(api_root_path.resolve())) + else: + sys.exit(1) + else: + subprocess.run(['patch', '-p1'], input=diff, cwd=str(api_root_path.resolve())) if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--mode', choices=['check', 'fix']) - parser.add_argument('--api_root', default='./api') - parser.add_argument('--api_shadow_root', default='./generated_api_shadow') - parser.add_argument('labels', nargs='*') - args = parser.parse_args() - - sync(args.api_root, args.mode, args.labels, False) - sync(args.api_shadow_root, args.mode, args.labels, True) + parser = argparse.ArgumentParser() + parser.add_argument('--mode', choices=['check', 'fix']) + parser.add_argument('--api_root', default='./api') + parser.add_argument('--api_shadow_root', default='./generated_api_shadow') + parser.add_argument('labels', nargs='*') + args = parser.parse_args() + + sync(args.api_root, args.mode, args.labels, False) + sync(args.api_shadow_root, args.mode, args.labels, True) diff --git a/tools/protodoc/generate_empty.py b/tools/protodoc/generate_empty.py index db3aafdc24df..347f022ac5ef 100644 --- a/tools/protodoc/generate_empty.py +++ b/tools/protodoc/generate_empty.py @@ -23,26 +23,27 @@ def generate_empty_extension_docs(extension, details, api_extensions_root): - extension_root = pathlib.Path(details['path']) - path = pathlib.Path(api_extensions_root, extension_root, 'empty', extension_root.name + '.rst') - path.parent.mkdir(parents=True, exist_ok=True) - description = details.get('description', '') - reflink = '' - if 'ref' in details: - reflink = '%s %s.' % (details['title'], - protodoc.format_internal_link('configuration overview', details['ref'])) - content = EMPTY_EXTENSION_DOCS_TEMPLATE.substitute(header=protodoc.format_header( - '=', details['title']), - description=description, - reflink=reflink, - extension=protodoc.format_extension(extension)) - path.write_text(content) + extension_root = pathlib.Path(details['path']) + path = pathlib.Path(api_extensions_root, extension_root, 'empty', extension_root.name + '.rst') + path.parent.mkdir(parents=True, exist_ok=True) + description = details.get('description', '') + reflink = '' + if 'ref' in details: + reflink = '%s %s.' % (details['title'], + protodoc.format_internal_link('configuration overview', + details['ref'])) + content = EMPTY_EXTENSION_DOCS_TEMPLATE.substitute( + header=protodoc.format_header('=', details['title']), + description=description, + reflink=reflink, + extension=protodoc.format_extension(extension)) + path.write_text(content) if __name__ == '__main__': - empty_extensions_path = sys.argv[1] - api_extensions_root = sys.argv[2] + empty_extensions_path = sys.argv[1] + api_extensions_root = sys.argv[2] - empty_extensions = json.loads(pathlib.Path(empty_extensions_path).read_text()) - for extension, details in empty_extensions.items(): - generate_empty_extension_docs(extension, details, api_extensions_root) + empty_extensions = json.loads(pathlib.Path(empty_extensions_path).read_text()) + for extension, details in empty_extensions.items(): + generate_empty_extension_docs(extension, details, api_extensions_root) diff --git a/tools/protodoc/protodoc.py b/tools/protodoc/protodoc.py index 42e88769c314..7593f6f92f6e 100755 --- a/tools/protodoc/protodoc.py +++ b/tools/protodoc/protodoc.py @@ -121,21 +121,21 @@ # create an index of extension categories from extension db EXTENSION_CATEGORIES = {} for _k, _v in EXTENSION_DB.items(): - for _cat in _v['categories']: - EXTENSION_CATEGORIES.setdefault(_cat, []).append(_k) + for _cat in _v['categories']: + EXTENSION_CATEGORIES.setdefault(_cat, []).append(_k) class ProtodocError(Exception): - """Base error class for the protodoc module.""" + """Base error class for the protodoc module.""" def hide_not_implemented(comment): - """Should a given type_context.Comment be hidden because it is tagged as [#not-implemented-hide:]?""" - return annotations.NOT_IMPLEMENTED_HIDE_ANNOTATION in comment.annotations + """Should a given type_context.Comment be hidden because it is tagged as [#not-implemented-hide:]?""" + return annotations.NOT_IMPLEMENTED_HIDE_ANNOTATION in comment.annotations def github_url(type_context): - """Obtain data plane API Github URL by path from a TypeContext. + """Obtain data plane API Github URL by path from a TypeContext. Args: type_context: type_context.TypeContext for node. @@ -143,14 +143,14 @@ def github_url(type_context): Returns: A string with a corresponding data plane API GitHub Url. """ - if type_context.location is not None: - return DATA_PLANE_API_URL_FMT % (type_context.source_code_info.name, - type_context.location.span[0]) - return '' + if type_context.location is not None: + return DATA_PLANE_API_URL_FMT % (type_context.source_code_info.name, + type_context.location.span[0]) + return '' def format_comment_with_annotations(comment, type_name=''): - """Format a comment string with additional RST for annotations. + """Format a comment string with additional RST for annotations. Args: comment: comment string. @@ -160,20 +160,20 @@ def format_comment_with_annotations(comment, type_name=''): Returns: A string with additional RST from annotations. """ - formatted_extension = '' - if annotations.EXTENSION_ANNOTATION in comment.annotations: - extension = comment.annotations[annotations.EXTENSION_ANNOTATION] - formatted_extension = format_extension(extension) - formatted_extension_category = '' - if annotations.EXTENSION_CATEGORY_ANNOTATION in comment.annotations: - for category in comment.annotations[annotations.EXTENSION_CATEGORY_ANNOTATION].split(","): - formatted_extension_category += format_extension_category(category) - comment = annotations.without_annotations(strip_leading_space(comment.raw) + '\n') - return comment + formatted_extension + formatted_extension_category + formatted_extension = '' + if annotations.EXTENSION_ANNOTATION in comment.annotations: + extension = comment.annotations[annotations.EXTENSION_ANNOTATION] + formatted_extension = format_extension(extension) + formatted_extension_category = '' + if annotations.EXTENSION_CATEGORY_ANNOTATION in comment.annotations: + for category in comment.annotations[annotations.EXTENSION_CATEGORY_ANNOTATION].split(","): + formatted_extension_category += format_extension_category(category) + comment = annotations.without_annotations(strip_leading_space(comment.raw) + '\n') + return comment + formatted_extension + formatted_extension_category def map_lines(f, s): - """Apply a function across each line in a flat string. + """Apply a function across each line in a flat string. Args: f: A string transform function for a line. @@ -182,29 +182,29 @@ def map_lines(f, s): Returns: A flat string with f applied to each line. """ - return '\n'.join(f(line) for line in s.split('\n')) + return '\n'.join(f(line) for line in s.split('\n')) def indent(spaces, line): - """Indent a string.""" - return ' ' * spaces + line + """Indent a string.""" + return ' ' * spaces + line def indent_lines(spaces, lines): - """Indent a list of strings.""" - return map(functools.partial(indent, spaces), lines) + """Indent a list of strings.""" + return map(functools.partial(indent, spaces), lines) def format_internal_link(text, ref): - return ':ref:`%s <%s>`' % (text, ref) + return ':ref:`%s <%s>`' % (text, ref) def format_external_link(text, ref): - return '`%s <%s>`_' % (text, ref) + return '`%s <%s>`_' % (text, ref) def format_header(style, text): - """Format RST header. + """Format RST header. Args: style: underline style, e.g. '=', '-'. @@ -213,11 +213,11 @@ def format_header(style, text): Returns: RST formatted header. """ - return '%s\n%s\n\n' % (text, style * len(text)) + return '%s\n%s\n\n' % (text, style * len(text)) def format_extension(extension): - """Format extension metadata as RST. + """Format extension metadata as RST. Args: extension: the name of the extension, e.g. com.acme.foo. @@ -225,25 +225,25 @@ def format_extension(extension): Returns: RST formatted extension description. """ - try: - extension_metadata = EXTENSION_DB[extension] - status = EXTENSION_STATUS_VALUES.get(extension_metadata['status'], '') - security_posture = EXTENSION_SECURITY_POSTURES[extension_metadata['security_posture']] - categories = extension_metadata["categories"] - except KeyError as e: - sys.stderr.write( - f"\n\nDid you forget to add '{extension}' to source/extensions/extensions_build_config.bzl?\n\n" - ) - exit(1) # Raising the error buries the above message in tracebacks. - - return EXTENSION_TEMPLATE.render(extension=extension, - status=status, - security_posture=security_posture, - categories=categories) + try: + extension_metadata = EXTENSION_DB[extension] + status = EXTENSION_STATUS_VALUES.get(extension_metadata['status'], '') + security_posture = EXTENSION_SECURITY_POSTURES[extension_metadata['security_posture']] + categories = extension_metadata["categories"] + except KeyError as e: + sys.stderr.write( + f"\n\nDid you forget to add '{extension}' to source/extensions/extensions_build_config.bzl?\n\n" + ) + exit(1) # Raising the error buries the above message in tracebacks. + + return EXTENSION_TEMPLATE.render(extension=extension, + status=status, + security_posture=security_posture, + categories=categories) def format_extension_category(extension_category): - """Format extension metadata as RST. + """Format extension metadata as RST. Args: extension_category: the name of the extension_category, e.g. com.acme. @@ -251,16 +251,16 @@ def format_extension_category(extension_category): Returns: RST formatted extension category description. """ - try: - extensions = EXTENSION_CATEGORIES[extension_category] - except KeyError as e: - raise ProtodocError(f"\n\nUnable to find extension category: {extension_category}\n\n") - return EXTENSION_CATEGORY_TEMPLATE.render(category=extension_category, - extensions=sorted(extensions)) + try: + extensions = EXTENSION_CATEGORIES[extension_category] + except KeyError as e: + raise ProtodocError(f"\n\nUnable to find extension category: {extension_category}\n\n") + return EXTENSION_CATEGORY_TEMPLATE.render(category=extension_category, + extensions=sorted(extensions)) def format_header_from_file(style, source_code_info, proto_name): - """Format RST header based on special file level title + """Format RST header based on special file level title Args: style: underline style, e.g. '=', '-'. @@ -271,61 +271,61 @@ def format_header_from_file(style, source_code_info, proto_name): Returns: RST formatted header, and file level comment without page title strings. """ - anchor = format_anchor(file_cross_ref_label(proto_name)) - stripped_comment = annotations.without_annotations( - strip_leading_space('\n'.join(c + '\n' for c in source_code_info.file_level_comments))) - formatted_extension = '' - if annotations.EXTENSION_ANNOTATION in source_code_info.file_level_annotations: - extension = source_code_info.file_level_annotations[annotations.EXTENSION_ANNOTATION] - formatted_extension = format_extension(extension) - if annotations.DOC_TITLE_ANNOTATION in source_code_info.file_level_annotations: - return anchor + format_header( - style, source_code_info.file_level_annotations[ - annotations.DOC_TITLE_ANNOTATION]) + formatted_extension, stripped_comment - return anchor + format_header(style, proto_name) + formatted_extension, stripped_comment + anchor = format_anchor(file_cross_ref_label(proto_name)) + stripped_comment = annotations.without_annotations( + strip_leading_space('\n'.join(c + '\n' for c in source_code_info.file_level_comments))) + formatted_extension = '' + if annotations.EXTENSION_ANNOTATION in source_code_info.file_level_annotations: + extension = source_code_info.file_level_annotations[annotations.EXTENSION_ANNOTATION] + formatted_extension = format_extension(extension) + if annotations.DOC_TITLE_ANNOTATION in source_code_info.file_level_annotations: + return anchor + format_header( + style, source_code_info.file_level_annotations[ + annotations.DOC_TITLE_ANNOTATION]) + formatted_extension, stripped_comment + return anchor + format_header(style, proto_name) + formatted_extension, stripped_comment def format_field_type_as_json(type_context, field): - """Format FieldDescriptorProto.Type as a pseudo-JSON string. + """Format FieldDescriptorProto.Type as a pseudo-JSON string. Args: type_context: contextual information for message/enum/field. field: FieldDescriptor proto. Return: RST formatted pseudo-JSON string representation of field type. """ - if type_name_from_fqn(field.type_name) in type_context.map_typenames: - return '"{...}"' - if field.label == field.LABEL_REPEATED: - return '[]' - if field.type == field.TYPE_MESSAGE: - return '"{...}"' - return '"..."' + if type_name_from_fqn(field.type_name) in type_context.map_typenames: + return '"{...}"' + if field.label == field.LABEL_REPEATED: + return '[]' + if field.type == field.TYPE_MESSAGE: + return '"{...}"' + return '"..."' def format_message_as_json(type_context, msg): - """Format a message definition DescriptorProto as a pseudo-JSON block. + """Format a message definition DescriptorProto as a pseudo-JSON block. Args: type_context: contextual information for message/enum/field. msg: message definition DescriptorProto. Return: RST formatted pseudo-JSON string representation of message definition. """ - lines = [] - for index, field in enumerate(msg.field): - field_type_context = type_context.extend_field(index, field.name) - leading_comment = field_type_context.leading_comment - if hide_not_implemented(leading_comment): - continue - lines.append('"%s": %s' % (field.name, format_field_type_as_json(type_context, field))) + lines = [] + for index, field in enumerate(msg.field): + field_type_context = type_context.extend_field(index, field.name) + leading_comment = field_type_context.leading_comment + if hide_not_implemented(leading_comment): + continue + lines.append('"%s": %s' % (field.name, format_field_type_as_json(type_context, field))) - if lines: - return '.. code-block:: json\n\n {\n' + ',\n'.join(indent_lines(4, lines)) + '\n }\n\n' - else: - return '.. code-block:: json\n\n {}\n\n' + if lines: + return '.. code-block:: json\n\n {\n' + ',\n'.join(indent_lines(4, lines)) + '\n }\n\n' + else: + return '.. code-block:: json\n\n {}\n\n' def normalize_field_type_name(field_fqn): - """Normalize a fully qualified field type name, e.g. + """Normalize a fully qualified field type name, e.g. .envoy.foo.bar. @@ -335,15 +335,15 @@ def normalize_field_type_name(field_fqn): field_fqn: a fully qualified type name from FieldDescriptorProto.type_name. Return: Normalized type name. """ - if field_fqn.startswith(ENVOY_API_NAMESPACE_PREFIX): - return field_fqn[len(ENVOY_API_NAMESPACE_PREFIX):] - if field_fqn.startswith(ENVOY_PREFIX): - return field_fqn[len(ENVOY_PREFIX):] - return field_fqn + if field_fqn.startswith(ENVOY_API_NAMESPACE_PREFIX): + return field_fqn[len(ENVOY_API_NAMESPACE_PREFIX):] + if field_fqn.startswith(ENVOY_PREFIX): + return field_fqn[len(ENVOY_PREFIX):] + return field_fqn def normalize_type_context_name(type_name): - """Normalize a type name, e.g. + """Normalize a type name, e.g. envoy.foo.bar. @@ -353,19 +353,19 @@ def normalize_type_context_name(type_name): type_name: a name from a TypeContext. Return: Normalized type name. """ - return normalize_field_type_name(qualify_type_name(type_name)) + return normalize_field_type_name(qualify_type_name(type_name)) def qualify_type_name(type_name): - return '.' + type_name + return '.' + type_name def type_name_from_fqn(fqn): - return fqn[1:] + return fqn[1:] def format_field_type(type_context, field): - """Format a FieldDescriptorProto type description. + """Format a FieldDescriptorProto type description. Adds cross-refs for message types. TODO(htuch): Add cross-refs for enums as well. @@ -375,117 +375,120 @@ def format_field_type(type_context, field): field: FieldDescriptor proto. Return: RST formatted field type. """ - if field.type_name.startswith(ENVOY_API_NAMESPACE_PREFIX) or field.type_name.startswith( - ENVOY_PREFIX): - type_name = normalize_field_type_name(field.type_name) - if field.type == field.TYPE_MESSAGE: - if type_context.map_typenames and type_name_from_fqn( - field.type_name) in type_context.map_typenames: - return 'map<%s, %s>' % tuple( - map(functools.partial(format_field_type, type_context), - type_context.map_typenames[type_name_from_fqn(field.type_name)])) - return format_internal_link(type_name, message_cross_ref_label(type_name)) - if field.type == field.TYPE_ENUM: - return format_internal_link(type_name, enum_cross_ref_label(type_name)) - elif field.type_name.startswith(WKT_NAMESPACE_PREFIX): - wkt = field.type_name[len(WKT_NAMESPACE_PREFIX):] - return format_external_link( - wkt, 'https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#%s' % - wkt.lower()) - elif field.type_name.startswith(RPC_NAMESPACE_PREFIX): - rpc = field.type_name[len(RPC_NAMESPACE_PREFIX):] - return format_external_link( - rpc, - 'https://cloud.google.com/natural-language/docs/reference/rpc/google.rpc#%s' % rpc.lower()) - elif field.type_name: - return field.type_name - - pretty_type_names = { - field.TYPE_DOUBLE: 'double', - field.TYPE_FLOAT: 'float', - field.TYPE_INT32: 'int32', - field.TYPE_SFIXED32: 'int32', - field.TYPE_SINT32: 'int32', - field.TYPE_FIXED32: 'uint32', - field.TYPE_UINT32: 'uint32', - field.TYPE_INT64: 'int64', - field.TYPE_SFIXED64: 'int64', - field.TYPE_SINT64: 'int64', - field.TYPE_FIXED64: 'uint64', - field.TYPE_UINT64: 'uint64', - field.TYPE_BOOL: 'bool', - field.TYPE_STRING: 'string', - field.TYPE_BYTES: 'bytes', - } - if field.type in pretty_type_names: - return format_external_link(pretty_type_names[field.type], - 'https://developers.google.com/protocol-buffers/docs/proto#scalar') - raise ProtodocError('Unknown field type ' + str(field.type)) + if field.type_name.startswith(ENVOY_API_NAMESPACE_PREFIX) or field.type_name.startswith( + ENVOY_PREFIX): + type_name = normalize_field_type_name(field.type_name) + if field.type == field.TYPE_MESSAGE: + if type_context.map_typenames and type_name_from_fqn( + field.type_name) in type_context.map_typenames: + return 'map<%s, %s>' % tuple( + map(functools.partial(format_field_type, type_context), + type_context.map_typenames[type_name_from_fqn(field.type_name)])) + return format_internal_link(type_name, message_cross_ref_label(type_name)) + if field.type == field.TYPE_ENUM: + return format_internal_link(type_name, enum_cross_ref_label(type_name)) + elif field.type_name.startswith(WKT_NAMESPACE_PREFIX): + wkt = field.type_name[len(WKT_NAMESPACE_PREFIX):] + return format_external_link( + wkt, + 'https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#%s' % + wkt.lower()) + elif field.type_name.startswith(RPC_NAMESPACE_PREFIX): + rpc = field.type_name[len(RPC_NAMESPACE_PREFIX):] + return format_external_link( + rpc, 'https://cloud.google.com/natural-language/docs/reference/rpc/google.rpc#%s' % + rpc.lower()) + elif field.type_name: + return field.type_name + + pretty_type_names = { + field.TYPE_DOUBLE: 'double', + field.TYPE_FLOAT: 'float', + field.TYPE_INT32: 'int32', + field.TYPE_SFIXED32: 'int32', + field.TYPE_SINT32: 'int32', + field.TYPE_FIXED32: 'uint32', + field.TYPE_UINT32: 'uint32', + field.TYPE_INT64: 'int64', + field.TYPE_SFIXED64: 'int64', + field.TYPE_SINT64: 'int64', + field.TYPE_FIXED64: 'uint64', + field.TYPE_UINT64: 'uint64', + field.TYPE_BOOL: 'bool', + field.TYPE_STRING: 'string', + field.TYPE_BYTES: 'bytes', + } + if field.type in pretty_type_names: + return format_external_link( + pretty_type_names[field.type], + 'https://developers.google.com/protocol-buffers/docs/proto#scalar') + raise ProtodocError('Unknown field type ' + str(field.type)) def strip_leading_space(s): - """Remove leading space in flat comment strings.""" - return map_lines(lambda s: s[1:], s) + """Remove leading space in flat comment strings.""" + return map_lines(lambda s: s[1:], s) def file_cross_ref_label(msg_name): - """File cross reference label.""" - return 'envoy_api_file_%s' % msg_name + """File cross reference label.""" + return 'envoy_api_file_%s' % msg_name def message_cross_ref_label(msg_name): - """Message cross reference label.""" - return 'envoy_api_msg_%s' % msg_name + """Message cross reference label.""" + return 'envoy_api_msg_%s' % msg_name def enum_cross_ref_label(enum_name): - """Enum cross reference label.""" - return 'envoy_api_enum_%s' % enum_name + """Enum cross reference label.""" + return 'envoy_api_enum_%s' % enum_name def field_cross_ref_label(field_name): - """Field cross reference label.""" - return 'envoy_api_field_%s' % field_name + """Field cross reference label.""" + return 'envoy_api_field_%s' % field_name def enum_value_cross_ref_label(enum_value_name): - """Enum value cross reference label.""" - return 'envoy_api_enum_value_%s' % enum_value_name + """Enum value cross reference label.""" + return 'envoy_api_enum_value_%s' % enum_value_name def format_anchor(label): - """Format a label as an Envoy API RST anchor.""" - return '.. _%s:\n\n' % label + """Format a label as an Envoy API RST anchor.""" + return '.. _%s:\n\n' % label def format_security_options(security_option, field, type_context, edge_config): - sections = [] - - if security_option.configure_for_untrusted_downstream: + sections = [] + + if security_option.configure_for_untrusted_downstream: + sections.append( + indent(4, + 'This field should be configured in the presence of untrusted *downstreams*.')) + if security_option.configure_for_untrusted_upstream: + sections.append( + indent(4, 'This field should be configured in the presence of untrusted *upstreams*.')) + if edge_config.note: + sections.append(indent(4, edge_config.note)) + + example_dict = json_format.MessageToDict(edge_config.example) + validate_fragment.validate_fragment(field.type_name[1:], example_dict) + field_name = type_context.name.split('.')[-1] + example = {field_name: example_dict} sections.append( - indent(4, 'This field should be configured in the presence of untrusted *downstreams*.')) - if security_option.configure_for_untrusted_upstream: - sections.append( - indent(4, 'This field should be configured in the presence of untrusted *upstreams*.')) - if edge_config.note: - sections.append(indent(4, edge_config.note)) - - example_dict = json_format.MessageToDict(edge_config.example) - validate_fragment.validate_fragment(field.type_name[1:], example_dict) - field_name = type_context.name.split('.')[-1] - example = {field_name: example_dict} - sections.append( - indent(4, 'Example configuration for untrusted environments:\n\n') + - indent(4, '.. code-block:: yaml\n\n') + - '\n'.join(indent_lines(6, - yaml.dump(example).split('\n')))) + indent(4, 'Example configuration for untrusted environments:\n\n') + + indent(4, '.. code-block:: yaml\n\n') + + '\n'.join(indent_lines(6, + yaml.dump(example).split('\n')))) - return '.. attention::\n' + '\n\n'.join(sections) + return '.. attention::\n' + '\n\n'.join(sections) def format_field_as_definition_list_item(outer_type_context, type_context, field, protodoc_manifest): - """Format a FieldDescriptorProto as RST definition list item. + """Format a FieldDescriptorProto as RST definition list item. Args: outer_type_context: contextual information for enclosing message. @@ -496,72 +499,72 @@ def format_field_as_definition_list_item(outer_type_context, type_context, field Returns: RST formatted definition list item. """ - field_annotations = [] - - anchor = format_anchor(field_cross_ref_label(normalize_type_context_name(type_context.name))) - if field.options.HasExtension(validate_pb2.rules): - rule = field.options.Extensions[validate_pb2.rules] - if ((rule.HasField('message') and rule.message.required) or - (rule.HasField('duration') and rule.duration.required) or - (rule.HasField('string') and rule.string.min_len > 0) or - (rule.HasField('string') and rule.string.min_bytes > 0) or - (rule.HasField('repeated') and rule.repeated.min_items > 0)): - field_annotations = ['*REQUIRED*'] - leading_comment = type_context.leading_comment - formatted_leading_comment = format_comment_with_annotations(leading_comment) - if hide_not_implemented(leading_comment): - return '' + field_annotations = [] + + anchor = format_anchor(field_cross_ref_label(normalize_type_context_name(type_context.name))) + if field.options.HasExtension(validate_pb2.rules): + rule = field.options.Extensions[validate_pb2.rules] + if ((rule.HasField('message') and rule.message.required) or + (rule.HasField('duration') and rule.duration.required) or + (rule.HasField('string') and rule.string.min_len > 0) or + (rule.HasField('string') and rule.string.min_bytes > 0) or + (rule.HasField('repeated') and rule.repeated.min_items > 0)): + field_annotations = ['*REQUIRED*'] + leading_comment = type_context.leading_comment + formatted_leading_comment = format_comment_with_annotations(leading_comment) + if hide_not_implemented(leading_comment): + return '' - if field.HasField('oneof_index'): - oneof_context = outer_type_context.extend_oneof(field.oneof_index, - type_context.oneof_names[field.oneof_index]) - oneof_comment = oneof_context.leading_comment - formatted_oneof_comment = format_comment_with_annotations(oneof_comment) - if hide_not_implemented(oneof_comment): - return '' - - # If the oneof only has one field and marked required, mark the field as required. - if len(type_context.oneof_fields[field.oneof_index]) == 1 and type_context.oneof_required[ - field.oneof_index]: - field_annotations = ['*REQUIRED*'] - - if len(type_context.oneof_fields[field.oneof_index]) > 1: - # Fields in oneof shouldn't be marked as required when we have oneof comment below it. - field_annotations = [] - oneof_template = '\nPrecisely one of %s must be set.\n' if type_context.oneof_required[ - field.oneof_index] else '\nOnly one of %s may be set.\n' - formatted_oneof_comment += oneof_template % ', '.join( - format_internal_link( - f, - field_cross_ref_label( - normalize_type_context_name(outer_type_context.extend_field(i, f).name))) - for i, f in type_context.oneof_fields[field.oneof_index]) - else: - formatted_oneof_comment = '' - - # If there is a udpa.annotations.security option, include it after the comment. - if field.options.HasExtension(security_pb2.security): - manifest_description = protodoc_manifest.fields.get(type_context.name) - if not manifest_description: - raise ProtodocError('Missing protodoc manifest YAML for %s' % type_context.name) - formatted_security_options = format_security_options( - field.options.Extensions[security_pb2.security], field, type_context, - manifest_description.edge_config) - else: - formatted_security_options = '' - pretty_label_names = { - field.LABEL_OPTIONAL: '', - field.LABEL_REPEATED: '**repeated** ', - } - comment = '(%s) ' % ', '.join( - [pretty_label_names[field.label] + format_field_type(type_context, field)] + - field_annotations) + formatted_leading_comment - return anchor + field.name + '\n' + map_lines(functools.partial( - indent, 2), comment + formatted_oneof_comment) + formatted_security_options + if field.HasField('oneof_index'): + oneof_context = outer_type_context.extend_oneof(field.oneof_index, + type_context.oneof_names[field.oneof_index]) + oneof_comment = oneof_context.leading_comment + formatted_oneof_comment = format_comment_with_annotations(oneof_comment) + if hide_not_implemented(oneof_comment): + return '' + + # If the oneof only has one field and marked required, mark the field as required. + if len(type_context.oneof_fields[field.oneof_index]) == 1 and type_context.oneof_required[ + field.oneof_index]: + field_annotations = ['*REQUIRED*'] + + if len(type_context.oneof_fields[field.oneof_index]) > 1: + # Fields in oneof shouldn't be marked as required when we have oneof comment below it. + field_annotations = [] + oneof_template = '\nPrecisely one of %s must be set.\n' if type_context.oneof_required[ + field.oneof_index] else '\nOnly one of %s may be set.\n' + formatted_oneof_comment += oneof_template % ', '.join( + format_internal_link( + f, + field_cross_ref_label( + normalize_type_context_name(outer_type_context.extend_field(i, f).name))) + for i, f in type_context.oneof_fields[field.oneof_index]) + else: + formatted_oneof_comment = '' + + # If there is a udpa.annotations.security option, include it after the comment. + if field.options.HasExtension(security_pb2.security): + manifest_description = protodoc_manifest.fields.get(type_context.name) + if not manifest_description: + raise ProtodocError('Missing protodoc manifest YAML for %s' % type_context.name) + formatted_security_options = format_security_options( + field.options.Extensions[security_pb2.security], field, type_context, + manifest_description.edge_config) + else: + formatted_security_options = '' + pretty_label_names = { + field.LABEL_OPTIONAL: '', + field.LABEL_REPEATED: '**repeated** ', + } + comment = '(%s) ' % ', '.join( + [pretty_label_names[field.label] + format_field_type(type_context, field)] + + field_annotations) + formatted_leading_comment + return anchor + field.name + '\n' + map_lines(functools.partial( + indent, 2), comment + formatted_oneof_comment) + formatted_security_options def format_message_as_definition_list(type_context, msg, protodoc_manifest): - """Format a DescriptorProto as RST definition list. + """Format a DescriptorProto as RST definition list. Args: type_context: contextual information for message/enum/field. @@ -571,27 +574,28 @@ def format_message_as_definition_list(type_context, msg, protodoc_manifest): Returns: RST formatted definition list item. """ - type_context.oneof_fields = defaultdict(list) - type_context.oneof_required = defaultdict(bool) - type_context.oneof_names = defaultdict(list) - for index, field in enumerate(msg.field): - if field.HasField('oneof_index'): - leading_comment = type_context.extend_field(index, field.name).leading_comment - if hide_not_implemented(leading_comment): - continue - type_context.oneof_fields[field.oneof_index].append((index, field.name)) - for index, oneof_decl in enumerate(msg.oneof_decl): - if oneof_decl.options.HasExtension(validate_pb2.required): - type_context.oneof_required[index] = oneof_decl.options.Extensions[validate_pb2.required] - type_context.oneof_names[index] = oneof_decl.name - return '\n'.join( - format_field_as_definition_list_item(type_context, type_context.extend_field( - index, field.name), field, protodoc_manifest) - for index, field in enumerate(msg.field)) + '\n' + type_context.oneof_fields = defaultdict(list) + type_context.oneof_required = defaultdict(bool) + type_context.oneof_names = defaultdict(list) + for index, field in enumerate(msg.field): + if field.HasField('oneof_index'): + leading_comment = type_context.extend_field(index, field.name).leading_comment + if hide_not_implemented(leading_comment): + continue + type_context.oneof_fields[field.oneof_index].append((index, field.name)) + for index, oneof_decl in enumerate(msg.oneof_decl): + if oneof_decl.options.HasExtension(validate_pb2.required): + type_context.oneof_required[index] = oneof_decl.options.Extensions[ + validate_pb2.required] + type_context.oneof_names[index] = oneof_decl.name + return '\n'.join( + format_field_as_definition_list_item( + type_context, type_context.extend_field(index, field.name), field, protodoc_manifest) + for index, field in enumerate(msg.field)) + '\n' def format_enum_value_as_definition_list_item(type_context, enum_value): - """Format a EnumValueDescriptorProto as RST definition list item. + """Format a EnumValueDescriptorProto as RST definition list item. Args: type_context: contextual information for message/enum/field. @@ -600,18 +604,19 @@ def format_enum_value_as_definition_list_item(type_context, enum_value): Returns: RST formatted definition list item. """ - anchor = format_anchor(enum_value_cross_ref_label(normalize_type_context_name(type_context.name))) - default_comment = '*(DEFAULT)* ' if enum_value.number == 0 else '' - leading_comment = type_context.leading_comment - formatted_leading_comment = format_comment_with_annotations(leading_comment) - if hide_not_implemented(leading_comment): - return '' - comment = default_comment + UNICODE_INVISIBLE_SEPARATOR + formatted_leading_comment - return anchor + enum_value.name + '\n' + map_lines(functools.partial(indent, 2), comment) + anchor = format_anchor( + enum_value_cross_ref_label(normalize_type_context_name(type_context.name))) + default_comment = '*(DEFAULT)* ' if enum_value.number == 0 else '' + leading_comment = type_context.leading_comment + formatted_leading_comment = format_comment_with_annotations(leading_comment) + if hide_not_implemented(leading_comment): + return '' + comment = default_comment + UNICODE_INVISIBLE_SEPARATOR + formatted_leading_comment + return anchor + enum_value.name + '\n' + map_lines(functools.partial(indent, 2), comment) def format_enum_as_definition_list(type_context, enum): - """Format a EnumDescriptorProto as RST definition list. + """Format a EnumDescriptorProto as RST definition list. Args: type_context: contextual information for message/enum/field. @@ -620,100 +625,102 @@ def format_enum_as_definition_list(type_context, enum): Returns: RST formatted definition list item. """ - return '\n'.join( - format_enum_value_as_definition_list_item( - type_context.extend_enum_value(index, enum_value.name), enum_value) - for index, enum_value in enumerate(enum.value)) + '\n' + return '\n'.join( + format_enum_value_as_definition_list_item( + type_context.extend_enum_value(index, enum_value.name), enum_value) + for index, enum_value in enumerate(enum.value)) + '\n' def format_proto_as_block_comment(proto): - """Format a proto as a RST block comment. + """Format a proto as a RST block comment. Useful in debugging, not usually referenced. """ - return '\n\nproto::\n\n' + map_lines(functools.partial(indent, 2), str(proto)) + '\n' + return '\n\nproto::\n\n' + map_lines(functools.partial(indent, 2), str(proto)) + '\n' class RstFormatVisitor(visitor.Visitor): - """Visitor to generate a RST representation from a FileDescriptor proto. + """Visitor to generate a RST representation from a FileDescriptor proto. See visitor.Visitor for visitor method docs comments. """ - def __init__(self): - r = runfiles.Create() - with open(r.Rlocation('envoy/docs/protodoc_manifest.yaml'), 'r') as f: - # Load as YAML, emit as JSON and then parse as proto to provide type - # checking. - protodoc_manifest_untyped = yaml.safe_load(f.read()) - self.protodoc_manifest = manifest_pb2.Manifest() - json_format.Parse(json.dumps(protodoc_manifest_untyped), self.protodoc_manifest) - - def visit_enum(self, enum_proto, type_context): - normal_enum_type = normalize_type_context_name(type_context.name) - anchor = format_anchor(enum_cross_ref_label(normal_enum_type)) - header = format_header('-', 'Enum %s' % normal_enum_type) - _github_url = github_url(type_context) - proto_link = format_external_link('[%s proto]' % normal_enum_type, _github_url) + '\n\n' - leading_comment = type_context.leading_comment - formatted_leading_comment = format_comment_with_annotations(leading_comment, 'enum') - if hide_not_implemented(leading_comment): - return '' - return anchor + header + proto_link + formatted_leading_comment + format_enum_as_definition_list( - type_context, enum_proto) - - def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): - # Skip messages synthesized to represent map types. - if msg_proto.options.map_entry: - return '' - normal_msg_type = normalize_type_context_name(type_context.name) - anchor = format_anchor(message_cross_ref_label(normal_msg_type)) - header = format_header('-', normal_msg_type) - _github_url = github_url(type_context) - proto_link = format_external_link('[%s proto]' % normal_msg_type, _github_url) + '\n\n' - leading_comment = type_context.leading_comment - formatted_leading_comment = format_comment_with_annotations(leading_comment, 'message') - if hide_not_implemented(leading_comment): - return '' - return anchor + header + proto_link + formatted_leading_comment + format_message_as_json( - type_context, msg_proto) + format_message_as_definition_list( - type_context, msg_proto, - self.protodoc_manifest) + '\n'.join(nested_msgs) + '\n' + '\n'.join(nested_enums) - - def visit_file(self, file_proto, type_context, services, msgs, enums): - has_messages = True - if all(len(msg) == 0 for msg in msgs) and all(len(enum) == 0 for enum in enums): - has_messages = False - - # TODO(mattklein123): The logic in both the doc and transform tool around files without messages - # is confusing and should be cleaned up. This is a stop gap to have titles for all proto docs - # in the common case. - if (has_messages and - not annotations.DOC_TITLE_ANNOTATION in type_context.source_code_info.file_level_annotations - and file_proto.name.startswith('envoy')): - raise ProtodocError('Envoy API proto file missing [#protodoc-title:] annotation: {}'.format( - file_proto.name)) - - # Find the earliest detached comment, attribute it to file level. - # Also extract file level titles if any. - header, comment = format_header_from_file('=', type_context.source_code_info, file_proto.name) - # If there are no messages, we don't include in the doc tree (no support for - # service rendering yet). We allow these files to be missing from the - # toctrees. - if not has_messages: - header = ':orphan:\n\n' + header - warnings = '' - if file_proto.options.HasExtension(status_pb2.file_status): - if file_proto.options.Extensions[status_pb2.file_status].work_in_progress: - warnings += ('.. warning::\n This API is work-in-progress and is ' - 'subject to breaking changes.\n\n') - debug_proto = format_proto_as_block_comment(file_proto) - return header + warnings + comment + '\n'.join(msgs) + '\n'.join(enums) # + debug_proto + def __init__(self): + r = runfiles.Create() + with open(r.Rlocation('envoy/docs/protodoc_manifest.yaml'), 'r') as f: + # Load as YAML, emit as JSON and then parse as proto to provide type + # checking. + protodoc_manifest_untyped = yaml.safe_load(f.read()) + self.protodoc_manifest = manifest_pb2.Manifest() + json_format.Parse(json.dumps(protodoc_manifest_untyped), self.protodoc_manifest) + + def visit_enum(self, enum_proto, type_context): + normal_enum_type = normalize_type_context_name(type_context.name) + anchor = format_anchor(enum_cross_ref_label(normal_enum_type)) + header = format_header('-', 'Enum %s' % normal_enum_type) + _github_url = github_url(type_context) + proto_link = format_external_link('[%s proto]' % normal_enum_type, _github_url) + '\n\n' + leading_comment = type_context.leading_comment + formatted_leading_comment = format_comment_with_annotations(leading_comment, 'enum') + if hide_not_implemented(leading_comment): + return '' + return anchor + header + proto_link + formatted_leading_comment + format_enum_as_definition_list( + type_context, enum_proto) + + def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): + # Skip messages synthesized to represent map types. + if msg_proto.options.map_entry: + return '' + normal_msg_type = normalize_type_context_name(type_context.name) + anchor = format_anchor(message_cross_ref_label(normal_msg_type)) + header = format_header('-', normal_msg_type) + _github_url = github_url(type_context) + proto_link = format_external_link('[%s proto]' % normal_msg_type, _github_url) + '\n\n' + leading_comment = type_context.leading_comment + formatted_leading_comment = format_comment_with_annotations(leading_comment, 'message') + if hide_not_implemented(leading_comment): + return '' + return anchor + header + proto_link + formatted_leading_comment + format_message_as_json( + type_context, msg_proto) + format_message_as_definition_list( + type_context, msg_proto, + self.protodoc_manifest) + '\n'.join(nested_msgs) + '\n' + '\n'.join(nested_enums) + + def visit_file(self, file_proto, type_context, services, msgs, enums): + has_messages = True + if all(len(msg) == 0 for msg in msgs) and all(len(enum) == 0 for enum in enums): + has_messages = False + + # TODO(mattklein123): The logic in both the doc and transform tool around files without messages + # is confusing and should be cleaned up. This is a stop gap to have titles for all proto docs + # in the common case. + if (has_messages and not annotations.DOC_TITLE_ANNOTATION + in type_context.source_code_info.file_level_annotations and + file_proto.name.startswith('envoy')): + raise ProtodocError( + 'Envoy API proto file missing [#protodoc-title:] annotation: {}'.format( + file_proto.name)) + + # Find the earliest detached comment, attribute it to file level. + # Also extract file level titles if any. + header, comment = format_header_from_file('=', type_context.source_code_info, + file_proto.name) + # If there are no messages, we don't include in the doc tree (no support for + # service rendering yet). We allow these files to be missing from the + # toctrees. + if not has_messages: + header = ':orphan:\n\n' + header + warnings = '' + if file_proto.options.HasExtension(status_pb2.file_status): + if file_proto.options.Extensions[status_pb2.file_status].work_in_progress: + warnings += ('.. warning::\n This API is work-in-progress and is ' + 'subject to breaking changes.\n\n') + debug_proto = format_proto_as_block_comment(file_proto) + return header + warnings + comment + '\n'.join(msgs) + '\n'.join(enums) # + debug_proto def main(): - plugin.plugin([plugin.direct_output_descriptor('.rst', RstFormatVisitor)]) + plugin.plugin([plugin.direct_output_descriptor('.rst', RstFormatVisitor)]) if __name__ == '__main__': - main() + main() diff --git a/tools/protoxform/merge_active_shadow.py b/tools/protoxform/merge_active_shadow.py index fdb7dba49d81..9796793a8391 100644 --- a/tools/protoxform/merge_active_shadow.py +++ b/tools/protoxform/merge_active_shadow.py @@ -29,211 +29,211 @@ # Set reserved_range in target_proto to reflect previous_reserved_range skipping # skip_reserved_numbers. def adjust_reserved_range(target_proto, previous_reserved_range, skip_reserved_numbers): - del target_proto.reserved_range[:] - for rr in previous_reserved_range: - # We can only handle singleton ranges today. - assert ((rr.start == rr.end) or (rr.end == rr.start + 1)) - if rr.start not in skip_reserved_numbers: - target_proto.reserved_range.add().MergeFrom(rr) + del target_proto.reserved_range[:] + for rr in previous_reserved_range: + # We can only handle singleton ranges today. + assert ((rr.start == rr.end) or (rr.end == rr.start + 1)) + if rr.start not in skip_reserved_numbers: + target_proto.reserved_range.add().MergeFrom(rr) # Add dependencies for envoy.annotations.disallowed_by_default def add_deprecation_dependencies(target_proto_dependencies, proto_field, is_enum): - if is_enum: - if proto_field.options.HasExtension(deprecation_pb2.disallowed_by_default_enum) and \ - "envoy/annotations/deprecation.proto" not in target_proto_dependencies: - target_proto_dependencies.append("envoy/annotations/deprecation.proto") - else: - if proto_field.options.HasExtension(deprecation_pb2.disallowed_by_default) and \ - "envoy/annotations/deprecation.proto" not in target_proto_dependencies: - target_proto_dependencies.append("envoy/annotations/deprecation.proto") - if proto_field.type_name == ".google.protobuf.Struct" and \ - "google/protobuf/struct.proto" not in target_proto_dependencies: - target_proto_dependencies.append("google/protobuf/struct.proto") + if is_enum: + if proto_field.options.HasExtension(deprecation_pb2.disallowed_by_default_enum) and \ + "envoy/annotations/deprecation.proto" not in target_proto_dependencies: + target_proto_dependencies.append("envoy/annotations/deprecation.proto") + else: + if proto_field.options.HasExtension(deprecation_pb2.disallowed_by_default) and \ + "envoy/annotations/deprecation.proto" not in target_proto_dependencies: + target_proto_dependencies.append("envoy/annotations/deprecation.proto") + if proto_field.type_name == ".google.protobuf.Struct" and \ + "google/protobuf/struct.proto" not in target_proto_dependencies: + target_proto_dependencies.append("google/protobuf/struct.proto") # Merge active/shadow EnumDescriptorProtos to a fresh target EnumDescriptorProto. def merge_active_shadow_enum(active_proto, shadow_proto, target_proto, target_proto_dependencies): - target_proto.MergeFrom(active_proto) - if not shadow_proto: - return - shadow_values = {v.name: v for v in shadow_proto.value} - skip_reserved_numbers = [] - # For every reserved name, check to see if it's in the shadow, and if so, - # reintroduce in target_proto. - del target_proto.reserved_name[:] - for n in active_proto.reserved_name: - hidden_n = 'hidden_envoy_deprecated_' + n - if hidden_n in shadow_values: - v = shadow_values[hidden_n] - add_deprecation_dependencies(target_proto_dependencies, v, True) - skip_reserved_numbers.append(v.number) - target_proto.value.add().MergeFrom(v) - else: - target_proto.reserved_name.append(n) - adjust_reserved_range(target_proto, active_proto.reserved_range, skip_reserved_numbers) - # Special fixup for deprecation of default enum values. - for tv in target_proto.value: - if tv.name == 'DEPRECATED_AND_UNAVAILABLE_DO_NOT_USE': - for sv in shadow_proto.value: - if sv.number == tv.number: - assert (sv.number == 0) - tv.CopyFrom(sv) + target_proto.MergeFrom(active_proto) + if not shadow_proto: + return + shadow_values = {v.name: v for v in shadow_proto.value} + skip_reserved_numbers = [] + # For every reserved name, check to see if it's in the shadow, and if so, + # reintroduce in target_proto. + del target_proto.reserved_name[:] + for n in active_proto.reserved_name: + hidden_n = 'hidden_envoy_deprecated_' + n + if hidden_n in shadow_values: + v = shadow_values[hidden_n] + add_deprecation_dependencies(target_proto_dependencies, v, True) + skip_reserved_numbers.append(v.number) + target_proto.value.add().MergeFrom(v) + else: + target_proto.reserved_name.append(n) + adjust_reserved_range(target_proto, active_proto.reserved_range, skip_reserved_numbers) + # Special fixup for deprecation of default enum values. + for tv in target_proto.value: + if tv.name == 'DEPRECATED_AND_UNAVAILABLE_DO_NOT_USE': + for sv in shadow_proto.value: + if sv.number == tv.number: + assert (sv.number == 0) + tv.CopyFrom(sv) # Adjust source code info comments path to reflect insertions of oneof fields # inside the middle of an existing collection of fields. def adjust_source_code_info(type_context, field_index, field_adjustment): - def has_path_prefix(s, t): - return len(s) <= len(t) and all(p[0] == p[1] for p in zip(s, t)) + def has_path_prefix(s, t): + return len(s) <= len(t) and all(p[0] == p[1] for p in zip(s, t)) - for loc in type_context.source_code_info.proto.location: - if has_path_prefix(type_context.path + [2], loc.path): - path_field_index = len(type_context.path) + 1 - if path_field_index < len(loc.path) and loc.path[path_field_index] >= field_index: - loc.path[path_field_index] += field_adjustment + for loc in type_context.source_code_info.proto.location: + if has_path_prefix(type_context.path + [2], loc.path): + path_field_index = len(type_context.path) + 1 + if path_field_index < len(loc.path) and loc.path[path_field_index] >= field_index: + loc.path[path_field_index] += field_adjustment # Merge active/shadow DescriptorProtos to a fresh target DescriptorProto. def merge_active_shadow_message(type_context, active_proto, shadow_proto, target_proto, target_proto_dependencies): - target_proto.MergeFrom(active_proto) - if not shadow_proto: - return - shadow_fields = {f.name: f for f in shadow_proto.field} - skip_reserved_numbers = [] - # For every reserved name, check to see if it's in the shadow, and if so, - # reintroduce in target_proto. We track both the normal fields we need to add - # back in (extra_simple_fields) and those that belong to oneofs - # (extra_oneof_fields). The latter require special treatment, as we can't just - # append them to the end of the message, they need to be reordered. - extra_simple_fields = [] - extra_oneof_fields = defaultdict(list) # oneof index -> list of fields - del target_proto.reserved_name[:] - for n in active_proto.reserved_name: - hidden_n = 'hidden_envoy_deprecated_' + n - if hidden_n in shadow_fields: - f = shadow_fields[hidden_n] - add_deprecation_dependencies(target_proto_dependencies, f, False) - skip_reserved_numbers.append(f.number) - missing_field = copy.deepcopy(f) - # oneof fields from the shadow need to have their index set to the - # corresponding index in active/target_proto. - if missing_field.HasField('oneof_index'): - oneof_name = shadow_proto.oneof_decl[missing_field.oneof_index].name - missing_oneof_index = None - for oneof_index, oneof_decl in enumerate(target_proto.oneof_decl): - if oneof_decl.name == oneof_name: - missing_oneof_index = oneof_index - if missing_oneof_index is None: - missing_oneof_index = len(target_proto.oneof_decl) - target_proto.oneof_decl.add().MergeFrom( - shadow_proto.oneof_decl[missing_field.oneof_index]) - missing_field.oneof_index = missing_oneof_index - extra_oneof_fields[missing_oneof_index].append(missing_field) - else: - extra_simple_fields.append(missing_field) - else: - target_proto.reserved_name.append(n) - # Copy existing fields, as we need to nuke them. - existing_fields = copy.deepcopy(target_proto.field) - del target_proto.field[:] - # Rebuild fields, taking into account extra_oneof_fields. protoprint.py - # expects that oneof fields are consecutive, so need to sort for this. - current_oneof_index = None - - def append_extra_oneof_fields(current_oneof_index, last_oneof_field_index): - # Add fields from extra_oneof_fields for current_oneof_index. - for oneof_f in extra_oneof_fields[current_oneof_index]: - target_proto.field.add().MergeFrom(oneof_f) - field_adjustment = len(extra_oneof_fields[current_oneof_index]) - # Fixup the comments in source code info. Note that this is really - # inefficient, O(N^2) in the worst case, but since we have relatively few - # deprecated fields, is the easiest to implement method. - if last_oneof_field_index is not None: - adjust_source_code_info(type_context, last_oneof_field_index, field_adjustment) - del extra_oneof_fields[current_oneof_index] - return field_adjustment - - field_index = 0 - for f in existing_fields: + target_proto.MergeFrom(active_proto) + if not shadow_proto: + return + shadow_fields = {f.name: f for f in shadow_proto.field} + skip_reserved_numbers = [] + # For every reserved name, check to see if it's in the shadow, and if so, + # reintroduce in target_proto. We track both the normal fields we need to add + # back in (extra_simple_fields) and those that belong to oneofs + # (extra_oneof_fields). The latter require special treatment, as we can't just + # append them to the end of the message, they need to be reordered. + extra_simple_fields = [] + extra_oneof_fields = defaultdict(list) # oneof index -> list of fields + del target_proto.reserved_name[:] + for n in active_proto.reserved_name: + hidden_n = 'hidden_envoy_deprecated_' + n + if hidden_n in shadow_fields: + f = shadow_fields[hidden_n] + add_deprecation_dependencies(target_proto_dependencies, f, False) + skip_reserved_numbers.append(f.number) + missing_field = copy.deepcopy(f) + # oneof fields from the shadow need to have their index set to the + # corresponding index in active/target_proto. + if missing_field.HasField('oneof_index'): + oneof_name = shadow_proto.oneof_decl[missing_field.oneof_index].name + missing_oneof_index = None + for oneof_index, oneof_decl in enumerate(target_proto.oneof_decl): + if oneof_decl.name == oneof_name: + missing_oneof_index = oneof_index + if missing_oneof_index is None: + missing_oneof_index = len(target_proto.oneof_decl) + target_proto.oneof_decl.add().MergeFrom( + shadow_proto.oneof_decl[missing_field.oneof_index]) + missing_field.oneof_index = missing_oneof_index + extra_oneof_fields[missing_oneof_index].append(missing_field) + else: + extra_simple_fields.append(missing_field) + else: + target_proto.reserved_name.append(n) + # Copy existing fields, as we need to nuke them. + existing_fields = copy.deepcopy(target_proto.field) + del target_proto.field[:] + # Rebuild fields, taking into account extra_oneof_fields. protoprint.py + # expects that oneof fields are consecutive, so need to sort for this. + current_oneof_index = None + + def append_extra_oneof_fields(current_oneof_index, last_oneof_field_index): + # Add fields from extra_oneof_fields for current_oneof_index. + for oneof_f in extra_oneof_fields[current_oneof_index]: + target_proto.field.add().MergeFrom(oneof_f) + field_adjustment = len(extra_oneof_fields[current_oneof_index]) + # Fixup the comments in source code info. Note that this is really + # inefficient, O(N^2) in the worst case, but since we have relatively few + # deprecated fields, is the easiest to implement method. + if last_oneof_field_index is not None: + adjust_source_code_info(type_context, last_oneof_field_index, field_adjustment) + del extra_oneof_fields[current_oneof_index] + return field_adjustment + + field_index = 0 + for f in existing_fields: + if current_oneof_index is not None: + field_oneof_index = f.oneof_index if f.HasField('oneof_index') else None + # Are we exiting the oneof? If so, add the respective extra_one_fields. + if field_oneof_index != current_oneof_index: + field_index += append_extra_oneof_fields(current_oneof_index, field_index) + current_oneof_index = field_oneof_index + elif f.HasField('oneof_index'): + current_oneof_index = f.oneof_index + target_proto.field.add().MergeFrom(f) + field_index += 1 if current_oneof_index is not None: - field_oneof_index = f.oneof_index if f.HasField('oneof_index') else None - # Are we exiting the oneof? If so, add the respective extra_one_fields. - if field_oneof_index != current_oneof_index: - field_index += append_extra_oneof_fields(current_oneof_index, field_index) - current_oneof_index = field_oneof_index - elif f.HasField('oneof_index'): - current_oneof_index = f.oneof_index - target_proto.field.add().MergeFrom(f) - field_index += 1 - if current_oneof_index is not None: - # No need to adjust source code info here, since there are no comments for - # trailing deprecated fields, so just set field index to None. - append_extra_oneof_fields(current_oneof_index, None) - # Non-oneof fields are easy to treat, we just append them to the existing - # fields. They don't get any comments, but that's fine in the generated - # shadows. - for f in extra_simple_fields: - target_proto.field.add().MergeFrom(f) - for oneof_index in sorted(extra_oneof_fields.keys()): - for f in extra_oneof_fields[oneof_index]: - target_proto.field.add().MergeFrom(f) - # Same is true for oneofs that are exclusively from the shadow. - adjust_reserved_range(target_proto, active_proto.reserved_range, skip_reserved_numbers) - # Visit nested message types - del target_proto.nested_type[:] - shadow_msgs = {msg.name: msg for msg in shadow_proto.nested_type} - for index, msg in enumerate(active_proto.nested_type): - merge_active_shadow_message( - type_context.extend_nested_message(index, msg.name, msg.options.deprecated), msg, - shadow_msgs.get(msg.name), target_proto.nested_type.add(), target_proto_dependencies) - # Visit nested enum types - del target_proto.enum_type[:] - shadow_enums = {msg.name: msg for msg in shadow_proto.enum_type} - for enum in active_proto.enum_type: - merge_active_shadow_enum(enum, shadow_enums.get(enum.name), target_proto.enum_type.add(), - target_proto_dependencies) - # Ensure target has any deprecated sub-message types in case they are needed. - active_msg_names = set([msg.name for msg in active_proto.nested_type]) - for msg in shadow_proto.nested_type: - if msg.name not in active_msg_names: - target_proto.nested_type.add().MergeFrom(msg) + # No need to adjust source code info here, since there are no comments for + # trailing deprecated fields, so just set field index to None. + append_extra_oneof_fields(current_oneof_index, None) + # Non-oneof fields are easy to treat, we just append them to the existing + # fields. They don't get any comments, but that's fine in the generated + # shadows. + for f in extra_simple_fields: + target_proto.field.add().MergeFrom(f) + for oneof_index in sorted(extra_oneof_fields.keys()): + for f in extra_oneof_fields[oneof_index]: + target_proto.field.add().MergeFrom(f) + # Same is true for oneofs that are exclusively from the shadow. + adjust_reserved_range(target_proto, active_proto.reserved_range, skip_reserved_numbers) + # Visit nested message types + del target_proto.nested_type[:] + shadow_msgs = {msg.name: msg for msg in shadow_proto.nested_type} + for index, msg in enumerate(active_proto.nested_type): + merge_active_shadow_message( + type_context.extend_nested_message(index, msg.name, msg.options.deprecated), msg, + shadow_msgs.get(msg.name), target_proto.nested_type.add(), target_proto_dependencies) + # Visit nested enum types + del target_proto.enum_type[:] + shadow_enums = {msg.name: msg for msg in shadow_proto.enum_type} + for enum in active_proto.enum_type: + merge_active_shadow_enum(enum, shadow_enums.get(enum.name), target_proto.enum_type.add(), + target_proto_dependencies) + # Ensure target has any deprecated sub-message types in case they are needed. + active_msg_names = set([msg.name for msg in active_proto.nested_type]) + for msg in shadow_proto.nested_type: + if msg.name not in active_msg_names: + target_proto.nested_type.add().MergeFrom(msg) # Merge active/shadow FileDescriptorProtos, returning the resulting FileDescriptorProto. def merge_active_shadow_file(active_file_proto, shadow_file_proto): - target_file_proto = copy.deepcopy(active_file_proto) - source_code_info = api_type_context.SourceCodeInfo(target_file_proto.name, - target_file_proto.source_code_info) - package_type_context = api_type_context.TypeContext(source_code_info, target_file_proto.package) - # Visit message types - del target_file_proto.message_type[:] - shadow_msgs = {msg.name: msg for msg in shadow_file_proto.message_type} - for index, msg in enumerate(active_file_proto.message_type): - merge_active_shadow_message( - package_type_context.extend_message(index, msg.name, msg.options.deprecated), msg, - shadow_msgs.get(msg.name), target_file_proto.message_type.add(), - target_file_proto.dependency) - # Visit enum types - del target_file_proto.enum_type[:] - shadow_enums = {msg.name: msg for msg in shadow_file_proto.enum_type} - for enum in active_file_proto.enum_type: - merge_active_shadow_enum(enum, shadow_enums.get(enum.name), target_file_proto.enum_type.add(), - target_file_proto.dependency) - # Ensure target has any deprecated message types in case they are needed. - active_msg_names = set([msg.name for msg in active_file_proto.message_type]) - for msg in shadow_file_proto.message_type: - if msg.name not in active_msg_names: - target_file_proto.message_type.add().MergeFrom(msg) - return target_file_proto + target_file_proto = copy.deepcopy(active_file_proto) + source_code_info = api_type_context.SourceCodeInfo(target_file_proto.name, + target_file_proto.source_code_info) + package_type_context = api_type_context.TypeContext(source_code_info, target_file_proto.package) + # Visit message types + del target_file_proto.message_type[:] + shadow_msgs = {msg.name: msg for msg in shadow_file_proto.message_type} + for index, msg in enumerate(active_file_proto.message_type): + merge_active_shadow_message( + package_type_context.extend_message(index, msg.name, msg.options.deprecated), msg, + shadow_msgs.get(msg.name), target_file_proto.message_type.add(), + target_file_proto.dependency) + # Visit enum types + del target_file_proto.enum_type[:] + shadow_enums = {msg.name: msg for msg in shadow_file_proto.enum_type} + for enum in active_file_proto.enum_type: + merge_active_shadow_enum(enum, shadow_enums.get(enum.name), + target_file_proto.enum_type.add(), target_file_proto.dependency) + # Ensure target has any deprecated message types in case they are needed. + active_msg_names = set([msg.name for msg in active_file_proto.message_type]) + for msg in shadow_file_proto.message_type: + if msg.name not in active_msg_names: + target_file_proto.message_type.add().MergeFrom(msg) + return target_file_proto if __name__ == '__main__': - active_src, shadow_src, dst = sys.argv[1:] - active_proto = descriptor_pb2.FileDescriptorProto() - text_format.Merge(pathlib.Path(active_src).read_text(), active_proto) - shadow_proto = descriptor_pb2.FileDescriptorProto() - text_format.Merge(pathlib.Path(shadow_src).read_text(), shadow_proto) - pathlib.Path(dst).write_text(str(merge_active_shadow_file(active_proto, shadow_proto))) + active_src, shadow_src, dst = sys.argv[1:] + active_proto = descriptor_pb2.FileDescriptorProto() + text_format.Merge(pathlib.Path(active_src).read_text(), active_proto) + shadow_proto = descriptor_pb2.FileDescriptorProto() + text_format.Merge(pathlib.Path(shadow_src).read_text(), shadow_proto) + pathlib.Path(dst).write_text(str(merge_active_shadow_file(active_proto, shadow_proto))) diff --git a/tools/protoxform/merge_active_shadow_test.py b/tools/protoxform/merge_active_shadow_test.py index 42c1d683c11b..6aee32ec3104 100644 --- a/tools/protoxform/merge_active_shadow_test.py +++ b/tools/protoxform/merge_active_shadow_test.py @@ -9,20 +9,20 @@ class MergeActiveShadowTest(unittest.TestCase): - # Dummy type context for tests that don't care about this. - def fake_type_context(self): - fake_source_code_info = descriptor_pb2.SourceCodeInfo() - source_code_info = api_type_context.SourceCodeInfo('fake', fake_source_code_info) - return api_type_context.TypeContext(source_code_info, 'fake_package') + # Dummy type context for tests that don't care about this. + def fake_type_context(self): + fake_source_code_info = descriptor_pb2.SourceCodeInfo() + source_code_info = api_type_context.SourceCodeInfo('fake', fake_source_code_info) + return api_type_context.TypeContext(source_code_info, 'fake_package') - # Poor man's text proto equivalence. Tensorflow has better tools for this, - # i.e. assertProto2Equal. - def assert_text_proto_eq(self, lhs, rhs): - self.assertMultiLineEqual(lhs.strip(), rhs.strip()) + # Poor man's text proto equivalence. Tensorflow has better tools for this, + # i.e. assertProto2Equal. + def assert_text_proto_eq(self, lhs, rhs): + self.assertMultiLineEqual(lhs.strip(), rhs.strip()) - def testadjust_reserved_range(self): - """adjust_reserved_range removes specified skip_reserved_numbers.""" - desc_pb_text = """ + def testadjust_reserved_range(self): + """adjust_reserved_range removes specified skip_reserved_numbers.""" + desc_pb_text = """ reserved_range { start: 41 end: 41 @@ -40,11 +40,11 @@ def testadjust_reserved_range(self): end: 51 } """ - desc = descriptor_pb2.DescriptorProto() - text_format.Merge(desc_pb_text, desc) - target = descriptor_pb2.DescriptorProto() - merge_active_shadow.adjust_reserved_range(target, desc.reserved_range, [42, 43]) - target_pb_text = """ + desc = descriptor_pb2.DescriptorProto() + text_format.Merge(desc_pb_text, desc) + target = descriptor_pb2.DescriptorProto() + merge_active_shadow.adjust_reserved_range(target, desc.reserved_range, [42, 43]) + target_pb_text = """ reserved_range { start: 41 end: 41 @@ -54,11 +54,11 @@ def testadjust_reserved_range(self): end: 51 } """ - self.assert_text_proto_eq(target_pb_text, str(target)) + self.assert_text_proto_eq(target_pb_text, str(target)) - def testmerge_active_shadow_enum(self): - """merge_active_shadow_enum recovers shadow values.""" - active_pb_text = """ + def testmerge_active_shadow_enum(self): + """merge_active_shadow_enum recovers shadow values.""" + active_pb_text = """ value { number: 1 name: "foo" @@ -77,9 +77,9 @@ def testmerge_active_shadow_enum(self): end: 3 } """ - active_proto = descriptor_pb2.EnumDescriptorProto() - text_format.Merge(active_pb_text, active_proto) - shadow_pb_text = """ + active_proto = descriptor_pb2.EnumDescriptorProto() + text_format.Merge(active_pb_text, active_proto) + shadow_pb_text = """ value { number: 1 name: "foo" @@ -101,13 +101,13 @@ def testmerge_active_shadow_enum(self): name: "hidden_envoy_deprecated_huh" } """ - shadow_proto = descriptor_pb2.EnumDescriptorProto() - text_format.Merge(shadow_pb_text, shadow_proto) - target_proto = descriptor_pb2.EnumDescriptorProto() - target_proto_dependencies = [] - merge_active_shadow.merge_active_shadow_enum(active_proto, shadow_proto, target_proto, - target_proto_dependencies) - target_pb_text = """ + shadow_proto = descriptor_pb2.EnumDescriptorProto() + text_format.Merge(shadow_pb_text, shadow_proto) + target_proto = descriptor_pb2.EnumDescriptorProto() + target_proto_dependencies = [] + merge_active_shadow.merge_active_shadow_enum(active_proto, shadow_proto, target_proto, + target_proto_dependencies) + target_pb_text = """ value { name: "foo" number: 1 @@ -125,11 +125,11 @@ def testmerge_active_shadow_enum(self): number: 2 } """ - self.assert_text_proto_eq(target_pb_text, str(target_proto)) + self.assert_text_proto_eq(target_pb_text, str(target_proto)) - def testmerge_active_shadow_message_comments(self): - """merge_active_shadow_message preserves comment field correspondence.""" - active_pb_text = """ + def testmerge_active_shadow_message_comments(self): + """merge_active_shadow_message preserves comment field correspondence.""" + active_pb_text = """ field { number: 9 name: "oneof_1_0" @@ -179,9 +179,9 @@ def testmerge_active_shadow_message_comments(self): name: "oneof_3" } """ - active_proto = descriptor_pb2.DescriptorProto() - text_format.Merge(active_pb_text, active_proto) - active_source_code_info_text = """ + active_proto = descriptor_pb2.DescriptorProto() + text_format.Merge(active_pb_text, active_proto) + active_source_code_info_text = """ location { path: [4, 1, 2, 4] leading_comments: "field_4" @@ -219,9 +219,9 @@ def testmerge_active_shadow_message_comments(self): leading_comments: "ignore_1" } """ - active_source_code_info = descriptor_pb2.SourceCodeInfo() - text_format.Merge(active_source_code_info_text, active_source_code_info) - shadow_pb_text = """ + active_source_code_info = descriptor_pb2.SourceCodeInfo() + text_format.Merge(active_source_code_info_text, active_source_code_info) + shadow_pb_text = """ field { number: 10 name: "hidden_envoy_deprecated_missing_oneof_field_0" @@ -253,16 +253,16 @@ def testmerge_active_shadow_message_comments(self): name: "oneof_3" } """ - shadow_proto = descriptor_pb2.DescriptorProto() - text_format.Merge(shadow_pb_text, shadow_proto) - target_proto = descriptor_pb2.DescriptorProto() - source_code_info = api_type_context.SourceCodeInfo('fake', active_source_code_info) - fake_type_context = api_type_context.TypeContext(source_code_info, 'fake_package') - target_proto_dependencies = [] - merge_active_shadow.merge_active_shadow_message( - fake_type_context.extend_message(1, "foo", False), active_proto, shadow_proto, target_proto, - target_proto_dependencies) - target_pb_text = """ + shadow_proto = descriptor_pb2.DescriptorProto() + text_format.Merge(shadow_pb_text, shadow_proto) + target_proto = descriptor_pb2.DescriptorProto() + source_code_info = api_type_context.SourceCodeInfo('fake', active_source_code_info) + fake_type_context = api_type_context.TypeContext(source_code_info, 'fake_package') + target_proto_dependencies = [] + merge_active_shadow.merge_active_shadow_message( + fake_type_context.extend_message(1, "foo", False), active_proto, shadow_proto, + target_proto, target_proto_dependencies) + target_pb_text = """ field { name: "oneof_1_0" number: 9 @@ -327,7 +327,7 @@ def testmerge_active_shadow_message_comments(self): name: "some_removed_oneof" } """ - target_source_code_info_text = """ + target_source_code_info_text = """ location { path: 4 path: 1 @@ -389,14 +389,14 @@ def testmerge_active_shadow_message_comments(self): leading_comments: "ignore_1" } """ - self.maxDiff = None - self.assert_text_proto_eq(target_pb_text, str(target_proto)) - self.assert_text_proto_eq(target_source_code_info_text, - str(fake_type_context.source_code_info.proto)) + self.maxDiff = None + self.assert_text_proto_eq(target_pb_text, str(target_proto)) + self.assert_text_proto_eq(target_source_code_info_text, + str(fake_type_context.source_code_info.proto)) - def testmerge_active_shadow_message(self): - """merge_active_shadow_message recovers shadow fields with oneofs.""" - active_pb_text = """ + def testmerge_active_shadow_message(self): + """merge_active_shadow_message recovers shadow fields with oneofs.""" + active_pb_text = """ field { number: 1 name: "foo" @@ -429,9 +429,9 @@ def testmerge_active_shadow_message(self): name: "some_oneof" } """ - active_proto = descriptor_pb2.DescriptorProto() - text_format.Merge(active_pb_text, active_proto) - shadow_pb_text = """ + active_proto = descriptor_pb2.DescriptorProto() + text_format.Merge(active_pb_text, active_proto) + shadow_pb_text = """ field { number: 1 name: "foo" @@ -462,14 +462,14 @@ def testmerge_active_shadow_message(self): name: "some_oneof" } """ - shadow_proto = descriptor_pb2.DescriptorProto() - text_format.Merge(shadow_pb_text, shadow_proto) - target_proto = descriptor_pb2.DescriptorProto() - target_proto_dependencies = [] - merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, - shadow_proto, target_proto, - target_proto_dependencies) - target_pb_text = """ + shadow_proto = descriptor_pb2.DescriptorProto() + text_format.Merge(shadow_pb_text, shadow_proto) + target_proto = descriptor_pb2.DescriptorProto() + target_proto_dependencies = [] + merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, + shadow_proto, target_proto, + target_proto_dependencies) + target_pb_text = """ field { name: "foo" number: 1 @@ -515,74 +515,74 @@ def testmerge_active_shadow_message(self): end: 3 } """ - self.assert_text_proto_eq(target_pb_text, str(target_proto)) - self.assertEqual(target_proto_dependencies[0], 'envoy/annotations/deprecation.proto') + self.assert_text_proto_eq(target_pb_text, str(target_proto)) + self.assertEqual(target_proto_dependencies[0], 'envoy/annotations/deprecation.proto') - def testmerge_active_shadow_message_no_shadow_message(self): - """merge_active_shadow_message doesn't require a shadow message for new nested active messages.""" - active_proto = descriptor_pb2.DescriptorProto() - shadow_proto = descriptor_pb2.DescriptorProto() - active_proto.nested_type.add().name = 'foo' - target_proto = descriptor_pb2.DescriptorProto() - target_proto_dependencies = [] - merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, - shadow_proto, target_proto, - target_proto_dependencies) - self.assertEqual(target_proto.nested_type[0].name, 'foo') + def testmerge_active_shadow_message_no_shadow_message(self): + """merge_active_shadow_message doesn't require a shadow message for new nested active messages.""" + active_proto = descriptor_pb2.DescriptorProto() + shadow_proto = descriptor_pb2.DescriptorProto() + active_proto.nested_type.add().name = 'foo' + target_proto = descriptor_pb2.DescriptorProto() + target_proto_dependencies = [] + merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, + shadow_proto, target_proto, + target_proto_dependencies) + self.assertEqual(target_proto.nested_type[0].name, 'foo') - def testmerge_active_shadow_message_no_shadow_enum(self): - """merge_active_shadow_message doesn't require a shadow enum for new nested active enums.""" - active_proto = descriptor_pb2.DescriptorProto() - shadow_proto = descriptor_pb2.DescriptorProto() - active_proto.enum_type.add().name = 'foo' - target_proto = descriptor_pb2.DescriptorProto() - target_proto_dependencies = [] - merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, - shadow_proto, target_proto, - target_proto_dependencies) - self.assertEqual(target_proto.enum_type[0].name, 'foo') + def testmerge_active_shadow_message_no_shadow_enum(self): + """merge_active_shadow_message doesn't require a shadow enum for new nested active enums.""" + active_proto = descriptor_pb2.DescriptorProto() + shadow_proto = descriptor_pb2.DescriptorProto() + active_proto.enum_type.add().name = 'foo' + target_proto = descriptor_pb2.DescriptorProto() + target_proto_dependencies = [] + merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, + shadow_proto, target_proto, + target_proto_dependencies) + self.assertEqual(target_proto.enum_type[0].name, 'foo') - def testmerge_active_shadow_message_missing(self): - """merge_active_shadow_message recovers missing messages from shadow.""" - active_proto = descriptor_pb2.DescriptorProto() - shadow_proto = descriptor_pb2.DescriptorProto() - shadow_proto.nested_type.add().name = 'foo' - target_proto = descriptor_pb2.DescriptorProto() - target_proto_dependencies = [] - merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, - shadow_proto, target_proto, - target_proto_dependencies) - self.assertEqual(target_proto.nested_type[0].name, 'foo') + def testmerge_active_shadow_message_missing(self): + """merge_active_shadow_message recovers missing messages from shadow.""" + active_proto = descriptor_pb2.DescriptorProto() + shadow_proto = descriptor_pb2.DescriptorProto() + shadow_proto.nested_type.add().name = 'foo' + target_proto = descriptor_pb2.DescriptorProto() + target_proto_dependencies = [] + merge_active_shadow.merge_active_shadow_message(self.fake_type_context(), active_proto, + shadow_proto, target_proto, + target_proto_dependencies) + self.assertEqual(target_proto.nested_type[0].name, 'foo') - def testmerge_active_shadow_file_missing(self): - """merge_active_shadow_file recovers missing messages from shadow.""" - active_proto = descriptor_pb2.FileDescriptorProto() - shadow_proto = descriptor_pb2.FileDescriptorProto() - shadow_proto.message_type.add().name = 'foo' - target_proto = descriptor_pb2.DescriptorProto() - target_proto = merge_active_shadow.merge_active_shadow_file(active_proto, shadow_proto) - self.assertEqual(target_proto.message_type[0].name, 'foo') + def testmerge_active_shadow_file_missing(self): + """merge_active_shadow_file recovers missing messages from shadow.""" + active_proto = descriptor_pb2.FileDescriptorProto() + shadow_proto = descriptor_pb2.FileDescriptorProto() + shadow_proto.message_type.add().name = 'foo' + target_proto = descriptor_pb2.DescriptorProto() + target_proto = merge_active_shadow.merge_active_shadow_file(active_proto, shadow_proto) + self.assertEqual(target_proto.message_type[0].name, 'foo') - def testmerge_active_shadow_file_no_shadow_message(self): - """merge_active_shadow_file doesn't require a shadow message for new active messages.""" - active_proto = descriptor_pb2.FileDescriptorProto() - shadow_proto = descriptor_pb2.FileDescriptorProto() - active_proto.message_type.add().name = 'foo' - target_proto = descriptor_pb2.DescriptorProto() - target_proto = merge_active_shadow.merge_active_shadow_file(active_proto, shadow_proto) - self.assertEqual(target_proto.message_type[0].name, 'foo') + def testmerge_active_shadow_file_no_shadow_message(self): + """merge_active_shadow_file doesn't require a shadow message for new active messages.""" + active_proto = descriptor_pb2.FileDescriptorProto() + shadow_proto = descriptor_pb2.FileDescriptorProto() + active_proto.message_type.add().name = 'foo' + target_proto = descriptor_pb2.DescriptorProto() + target_proto = merge_active_shadow.merge_active_shadow_file(active_proto, shadow_proto) + self.assertEqual(target_proto.message_type[0].name, 'foo') - def testmerge_active_shadow_file_no_shadow_enum(self): - """merge_active_shadow_file doesn't require a shadow enum for new active enums.""" - active_proto = descriptor_pb2.FileDescriptorProto() - shadow_proto = descriptor_pb2.FileDescriptorProto() - active_proto.enum_type.add().name = 'foo' - target_proto = descriptor_pb2.DescriptorProto() - target_proto = merge_active_shadow.merge_active_shadow_file(active_proto, shadow_proto) - self.assertEqual(target_proto.enum_type[0].name, 'foo') + def testmerge_active_shadow_file_no_shadow_enum(self): + """merge_active_shadow_file doesn't require a shadow enum for new active enums.""" + active_proto = descriptor_pb2.FileDescriptorProto() + shadow_proto = descriptor_pb2.FileDescriptorProto() + active_proto.enum_type.add().name = 'foo' + target_proto = descriptor_pb2.DescriptorProto() + target_proto = merge_active_shadow.merge_active_shadow_file(active_proto, shadow_proto) + self.assertEqual(target_proto.enum_type[0].name, 'foo') # TODO(htuch): add some test for recursion. if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tools/protoxform/migrate.py b/tools/protoxform/migrate.py index 7aa7ba9ed325..61a53f3a2e10 100644 --- a/tools/protoxform/migrate.py +++ b/tools/protoxform/migrate.py @@ -19,227 +19,230 @@ class UpgradeVisitor(visitor.Visitor): - """Visitor to generate an upgraded proto from a FileDescriptor proto. + """Visitor to generate an upgraded proto from a FileDescriptor proto. See visitor.Visitor for visitor method docs comments. """ - def __init__(self, n, typedb, envoy_internal_shadow, package_version_status): - self._base_version = n - self._typedb = typedb - self._envoy_internal_shadow = envoy_internal_shadow - self._package_version_status = package_version_status + def __init__(self, n, typedb, envoy_internal_shadow, package_version_status): + self._base_version = n + self._typedb = typedb + self._envoy_internal_shadow = envoy_internal_shadow + self._package_version_status = package_version_status - def _upgraded_comment(self, c): + def _upgraded_comment(self, c): - def upgrade_type(match): - # We're upgrading a type within a RST anchor reference here. These are - # stylized and match the output format of tools/protodoc. We need to do - # some special handling of field/enum values, and also the normalization - # that was performed in v2 for envoy.api.v2 types. - label_ref_type, label_normalized_type_name, section_ref_type, section_normalized_type_name = match.groups( - ) - if label_ref_type is not None: - ref_type = label_ref_type - normalized_type_name = label_normalized_type_name - else: - ref_type = section_ref_type - normalized_type_name = section_normalized_type_name - if ref_type == 'field' or ref_type == 'enum_value': - normalized_type_name, residual = normalized_type_name.rsplit('.', 1) - else: - residual = '' - type_name = 'envoy.' + normalized_type_name - api_v2_type_name = 'envoy.api.v2.' + normalized_type_name - if type_name in self._typedb.types: - type_desc = self._typedb.types[type_name] - else: - # We need to deal with envoy.api.* normalization in the v2 API. We won't - # need this in v3+, so rather than churn docs, we just have this workaround. - type_desc = self._typedb.types[api_v2_type_name] - repl_type = type_desc.next_version_type_name[ - len('envoy.'):] if type_desc.next_version_type_name else normalized_type_name - # TODO(htuch): this should really either go through the type database or - # via the descriptor pool and annotations, but there are only two of these - # we need for the initial v2 -> v3 docs cut, so hard coding for now. - # Tracked at https://github.com/envoyproxy/envoy/issues/9734. - if repl_type == 'config.route.v3.RouteAction': - if residual == 'host_rewrite': - residual = 'host_rewrite_literal' - elif residual == 'auto_host_rewrite_header': - residual = 'auto_host_rewrite' - new_ref = 'envoy_api_%s_%s%s' % (ref_type, repl_type, '.' + residual if residual else '') - if label_ref_type is not None: - return '<%s>' % new_ref - else: - return ':ref:`%s`' % new_ref + def upgrade_type(match): + # We're upgrading a type within a RST anchor reference here. These are + # stylized and match the output format of tools/protodoc. We need to do + # some special handling of field/enum values, and also the normalization + # that was performed in v2 for envoy.api.v2 types. + label_ref_type, label_normalized_type_name, section_ref_type, section_normalized_type_name = match.groups( + ) + if label_ref_type is not None: + ref_type = label_ref_type + normalized_type_name = label_normalized_type_name + else: + ref_type = section_ref_type + normalized_type_name = section_normalized_type_name + if ref_type == 'field' or ref_type == 'enum_value': + normalized_type_name, residual = normalized_type_name.rsplit('.', 1) + else: + residual = '' + type_name = 'envoy.' + normalized_type_name + api_v2_type_name = 'envoy.api.v2.' + normalized_type_name + if type_name in self._typedb.types: + type_desc = self._typedb.types[type_name] + else: + # We need to deal with envoy.api.* normalization in the v2 API. We won't + # need this in v3+, so rather than churn docs, we just have this workaround. + type_desc = self._typedb.types[api_v2_type_name] + repl_type = type_desc.next_version_type_name[ + len('envoy.'):] if type_desc.next_version_type_name else normalized_type_name + # TODO(htuch): this should really either go through the type database or + # via the descriptor pool and annotations, but there are only two of these + # we need for the initial v2 -> v3 docs cut, so hard coding for now. + # Tracked at https://github.com/envoyproxy/envoy/issues/9734. + if repl_type == 'config.route.v3.RouteAction': + if residual == 'host_rewrite': + residual = 'host_rewrite_literal' + elif residual == 'auto_host_rewrite_header': + residual = 'auto_host_rewrite' + new_ref = 'envoy_api_%s_%s%s' % (ref_type, repl_type, + '.' + residual if residual else '') + if label_ref_type is not None: + return '<%s>' % new_ref + else: + return ':ref:`%s`' % new_ref - return re.sub(ENVOY_COMMENT_WITH_TYPE_REGEX, upgrade_type, c) + return re.sub(ENVOY_COMMENT_WITH_TYPE_REGEX, upgrade_type, c) - def _upgraded_post_method(self, m): - return re.sub(r'^/v%d/' % self._base_version, '/v%d/' % (self._base_version + 1), m) + def _upgraded_post_method(self, m): + return re.sub(r'^/v%d/' % self._base_version, '/v%d/' % (self._base_version + 1), m) - # Upgraded type using canonical type naming, e.g. foo.bar. - def _upgraded_type_canonical(self, t): - if not t.startswith('envoy'): - return t - type_desc = self._typedb.types[t] - if type_desc.next_version_type_name: - return type_desc.next_version_type_name - return t + # Upgraded type using canonical type naming, e.g. foo.bar. + def _upgraded_type_canonical(self, t): + if not t.startswith('envoy'): + return t + type_desc = self._typedb.types[t] + if type_desc.next_version_type_name: + return type_desc.next_version_type_name + return t - # Upgraded type using internal type naming, e.g. .foo.bar. - def _upgraded_type(self, t): - if not t.startswith('.envoy'): - return t - return '.' + self._upgraded_type_canonical(t[1:]) + # Upgraded type using internal type naming, e.g. .foo.bar. + def _upgraded_type(self, t): + if not t.startswith('.envoy'): + return t + return '.' + self._upgraded_type_canonical(t[1:]) - def _deprecate(self, proto, field_or_value): - """Deprecate a field or value in a message/enum proto. + def _deprecate(self, proto, field_or_value): + """Deprecate a field or value in a message/enum proto. Args: proto: DescriptorProto or EnumDescriptorProto message. field_or_value: field or value inside proto. """ - if self._envoy_internal_shadow: - field_or_value.name = 'hidden_envoy_deprecated_' + field_or_value.name - else: - reserved = proto.reserved_range.add() - reserved.start = field_or_value.number - reserved.end = field_or_value.number + 1 - proto.reserved_name.append(field_or_value.name) - options.add_hide_option(field_or_value.options) + if self._envoy_internal_shadow: + field_or_value.name = 'hidden_envoy_deprecated_' + field_or_value.name + else: + reserved = proto.reserved_range.add() + reserved.start = field_or_value.number + reserved.end = field_or_value.number + 1 + proto.reserved_name.append(field_or_value.name) + options.add_hide_option(field_or_value.options) - def _rename(self, proto, migrate_annotation): - """Rename a field/enum/service/message + def _rename(self, proto, migrate_annotation): + """Rename a field/enum/service/message Args: proto: DescriptorProto or corresponding proto message migrate_annotation: udpa.annotations.MigrateAnnotation message """ - if migrate_annotation.rename: - proto.name = migrate_annotation.rename - migrate_annotation.rename = "" + if migrate_annotation.rename: + proto.name = migrate_annotation.rename + migrate_annotation.rename = "" - def _oneof_promotion(self, msg_proto, field_proto, migrate_annotation): - """Promote a field to a oneof. + def _oneof_promotion(self, msg_proto, field_proto, migrate_annotation): + """Promote a field to a oneof. Args: msg_proto: DescriptorProto for message containing field. field_proto: FieldDescriptorProto for field. migrate_annotation: udpa.annotations.FieldMigrateAnnotation message """ - if migrate_annotation.oneof_promotion: - oneof_index = -1 - for n, oneof_decl in enumerate(msg_proto.oneof_decl): - if oneof_decl.name == migrate_annotation.oneof_promotion: - oneof_index = n - if oneof_index == -1: - oneof_index = len(msg_proto.oneof_decl) - oneof_decl = msg_proto.oneof_decl.add() - oneof_decl.name = migrate_annotation.oneof_promotion - field_proto.oneof_index = oneof_index - migrate_annotation.oneof_promotion = "" + if migrate_annotation.oneof_promotion: + oneof_index = -1 + for n, oneof_decl in enumerate(msg_proto.oneof_decl): + if oneof_decl.name == migrate_annotation.oneof_promotion: + oneof_index = n + if oneof_index == -1: + oneof_index = len(msg_proto.oneof_decl) + oneof_decl = msg_proto.oneof_decl.add() + oneof_decl.name = migrate_annotation.oneof_promotion + field_proto.oneof_index = oneof_index + migrate_annotation.oneof_promotion = "" - def visit_service(self, service_proto, type_context): - upgraded_proto = copy.deepcopy(service_proto) - for m in upgraded_proto.method: - if m.options.HasExtension(annotations_pb2.http): - http_options = m.options.Extensions[annotations_pb2.http] - # TODO(htuch): figure out a more systematic approach using the type DB - # to service upgrade. - http_options.post = self._upgraded_post_method(http_options.post) - m.input_type = self._upgraded_type(m.input_type) - m.output_type = self._upgraded_type(m.output_type) - if service_proto.options.HasExtension(resource_pb2.resource): - upgraded_proto.options.Extensions[resource_pb2.resource].type = self._upgraded_type_canonical( - service_proto.options.Extensions[resource_pb2.resource].type) - return upgraded_proto + def visit_service(self, service_proto, type_context): + upgraded_proto = copy.deepcopy(service_proto) + for m in upgraded_proto.method: + if m.options.HasExtension(annotations_pb2.http): + http_options = m.options.Extensions[annotations_pb2.http] + # TODO(htuch): figure out a more systematic approach using the type DB + # to service upgrade. + http_options.post = self._upgraded_post_method(http_options.post) + m.input_type = self._upgraded_type(m.input_type) + m.output_type = self._upgraded_type(m.output_type) + if service_proto.options.HasExtension(resource_pb2.resource): + upgraded_proto.options.Extensions[ + resource_pb2.resource].type = self._upgraded_type_canonical( + service_proto.options.Extensions[resource_pb2.resource].type) + return upgraded_proto - def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): - upgraded_proto = copy.deepcopy(msg_proto) - if upgraded_proto.options.deprecated and not self._envoy_internal_shadow: - options.add_hide_option(upgraded_proto.options) - options.set_versioning_annotation(upgraded_proto.options, type_context.name) - # Mark deprecated fields as ready for deletion by protoxform. - for f in upgraded_proto.field: - if f.options.deprecated: - self._deprecate(upgraded_proto, f) - if self._envoy_internal_shadow: - # When shadowing, we use the upgraded version of types (which should - # themselves also be shadowed), to allow us to avoid unnecessary - # references to the previous version (and complexities around - # upgrading during API boosting). - f.type_name = self._upgraded_type(f.type_name) - else: - # Make sure the type name is erased so it isn't picked up by protoxform - # when computing deps. - f.type_name = "" - else: - f.type_name = self._upgraded_type(f.type_name) - if f.options.HasExtension(migrate_pb2.field_migrate): - field_migrate = f.options.Extensions[migrate_pb2.field_migrate] - self._rename(f, field_migrate) - self._oneof_promotion(upgraded_proto, f, field_migrate) - # Upgrade nested messages. - del upgraded_proto.nested_type[:] - upgraded_proto.nested_type.extend(nested_msgs) - # Upgrade enums. - del upgraded_proto.enum_type[:] - upgraded_proto.enum_type.extend(nested_enums) - return upgraded_proto + def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): + upgraded_proto = copy.deepcopy(msg_proto) + if upgraded_proto.options.deprecated and not self._envoy_internal_shadow: + options.add_hide_option(upgraded_proto.options) + options.set_versioning_annotation(upgraded_proto.options, type_context.name) + # Mark deprecated fields as ready for deletion by protoxform. + for f in upgraded_proto.field: + if f.options.deprecated: + self._deprecate(upgraded_proto, f) + if self._envoy_internal_shadow: + # When shadowing, we use the upgraded version of types (which should + # themselves also be shadowed), to allow us to avoid unnecessary + # references to the previous version (and complexities around + # upgrading during API boosting). + f.type_name = self._upgraded_type(f.type_name) + else: + # Make sure the type name is erased so it isn't picked up by protoxform + # when computing deps. + f.type_name = "" + else: + f.type_name = self._upgraded_type(f.type_name) + if f.options.HasExtension(migrate_pb2.field_migrate): + field_migrate = f.options.Extensions[migrate_pb2.field_migrate] + self._rename(f, field_migrate) + self._oneof_promotion(upgraded_proto, f, field_migrate) + # Upgrade nested messages. + del upgraded_proto.nested_type[:] + upgraded_proto.nested_type.extend(nested_msgs) + # Upgrade enums. + del upgraded_proto.enum_type[:] + upgraded_proto.enum_type.extend(nested_enums) + return upgraded_proto - def visit_enum(self, enum_proto, type_context): - upgraded_proto = copy.deepcopy(enum_proto) - if upgraded_proto.options.deprecated and not self._envoy_internal_shadow: - options.add_hide_option(upgraded_proto.options) - for v in upgraded_proto.value: - if v.options.deprecated: - # We need special handling for the zero field, as proto3 needs some value - # here. - if v.number == 0 and not self._envoy_internal_shadow: - v.name = 'DEPRECATED_AND_UNAVAILABLE_DO_NOT_USE' - else: - # Mark deprecated enum values as ready for deletion by protoxform. - self._deprecate(upgraded_proto, v) - elif v.options.HasExtension(migrate_pb2.enum_value_migrate): - self._rename(v, v.options.Extensions[migrate_pb2.enum_value_migrate]) - return upgraded_proto + def visit_enum(self, enum_proto, type_context): + upgraded_proto = copy.deepcopy(enum_proto) + if upgraded_proto.options.deprecated and not self._envoy_internal_shadow: + options.add_hide_option(upgraded_proto.options) + for v in upgraded_proto.value: + if v.options.deprecated: + # We need special handling for the zero field, as proto3 needs some value + # here. + if v.number == 0 and not self._envoy_internal_shadow: + v.name = 'DEPRECATED_AND_UNAVAILABLE_DO_NOT_USE' + else: + # Mark deprecated enum values as ready for deletion by protoxform. + self._deprecate(upgraded_proto, v) + elif v.options.HasExtension(migrate_pb2.enum_value_migrate): + self._rename(v, v.options.Extensions[migrate_pb2.enum_value_migrate]) + return upgraded_proto - def visit_file(self, file_proto, type_context, services, msgs, enums): - upgraded_proto = copy.deepcopy(file_proto) - # Upgrade imports. - upgraded_proto.dependency[:] = [ - dependency for dependency in upgraded_proto.dependency - if dependency not in ("udpa/annotations/migrate.proto") - ] - # Upgrade package. - upgraded_proto.package = self._typedb.next_version_protos[upgraded_proto.name].qualified_package - upgraded_proto.name = self._typedb.next_version_protos[upgraded_proto.name].proto_path - upgraded_proto.options.ClearExtension(migrate_pb2.file_migrate) - upgraded_proto.options.Extensions[ - status_pb2.file_status].package_version_status = self._package_version_status - # Upgrade comments. - for location in upgraded_proto.source_code_info.location: - location.leading_comments = self._upgraded_comment(location.leading_comments) - location.trailing_comments = self._upgraded_comment(location.trailing_comments) - for n, c in enumerate(location.leading_detached_comments): - location.leading_detached_comments[n] = self._upgraded_comment(c) - # Upgrade services. - del upgraded_proto.service[:] - upgraded_proto.service.extend(services) - # Upgrade messages. - del upgraded_proto.message_type[:] - upgraded_proto.message_type.extend(msgs) - # Upgrade enums. - del upgraded_proto.enum_type[:] - upgraded_proto.enum_type.extend(enums) + def visit_file(self, file_proto, type_context, services, msgs, enums): + upgraded_proto = copy.deepcopy(file_proto) + # Upgrade imports. + upgraded_proto.dependency[:] = [ + dependency for dependency in upgraded_proto.dependency + if dependency not in ("udpa/annotations/migrate.proto") + ] + # Upgrade package. + upgraded_proto.package = self._typedb.next_version_protos[ + upgraded_proto.name].qualified_package + upgraded_proto.name = self._typedb.next_version_protos[upgraded_proto.name].proto_path + upgraded_proto.options.ClearExtension(migrate_pb2.file_migrate) + upgraded_proto.options.Extensions[ + status_pb2.file_status].package_version_status = self._package_version_status + # Upgrade comments. + for location in upgraded_proto.source_code_info.location: + location.leading_comments = self._upgraded_comment(location.leading_comments) + location.trailing_comments = self._upgraded_comment(location.trailing_comments) + for n, c in enumerate(location.leading_detached_comments): + location.leading_detached_comments[n] = self._upgraded_comment(c) + # Upgrade services. + del upgraded_proto.service[:] + upgraded_proto.service.extend(services) + # Upgrade messages. + del upgraded_proto.message_type[:] + upgraded_proto.message_type.extend(msgs) + # Upgrade enums. + del upgraded_proto.enum_type[:] + upgraded_proto.enum_type.extend(enums) - return upgraded_proto + return upgraded_proto def version_upgrade_xform(n, envoy_internal_shadow, file_proto, params): - """Transform a FileDescriptorProto from vN[alpha\d] to v(N+1). + """Transform a FileDescriptorProto from vN[alpha\d] to v(N+1). Args: n: version N to upgrade from. @@ -250,24 +253,24 @@ def version_upgrade_xform(n, envoy_internal_shadow, file_proto, params): Returns: v(N+1) FileDescriptorProto message. """ - # Load type database. - if params['type_db_path']: - utils.load_type_db(params['type_db_path']) - typedb = utils.get_type_db() - # If this isn't a proto in an upgraded package, return None. - if file_proto.name not in typedb.next_version_protos or not typedb.next_version_protos[ - file_proto.name]: - return None - # Otherwise, this .proto needs upgrading, do it. - freeze = 'extra_args' in params and params['extra_args'] == 'freeze' - existing_pkg_version_status = file_proto.options.Extensions[ - status_pb2.file_status].package_version_status - # Normally, we are generating the NEXT_MAJOR_VERSION_CANDIDATE. However, if - # freezing and previously this was the active major version, the migrated - # version is now the ACTIVE version. - if freeze and existing_pkg_version_status == status_pb2.ACTIVE: - package_version_status = status_pb2.ACTIVE - else: - package_version_status = status_pb2.NEXT_MAJOR_VERSION_CANDIDATE - return traverse.traverse_file( - file_proto, UpgradeVisitor(n, typedb, envoy_internal_shadow, package_version_status)) + # Load type database. + if params['type_db_path']: + utils.load_type_db(params['type_db_path']) + typedb = utils.get_type_db() + # If this isn't a proto in an upgraded package, return None. + if file_proto.name not in typedb.next_version_protos or not typedb.next_version_protos[ + file_proto.name]: + return None + # Otherwise, this .proto needs upgrading, do it. + freeze = 'extra_args' in params and params['extra_args'] == 'freeze' + existing_pkg_version_status = file_proto.options.Extensions[ + status_pb2.file_status].package_version_status + # Normally, we are generating the NEXT_MAJOR_VERSION_CANDIDATE. However, if + # freezing and previously this was the active major version, the migrated + # version is now the ACTIVE version. + if freeze and existing_pkg_version_status == status_pb2.ACTIVE: + package_version_status = status_pb2.ACTIVE + else: + package_version_status = status_pb2.NEXT_MAJOR_VERSION_CANDIDATE + return traverse.traverse_file( + file_proto, UpgradeVisitor(n, typedb, envoy_internal_shadow, package_version_status)) diff --git a/tools/protoxform/options.py b/tools/protoxform/options.py index 663ac4f50ae0..41dd5a3e517d 100644 --- a/tools/protoxform/options.py +++ b/tools/protoxform/options.py @@ -4,19 +4,19 @@ def add_hide_option(options): - """Mark message/enum/field/enum value as hidden. + """Mark message/enum/field/enum value as hidden. Hidden messages are ignored when generating output. Args: options: MessageOptions/EnumOptions/FieldOptions/EnumValueOptions message. """ - hide_option = options.uninterpreted_option.add() - hide_option.name.add().name_part = 'protoxform_hide' + hide_option = options.uninterpreted_option.add() + hide_option.name.add().name_part = 'protoxform_hide' def has_hide_option(options): - """Is message/enum/field/enum value hidden? + """Is message/enum/field/enum value hidden? Hidden messages are ignored when generating output. @@ -25,12 +25,12 @@ def has_hide_option(options): Returns: Hidden status. """ - return any( - option.name[0].name_part == 'protoxform_hide' for option in options.uninterpreted_option) + return any( + option.name[0].name_part == 'protoxform_hide' for option in options.uninterpreted_option) def set_versioning_annotation(options, previous_message_type): - """Set the udpa.annotations.versioning option. + """Set the udpa.annotations.versioning option. Used by Envoy to chain back through the message type history. @@ -38,11 +38,11 @@ def set_versioning_annotation(options, previous_message_type): options: MessageOptions message. previous_message_type: string with earlier API type name for the message. """ - options.Extensions[versioning_pb2.versioning].previous_message_type = previous_message_type + options.Extensions[versioning_pb2.versioning].previous_message_type = previous_message_type def get_versioning_annotation(options): - """Get the udpa.annotations.versioning option. + """Get the udpa.annotations.versioning option. Used by Envoy to chain back through the message type history. @@ -51,6 +51,6 @@ def get_versioning_annotation(options): Returns: versioning.Annotation if set otherwise None. """ - if not options.HasExtension(versioning_pb2.versioning): - return None - return options.Extensions[versioning_pb2.versioning] + if not options.HasExtension(versioning_pb2.versioning): + return None + return options.Extensions[versioning_pb2.versioning] diff --git a/tools/protoxform/protoprint.py b/tools/protoxform/protoprint.py index bcd27c0c8bd2..fa71e077099a 100755 --- a/tools/protoxform/protoprint.py +++ b/tools/protoxform/protoprint.py @@ -44,11 +44,11 @@ class ProtoPrintError(Exception): - """Base error class for the protoprint module.""" + """Base error class for the protoprint module.""" def extract_clang_proto_style(clang_format_text): - """Extract a key:value dictionary for proto formatting. + """Extract a key:value dictionary for proto formatting. Args: clang_format_text: text from a .clang-format file. @@ -56,21 +56,21 @@ def extract_clang_proto_style(clang_format_text): Returns: key:value dictionary suitable for passing to clang-format --style. """ - lang = None - format_dict = {} - for line in clang_format_text.split('\n'): - if lang is None or lang != 'Proto': - match = re.match('Language:\s+(\w+)', line) - if match: - lang = match.group(1) - continue - match = re.match('(\w+):\s+(\w+)', line) - if match: - key, value = match.groups() - format_dict[key] = value - else: - break - return str(format_dict) + lang = None + format_dict = {} + for line in clang_format_text.split('\n'): + if lang is None or lang != 'Proto': + match = re.match('Language:\s+(\w+)', line) + if match: + lang = match.group(1) + continue + match = re.match('(\w+):\s+(\w+)', line) + if match: + key, value = match.groups() + format_dict[key] = value + else: + break + return str(format_dict) # Ensure we are using the canonical clang-format proto style. @@ -78,7 +78,7 @@ def extract_clang_proto_style(clang_format_text): def clang_format(contents): - """Run proto-style oriented clang-format over given string. + """Run proto-style oriented clang-format over given string. Args: contents: a string with proto contents. @@ -86,15 +86,15 @@ def clang_format(contents): Returns: clang-formatted string """ - return subprocess.run( - ['clang-format', - '--style=%s' % CLANG_FORMAT_STYLE, '--assume-filename=.proto'], - input=contents.encode('utf-8'), - stdout=subprocess.PIPE).stdout + return subprocess.run( + ['clang-format', + '--style=%s' % CLANG_FORMAT_STYLE, '--assume-filename=.proto'], + input=contents.encode('utf-8'), + stdout=subprocess.PIPE).stdout def format_block(block): - """Append \n to a .proto section (e.g. + """Append \n to a .proto section (e.g. comment, message definition, etc.) if non-empty. @@ -104,13 +104,13 @@ def format_block(block): Returns: A string with appropriate whitespace. """ - if block.strip(): - return block + '\n' - return '' + if block.strip(): + return block + '\n' + return '' def format_comments(comments): - """Format a list of comment blocks from SourceCodeInfo. + """Format a list of comment blocks from SourceCodeInfo. Prefixes // to each line, separates blocks by spaces. @@ -122,20 +122,20 @@ def format_comments(comments): A string reprenting the formatted comment blocks. """ - # TODO(htuch): not sure why this is needed, but clang-format does some weird - # stuff with // comment indents when we have these trailing \ - def fixup_trailing_backslash(s): - return s[:-1].rstrip() if s.endswith('\\') else s + # TODO(htuch): not sure why this is needed, but clang-format does some weird + # stuff with // comment indents when we have these trailing \ + def fixup_trailing_backslash(s): + return s[:-1].rstrip() if s.endswith('\\') else s - comments = '\n\n'.join( - '\n'.join(['//%s' % fixup_trailing_backslash(line) - for line in comment.split('\n')[:-1]]) - for comment in comments) - return format_block(comments) + comments = '\n\n'.join( + '\n'.join(['//%s' % fixup_trailing_backslash(line) + for line in comment.split('\n')[:-1]]) + for comment in comments) + return format_block(comments) def create_next_free_field_xform(msg_proto): - """Return the next free field number annotation transformer of a message. + """Return the next free field number annotation transformer of a message. Args: msg_proto: DescriptorProto for message. @@ -143,17 +143,17 @@ def create_next_free_field_xform(msg_proto): Returns: the next free field number annotation transformer. """ - next_free = max( - sum([ - [f.number + 1 for f in msg_proto.field], - [rr.end for rr in msg_proto.reserved_range], - [ex.end for ex in msg_proto.extension_range], - ], [1])) - return lambda _: next_free if next_free > NEXT_FREE_FIELD_MIN else None + next_free = max( + sum([ + [f.number + 1 for f in msg_proto.field], + [rr.end for rr in msg_proto.reserved_range], + [ex.end for ex in msg_proto.extension_range], + ], [1])) + return lambda _: next_free if next_free > NEXT_FREE_FIELD_MIN else None def format_type_context_comments(type_context, annotation_xforms=None): - """Format the leading/trailing comments in a given TypeContext. + """Format the leading/trailing comments in a given TypeContext. Args: type_context: contextual information for message/enum/field. @@ -163,16 +163,16 @@ def format_type_context_comments(type_context, annotation_xforms=None): Returns: Tuple of formatted leading and trailing comment blocks. """ - leading_comment = type_context.leading_comment - if annotation_xforms: - leading_comment = leading_comment.get_comment_with_transforms(annotation_xforms) - leading = format_comments(list(type_context.leading_detached_comments) + [leading_comment.raw]) - trailing = format_block(format_comments([type_context.trailing_comment])) - return leading, trailing + leading_comment = type_context.leading_comment + if annotation_xforms: + leading_comment = leading_comment.get_comment_with_transforms(annotation_xforms) + leading = format_comments(list(type_context.leading_detached_comments) + [leading_comment.raw]) + trailing = format_block(format_comments([type_context.trailing_comment])) + return leading, trailing def format_header_from_file(source_code_info, file_proto, empty_file): - """Format proto header. + """Format proto header. Args: source_code_info: SourceCodeInfo object. @@ -182,125 +182,125 @@ def format_header_from_file(source_code_info, file_proto, empty_file): Returns: Formatted proto header as a string. """ - # Load the type database. - typedb = utils.get_type_db() - # Figure out type dependencies in this .proto. - types = Types() - text_format.Merge(traverse.traverse_file(file_proto, type_whisperer.TypeWhispererVisitor()), - types) - type_dependencies = sum([list(t.type_dependencies) for t in types.types.values()], []) - for service in file_proto.service: - for m in service.method: - type_dependencies.extend([m.input_type[1:], m.output_type[1:]]) - # Determine the envoy/ import paths from type deps. - envoy_proto_paths = set( - typedb.types[t].proto_path - for t in type_dependencies - if t.startswith('envoy.') and typedb.types[t].proto_path != file_proto.name) - - def camel_case(s): - return ''.join(t.capitalize() for t in re.split('[\._]', s)) - - package_line = 'package %s;\n' % file_proto.package - file_block = '\n'.join(['syntax = "proto3";\n', package_line]) - - options = descriptor_pb2.FileOptions() - - options.java_outer_classname = camel_case(os.path.basename(file_proto.name)) - for msg in file_proto.message_type: - if msg.name == options.java_outer_classname: - # This is a workaround for Java outer class names that would otherwise - # conflict with types defined within the same proto file, see - # https://github.com/envoyproxy/envoy/pull/13378. - # TODO: in next major version, make this consistent. - options.java_outer_classname += "OuterClass" - - options.java_multiple_files = True - options.java_package = 'io.envoyproxy.' + file_proto.package - - # This is a workaround for C#/Ruby namespace conflicts between packages and - # objects, see https://github.com/envoyproxy/envoy/pull/3854. - # TODO(htuch): remove once v3 fixes this naming issue in - # https://github.com/envoyproxy/envoy/issues/8120. - if file_proto.package in ['envoy.api.v2.listener', 'envoy.api.v2.cluster']: - qualified_package = '.'.join(s.capitalize() for s in file_proto.package.split('.')) + 'NS' - options.csharp_namespace = qualified_package - options.ruby_package = qualified_package - - if file_proto.service: - options.java_generic_services = True - - if file_proto.options.HasExtension(migrate_pb2.file_migrate): - options.Extensions[migrate_pb2.file_migrate].CopyFrom( - file_proto.options.Extensions[migrate_pb2.file_migrate]) - - if file_proto.options.HasExtension( - status_pb2.file_status) and file_proto.package.endswith('alpha'): - options.Extensions[status_pb2.file_status].CopyFrom( - file_proto.options.Extensions[status_pb2.file_status]) - - if not empty_file: - options.Extensions[ - status_pb2.file_status].package_version_status = file_proto.options.Extensions[ - status_pb2.file_status].package_version_status - - options_block = format_options(options) - - requires_versioning_import = any( - protoxform_options.get_versioning_annotation(m.options) for m in file_proto.message_type) - - envoy_imports = list(envoy_proto_paths) - google_imports = [] - infra_imports = [] - misc_imports = [] - public_imports = [] - - for idx, d in enumerate(file_proto.dependency): - if idx in file_proto.public_dependency: - public_imports.append(d) - continue - elif d.startswith('envoy/annotations') or d.startswith('udpa/annotations'): - infra_imports.append(d) - elif d.startswith('envoy/'): - # We ignore existing envoy/ imports, since these are computed explicitly - # from type_dependencies. - pass - elif d.startswith('google/'): - google_imports.append(d) - elif d.startswith('validate/'): - infra_imports.append(d) - elif d in ['udpa/annotations/versioning.proto', 'udpa/annotations/status.proto']: - # Skip, we decide to add this based on requires_versioning_import and options. - pass - else: - misc_imports.append(d) - - if options.HasExtension(status_pb2.file_status): - infra_imports.append('udpa/annotations/status.proto') - - if requires_versioning_import: - infra_imports.append('udpa/annotations/versioning.proto') - - def format_import_block(xs): - if not xs: - return '' - return format_block('\n'.join(sorted('import "%s";' % x for x in set(xs) if x))) - - def format_public_import_block(xs): - if not xs: - return '' - return format_block('\n'.join(sorted('import public "%s";' % x for x in xs))) - - import_block = '\n'.join( - map(format_import_block, [envoy_imports, google_imports, misc_imports, infra_imports])) - import_block += '\n' + format_public_import_block(public_imports) - comment_block = format_comments(source_code_info.file_level_comments) - - return ''.join(map(format_block, [file_block, import_block, options_block, comment_block])) + # Load the type database. + typedb = utils.get_type_db() + # Figure out type dependencies in this .proto. + types = Types() + text_format.Merge(traverse.traverse_file(file_proto, type_whisperer.TypeWhispererVisitor()), + types) + type_dependencies = sum([list(t.type_dependencies) for t in types.types.values()], []) + for service in file_proto.service: + for m in service.method: + type_dependencies.extend([m.input_type[1:], m.output_type[1:]]) + # Determine the envoy/ import paths from type deps. + envoy_proto_paths = set( + typedb.types[t].proto_path + for t in type_dependencies + if t.startswith('envoy.') and typedb.types[t].proto_path != file_proto.name) + + def camel_case(s): + return ''.join(t.capitalize() for t in re.split('[\._]', s)) + + package_line = 'package %s;\n' % file_proto.package + file_block = '\n'.join(['syntax = "proto3";\n', package_line]) + + options = descriptor_pb2.FileOptions() + + options.java_outer_classname = camel_case(os.path.basename(file_proto.name)) + for msg in file_proto.message_type: + if msg.name == options.java_outer_classname: + # This is a workaround for Java outer class names that would otherwise + # conflict with types defined within the same proto file, see + # https://github.com/envoyproxy/envoy/pull/13378. + # TODO: in next major version, make this consistent. + options.java_outer_classname += "OuterClass" + + options.java_multiple_files = True + options.java_package = 'io.envoyproxy.' + file_proto.package + + # This is a workaround for C#/Ruby namespace conflicts between packages and + # objects, see https://github.com/envoyproxy/envoy/pull/3854. + # TODO(htuch): remove once v3 fixes this naming issue in + # https://github.com/envoyproxy/envoy/issues/8120. + if file_proto.package in ['envoy.api.v2.listener', 'envoy.api.v2.cluster']: + qualified_package = '.'.join(s.capitalize() for s in file_proto.package.split('.')) + 'NS' + options.csharp_namespace = qualified_package + options.ruby_package = qualified_package + + if file_proto.service: + options.java_generic_services = True + + if file_proto.options.HasExtension(migrate_pb2.file_migrate): + options.Extensions[migrate_pb2.file_migrate].CopyFrom( + file_proto.options.Extensions[migrate_pb2.file_migrate]) + + if file_proto.options.HasExtension( + status_pb2.file_status) and file_proto.package.endswith('alpha'): + options.Extensions[status_pb2.file_status].CopyFrom( + file_proto.options.Extensions[status_pb2.file_status]) + + if not empty_file: + options.Extensions[ + status_pb2.file_status].package_version_status = file_proto.options.Extensions[ + status_pb2.file_status].package_version_status + + options_block = format_options(options) + + requires_versioning_import = any( + protoxform_options.get_versioning_annotation(m.options) for m in file_proto.message_type) + + envoy_imports = list(envoy_proto_paths) + google_imports = [] + infra_imports = [] + misc_imports = [] + public_imports = [] + + for idx, d in enumerate(file_proto.dependency): + if idx in file_proto.public_dependency: + public_imports.append(d) + continue + elif d.startswith('envoy/annotations') or d.startswith('udpa/annotations'): + infra_imports.append(d) + elif d.startswith('envoy/'): + # We ignore existing envoy/ imports, since these are computed explicitly + # from type_dependencies. + pass + elif d.startswith('google/'): + google_imports.append(d) + elif d.startswith('validate/'): + infra_imports.append(d) + elif d in ['udpa/annotations/versioning.proto', 'udpa/annotations/status.proto']: + # Skip, we decide to add this based on requires_versioning_import and options. + pass + else: + misc_imports.append(d) + + if options.HasExtension(status_pb2.file_status): + infra_imports.append('udpa/annotations/status.proto') + + if requires_versioning_import: + infra_imports.append('udpa/annotations/versioning.proto') + + def format_import_block(xs): + if not xs: + return '' + return format_block('\n'.join(sorted('import "%s";' % x for x in set(xs) if x))) + + def format_public_import_block(xs): + if not xs: + return '' + return format_block('\n'.join(sorted('import public "%s";' % x for x in xs))) + + import_block = '\n'.join( + map(format_import_block, [envoy_imports, google_imports, misc_imports, infra_imports])) + import_block += '\n' + format_public_import_block(public_imports) + comment_block = format_comments(source_code_info.file_level_comments) + + return ''.join(map(format_block, [file_block, import_block, options_block, comment_block])) def normalize_field_type_name(type_context, field_fqn): - """Normalize a fully qualified field type name, e.g. + """Normalize a fully qualified field type name, e.g. .envoy.foo.bar is normalized to foo.bar. @@ -313,80 +313,80 @@ def normalize_field_type_name(type_context, field_fqn): Returns: Normalized type name as a string. """ - if field_fqn.startswith('.'): - # Let's say we have type context namespace a.b.c.d.e and the type we're - # trying to normalize is a.b.d.e. We take (from the end) on package fragment - # at a time, and apply the inner-most evaluation that protoc performs to see - # if we evaluate to the fully qualified type. If so, we're done. It's not - # sufficient to compute common prefix and drop that, since in the above - # example the normalized type name would be d.e, which proto resolves inner - # most as a.b.c.d.e (bad) instead of the intended a.b.d.e. - field_fqn_splits = field_fqn[1:].split('.') - type_context_splits = type_context.name.split('.')[:-1] - remaining_field_fqn_splits = deque(field_fqn_splits[:-1]) - normalized_splits = deque([field_fqn_splits[-1]]) - - if list(remaining_field_fqn_splits)[:1] != type_context_splits[:1] and ( - len(remaining_field_fqn_splits) == 0 or - remaining_field_fqn_splits[0] in type_context_splits[1:]): - # Notice that in some cases it is error-prone to normalize a type name. - # E.g., it would be an error to replace ".external.Type" with "external.Type" - # in the context of "envoy.extensions.type.external.vX.Config". - # In such a context protoc resolves "external.Type" into - # "envoy.extensions.type.external.Type", which is exactly what the use of a - # fully-qualified name ".external.Type" was meant to prevent. - # - # A type SHOULD remain fully-qualified under the following conditions: - # 1. its root package is different from the root package of the context type - # 2. EITHER the type doesn't belong to any package at all - # OR its root package has a name that collides with one of the packages - # of the context type - # - # E.g., - # a) although ".some.Type" has a different root package than the context type - # "TopLevelType", it is still safe to normalize it into "some.Type" - # b) although ".google.protobuf.Any" has a different root package than the context type - # "envoy.api.v2.Cluster", it still safe to normalize it into "google.protobuf.Any" - # c) it is error-prone to normalize ".TopLevelType" in the context of "some.Type" - # into "TopLevelType" - # d) it is error-prone to normalize ".external.Type" in the context of - # "envoy.extensions.type.external.vX.Config" into "external.Type" - return field_fqn - - def equivalent_in_type_context(splits): - type_context_splits_tmp = deque(type_context_splits) - while type_context_splits_tmp: - # If we're in a.b.c and the FQN is a.d.Foo, we want to return true once - # we have type_context_splits_tmp as [a] and splits as [d, Foo]. - if list(type_context_splits_tmp) + list(splits) == field_fqn_splits: - return True - # If we're in a.b.c.d.e.f and the FQN is a.b.d.e.Foo, we want to return True - # once we have type_context_splits_tmp as [a] and splits as [b, d, e, Foo], but - # not when type_context_splits_tmp is [a, b, c] and FQN is [d, e, Foo]. - if len(splits) > 1 and '.'.join(type_context_splits_tmp).endswith('.'.join( - list(splits)[:-1])): - return False - type_context_splits_tmp.pop() - return False - - while remaining_field_fqn_splits and not equivalent_in_type_context(normalized_splits): - normalized_splits.appendleft(remaining_field_fqn_splits.pop()) - - # `extensions` is a keyword in proto2, and protoc will throw error if a type name - # starts with `extensions.`. - if normalized_splits[0] == 'extensions': - normalized_splits.appendleft(remaining_field_fqn_splits.pop()) - - return '.'.join(normalized_splits) - return field_fqn + if field_fqn.startswith('.'): + # Let's say we have type context namespace a.b.c.d.e and the type we're + # trying to normalize is a.b.d.e. We take (from the end) on package fragment + # at a time, and apply the inner-most evaluation that protoc performs to see + # if we evaluate to the fully qualified type. If so, we're done. It's not + # sufficient to compute common prefix and drop that, since in the above + # example the normalized type name would be d.e, which proto resolves inner + # most as a.b.c.d.e (bad) instead of the intended a.b.d.e. + field_fqn_splits = field_fqn[1:].split('.') + type_context_splits = type_context.name.split('.')[:-1] + remaining_field_fqn_splits = deque(field_fqn_splits[:-1]) + normalized_splits = deque([field_fqn_splits[-1]]) + + if list(remaining_field_fqn_splits)[:1] != type_context_splits[:1] and ( + len(remaining_field_fqn_splits) == 0 or + remaining_field_fqn_splits[0] in type_context_splits[1:]): + # Notice that in some cases it is error-prone to normalize a type name. + # E.g., it would be an error to replace ".external.Type" with "external.Type" + # in the context of "envoy.extensions.type.external.vX.Config". + # In such a context protoc resolves "external.Type" into + # "envoy.extensions.type.external.Type", which is exactly what the use of a + # fully-qualified name ".external.Type" was meant to prevent. + # + # A type SHOULD remain fully-qualified under the following conditions: + # 1. its root package is different from the root package of the context type + # 2. EITHER the type doesn't belong to any package at all + # OR its root package has a name that collides with one of the packages + # of the context type + # + # E.g., + # a) although ".some.Type" has a different root package than the context type + # "TopLevelType", it is still safe to normalize it into "some.Type" + # b) although ".google.protobuf.Any" has a different root package than the context type + # "envoy.api.v2.Cluster", it still safe to normalize it into "google.protobuf.Any" + # c) it is error-prone to normalize ".TopLevelType" in the context of "some.Type" + # into "TopLevelType" + # d) it is error-prone to normalize ".external.Type" in the context of + # "envoy.extensions.type.external.vX.Config" into "external.Type" + return field_fqn + + def equivalent_in_type_context(splits): + type_context_splits_tmp = deque(type_context_splits) + while type_context_splits_tmp: + # If we're in a.b.c and the FQN is a.d.Foo, we want to return true once + # we have type_context_splits_tmp as [a] and splits as [d, Foo]. + if list(type_context_splits_tmp) + list(splits) == field_fqn_splits: + return True + # If we're in a.b.c.d.e.f and the FQN is a.b.d.e.Foo, we want to return True + # once we have type_context_splits_tmp as [a] and splits as [b, d, e, Foo], but + # not when type_context_splits_tmp is [a, b, c] and FQN is [d, e, Foo]. + if len(splits) > 1 and '.'.join(type_context_splits_tmp).endswith('.'.join( + list(splits)[:-1])): + return False + type_context_splits_tmp.pop() + return False + + while remaining_field_fqn_splits and not equivalent_in_type_context(normalized_splits): + normalized_splits.appendleft(remaining_field_fqn_splits.pop()) + + # `extensions` is a keyword in proto2, and protoc will throw error if a type name + # starts with `extensions.`. + if normalized_splits[0] == 'extensions': + normalized_splits.appendleft(remaining_field_fqn_splits.pop()) + + return '.'.join(normalized_splits) + return field_fqn def type_name_from_fqn(fqn): - return fqn[1:] + return fqn[1:] def format_field_type(type_context, field): - """Format a FieldDescriptorProto type description. + """Format a FieldDescriptorProto type description. Args: type_context: contextual information for message/enum/field. @@ -395,43 +395,43 @@ def format_field_type(type_context, field): Returns: Formatted proto field type as string. """ - label = 'repeated ' if field.label == field.LABEL_REPEATED else '' - type_name = label + normalize_field_type_name(type_context, field.type_name) - - if field.type == field.TYPE_MESSAGE: - if type_context.map_typenames and type_name_from_fqn( - field.type_name) in type_context.map_typenames: - return 'map<%s, %s>' % tuple( - map(functools.partial(format_field_type, type_context), - type_context.map_typenames[type_name_from_fqn(field.type_name)])) - return type_name - elif field.type_name: - return type_name - - pretty_type_names = { - field.TYPE_DOUBLE: 'double', - field.TYPE_FLOAT: 'float', - field.TYPE_INT32: 'int32', - field.TYPE_SFIXED32: 'int32', - field.TYPE_SINT32: 'int32', - field.TYPE_FIXED32: 'uint32', - field.TYPE_UINT32: 'uint32', - field.TYPE_INT64: 'int64', - field.TYPE_SFIXED64: 'int64', - field.TYPE_SINT64: 'int64', - field.TYPE_FIXED64: 'uint64', - field.TYPE_UINT64: 'uint64', - field.TYPE_BOOL: 'bool', - field.TYPE_STRING: 'string', - field.TYPE_BYTES: 'bytes', - } - if field.type in pretty_type_names: - return label + pretty_type_names[field.type] - raise ProtoPrintError('Unknown field type ' + str(field.type)) + label = 'repeated ' if field.label == field.LABEL_REPEATED else '' + type_name = label + normalize_field_type_name(type_context, field.type_name) + + if field.type == field.TYPE_MESSAGE: + if type_context.map_typenames and type_name_from_fqn( + field.type_name) in type_context.map_typenames: + return 'map<%s, %s>' % tuple( + map(functools.partial(format_field_type, type_context), + type_context.map_typenames[type_name_from_fqn(field.type_name)])) + return type_name + elif field.type_name: + return type_name + + pretty_type_names = { + field.TYPE_DOUBLE: 'double', + field.TYPE_FLOAT: 'float', + field.TYPE_INT32: 'int32', + field.TYPE_SFIXED32: 'int32', + field.TYPE_SINT32: 'int32', + field.TYPE_FIXED32: 'uint32', + field.TYPE_UINT32: 'uint32', + field.TYPE_INT64: 'int64', + field.TYPE_SFIXED64: 'int64', + field.TYPE_SINT64: 'int64', + field.TYPE_FIXED64: 'uint64', + field.TYPE_UINT64: 'uint64', + field.TYPE_BOOL: 'bool', + field.TYPE_STRING: 'string', + field.TYPE_BYTES: 'bytes', + } + if field.type in pretty_type_names: + return label + pretty_type_names[field.type] + raise ProtoPrintError('Unknown field type ' + str(field.type)) def format_service_method(type_context, method): - """Format a service MethodDescriptorProto. + """Format a service MethodDescriptorProto. Args: type_context: contextual information for method. @@ -441,19 +441,19 @@ def format_service_method(type_context, method): Formatted service method as string. """ - def format_streaming(s): - return 'stream ' if s else '' + def format_streaming(s): + return 'stream ' if s else '' - leading_comment, trailing_comment = format_type_context_comments(type_context) - return '%srpc %s(%s%s%s) returns (%s%s) {%s}\n' % ( - leading_comment, method.name, trailing_comment, format_streaming( - method.client_streaming), normalize_field_type_name( - type_context, method.input_type), format_streaming(method.server_streaming), - normalize_field_type_name(type_context, method.output_type), format_options(method.options)) + leading_comment, trailing_comment = format_type_context_comments(type_context) + return '%srpc %s(%s%s%s) returns (%s%s) {%s}\n' % ( + leading_comment, method.name, trailing_comment, format_streaming( + method.client_streaming), normalize_field_type_name( + type_context, method.input_type), format_streaming(method.server_streaming), + normalize_field_type_name(type_context, method.output_type), format_options(method.options)) def format_field(type_context, field): - """Format FieldDescriptorProto as a proto field. + """Format FieldDescriptorProto as a proto field. Args: type_context: contextual information for message/enum/field. @@ -462,17 +462,17 @@ def format_field(type_context, field): Returns: Formatted proto field as a string. """ - if protoxform_options.has_hide_option(field.options): - return '' - leading_comment, trailing_comment = format_type_context_comments(type_context) + if protoxform_options.has_hide_option(field.options): + return '' + leading_comment, trailing_comment = format_type_context_comments(type_context) - return '%s%s %s = %d%s;\n%s' % (leading_comment, format_field_type( - type_context, field), field.name, field.number, format_options( - field.options), trailing_comment) + return '%s%s %s = %d%s;\n%s' % (leading_comment, format_field_type( + type_context, field), field.name, field.number, format_options( + field.options), trailing_comment) def format_enum_value(type_context, value): - """Format a EnumValueDescriptorProto as a proto enum value. + """Format a EnumValueDescriptorProto as a proto enum value. Args: type_context: contextual information for message/enum/field. @@ -481,16 +481,16 @@ def format_enum_value(type_context, value): Returns: Formatted proto enum value as a string. """ - if protoxform_options.has_hide_option(value.options): - return '' - leading_comment, trailing_comment = format_type_context_comments(type_context) - formatted_annotations = format_options(value.options) - return '%s%s = %d%s;\n%s' % (leading_comment, value.name, value.number, formatted_annotations, - trailing_comment) + if protoxform_options.has_hide_option(value.options): + return '' + leading_comment, trailing_comment = format_type_context_comments(type_context) + formatted_annotations = format_options(value.options) + return '%s%s = %d%s;\n%s' % (leading_comment, value.name, value.number, formatted_annotations, + trailing_comment) def text_format_value(field, value): - """Format the value as protobuf text format + """Format the value as protobuf text format Args: field: a FieldDescriptor that describes the field @@ -499,13 +499,13 @@ def text_format_value(field, value): Returns: value in protobuf text format """ - out = io.StringIO() - text_format.PrintFieldValue(field, value, out) - return out.getvalue() + out = io.StringIO() + text_format.PrintFieldValue(field, value, out) + return out.getvalue() def format_options(options): - """Format *Options (e.g. + """Format *Options (e.g. MessageOptions, FieldOptions) message. @@ -516,30 +516,30 @@ def format_options(options): Formatted options as a string. """ - formatted_options = [] - for option_descriptor, option_value in sorted(options.ListFields(), key=lambda x: x[0].number): - option_name = '({})'.format( - option_descriptor.full_name) if option_descriptor.is_extension else option_descriptor.name - if option_descriptor.message_type and option_descriptor.label != option_descriptor.LABEL_REPEATED: - formatted_options.extend([ - '{}.{} = {}'.format(option_name, subfield.name, text_format_value(subfield, value)) - for subfield, value in option_value.ListFields() - ]) - else: - formatted_options.append('{} = {}'.format(option_name, - text_format_value(option_descriptor, option_value))) - - if formatted_options: - if options.DESCRIPTOR.name in ('EnumValueOptions', 'FieldOptions'): - return '[{}]'.format(','.join(formatted_options)) - else: - return format_block(''.join( - 'option {};\n'.format(formatted_option) for formatted_option in formatted_options)) - return '' + formatted_options = [] + for option_descriptor, option_value in sorted(options.ListFields(), key=lambda x: x[0].number): + option_name = '({})'.format(option_descriptor.full_name + ) if option_descriptor.is_extension else option_descriptor.name + if option_descriptor.message_type and option_descriptor.label != option_descriptor.LABEL_REPEATED: + formatted_options.extend([ + '{}.{} = {}'.format(option_name, subfield.name, text_format_value(subfield, value)) + for subfield, value in option_value.ListFields() + ]) + else: + formatted_options.append('{} = {}'.format( + option_name, text_format_value(option_descriptor, option_value))) + + if formatted_options: + if options.DESCRIPTOR.name in ('EnumValueOptions', 'FieldOptions'): + return '[{}]'.format(','.join(formatted_options)) + else: + return format_block(''.join( + 'option {};\n'.format(formatted_option) for formatted_option in formatted_options)) + return '' def format_reserved(enum_or_msg_proto): - """Format reserved values/names in a [Enum]DescriptorProto. + """Format reserved values/names in a [Enum]DescriptorProto. Args: enum_or_msg_proto: [Enum]DescriptorProto message. @@ -547,107 +547,108 @@ def format_reserved(enum_or_msg_proto): Returns: Formatted enum_or_msg_proto as a string. """ - rrs = copy.deepcopy(enum_or_msg_proto.reserved_range) - # Fixups for singletons that don't seem to always have [inclusive, exclusive) - # format when parsed by protoc. - for rr in rrs: - if rr.start == rr.end: - rr.end += 1 - reserved_fields = format_block( - 'reserved %s;\n' % - ','.join(map(str, sum([list(range(rr.start, rr.end)) for rr in rrs], [])))) if rrs else '' - if enum_or_msg_proto.reserved_name: - reserved_fields += format_block('reserved %s;\n' % - ', '.join('"%s"' % n for n in enum_or_msg_proto.reserved_name)) - return reserved_fields + rrs = copy.deepcopy(enum_or_msg_proto.reserved_range) + # Fixups for singletons that don't seem to always have [inclusive, exclusive) + # format when parsed by protoc. + for rr in rrs: + if rr.start == rr.end: + rr.end += 1 + reserved_fields = format_block( + 'reserved %s;\n' % + ','.join(map(str, sum([list(range(rr.start, rr.end)) for rr in rrs], [])))) if rrs else '' + if enum_or_msg_proto.reserved_name: + reserved_fields += format_block( + 'reserved %s;\n' % ', '.join('"%s"' % n for n in enum_or_msg_proto.reserved_name)) + return reserved_fields class ProtoFormatVisitor(visitor.Visitor): - """Visitor to generate a proto representation from a FileDescriptor proto. + """Visitor to generate a proto representation from a FileDescriptor proto. See visitor.Visitor for visitor method docs comments. """ - def visit_service(self, service_proto, type_context): - leading_comment, trailing_comment = format_type_context_comments(type_context) - methods = '\n'.join( - format_service_method(type_context.extend_method(index, m.name), m) - for index, m in enumerate(service_proto.method)) - options = format_block(format_options(service_proto.options)) - return '%sservice %s {\n%s%s%s\n}\n' % (leading_comment, service_proto.name, options, - trailing_comment, methods) - - def visit_enum(self, enum_proto, type_context): - if protoxform_options.has_hide_option(enum_proto.options): - return '' - leading_comment, trailing_comment = format_type_context_comments(type_context) - formatted_options = format_options(enum_proto.options) - reserved_fields = format_reserved(enum_proto) - values = [ - format_enum_value(type_context.extend_field(index, value.name), value) - for index, value in enumerate(enum_proto.value) - ] - joined_values = ('\n' if any('//' in v for v in values) else '').join(values) - return '%senum %s {\n%s%s%s%s\n}\n' % (leading_comment, enum_proto.name, trailing_comment, - formatted_options, reserved_fields, joined_values) - - def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): - # Skip messages synthesized to represent map types. - if msg_proto.options.map_entry: - return '' - if protoxform_options.has_hide_option(msg_proto.options): - return '' - annotation_xforms = { - annotations.NEXT_FREE_FIELD_ANNOTATION: create_next_free_field_xform(msg_proto) - } - leading_comment, trailing_comment = format_type_context_comments(type_context, - annotation_xforms) - formatted_options = format_options(msg_proto.options) - formatted_enums = format_block('\n'.join(nested_enums)) - formatted_msgs = format_block('\n'.join(nested_msgs)) - reserved_fields = format_reserved(msg_proto) - # Recover the oneof structure. This needs some extra work, since - # DescriptorProto just gives use fields and a oneof_index that can allow - # recovery of the original oneof placement. - fields = '' - oneof_index = None - for index, field in enumerate(msg_proto.field): - if oneof_index is not None: - if not field.HasField('oneof_index') or field.oneof_index != oneof_index: - fields += '}\n\n' - oneof_index = None - if oneof_index is None and field.HasField('oneof_index'): - oneof_index = field.oneof_index - assert (oneof_index < len(msg_proto.oneof_decl)) - oneof_proto = msg_proto.oneof_decl[oneof_index] - oneof_leading_comment, oneof_trailing_comment = format_type_context_comments( - type_context.extend_oneof(oneof_index, field.name)) - fields += '%soneof %s {\n%s%s' % (oneof_leading_comment, oneof_proto.name, - oneof_trailing_comment, format_options( - oneof_proto.options)) - fields += format_block(format_field(type_context.extend_field(index, field.name), field)) - if oneof_index is not None: - fields += '}\n\n' - return '%smessage %s {\n%s%s%s%s%s%s\n}\n' % (leading_comment, msg_proto.name, trailing_comment, - formatted_options, formatted_enums, - formatted_msgs, reserved_fields, fields) - - def visit_file(self, file_proto, type_context, services, msgs, enums): - empty_file = len(services) == 0 and len(enums) == 0 and len(msgs) == 0 - header = format_header_from_file(type_context.source_code_info, file_proto, empty_file) - formatted_services = format_block('\n'.join(services)) - formatted_enums = format_block('\n'.join(enums)) - formatted_msgs = format_block('\n'.join(msgs)) - return clang_format(header + formatted_services + formatted_enums + formatted_msgs) + def visit_service(self, service_proto, type_context): + leading_comment, trailing_comment = format_type_context_comments(type_context) + methods = '\n'.join( + format_service_method(type_context.extend_method(index, m.name), m) + for index, m in enumerate(service_proto.method)) + options = format_block(format_options(service_proto.options)) + return '%sservice %s {\n%s%s%s\n}\n' % (leading_comment, service_proto.name, options, + trailing_comment, methods) + + def visit_enum(self, enum_proto, type_context): + if protoxform_options.has_hide_option(enum_proto.options): + return '' + leading_comment, trailing_comment = format_type_context_comments(type_context) + formatted_options = format_options(enum_proto.options) + reserved_fields = format_reserved(enum_proto) + values = [ + format_enum_value(type_context.extend_field(index, value.name), value) + for index, value in enumerate(enum_proto.value) + ] + joined_values = ('\n' if any('//' in v for v in values) else '').join(values) + return '%senum %s {\n%s%s%s%s\n}\n' % (leading_comment, enum_proto.name, trailing_comment, + formatted_options, reserved_fields, joined_values) + + def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): + # Skip messages synthesized to represent map types. + if msg_proto.options.map_entry: + return '' + if protoxform_options.has_hide_option(msg_proto.options): + return '' + annotation_xforms = { + annotations.NEXT_FREE_FIELD_ANNOTATION: create_next_free_field_xform(msg_proto) + } + leading_comment, trailing_comment = format_type_context_comments( + type_context, annotation_xforms) + formatted_options = format_options(msg_proto.options) + formatted_enums = format_block('\n'.join(nested_enums)) + formatted_msgs = format_block('\n'.join(nested_msgs)) + reserved_fields = format_reserved(msg_proto) + # Recover the oneof structure. This needs some extra work, since + # DescriptorProto just gives use fields and a oneof_index that can allow + # recovery of the original oneof placement. + fields = '' + oneof_index = None + for index, field in enumerate(msg_proto.field): + if oneof_index is not None: + if not field.HasField('oneof_index') or field.oneof_index != oneof_index: + fields += '}\n\n' + oneof_index = None + if oneof_index is None and field.HasField('oneof_index'): + oneof_index = field.oneof_index + assert (oneof_index < len(msg_proto.oneof_decl)) + oneof_proto = msg_proto.oneof_decl[oneof_index] + oneof_leading_comment, oneof_trailing_comment = format_type_context_comments( + type_context.extend_oneof(oneof_index, field.name)) + fields += '%soneof %s {\n%s%s' % (oneof_leading_comment, oneof_proto.name, + oneof_trailing_comment, + format_options(oneof_proto.options)) + fields += format_block(format_field(type_context.extend_field(index, field.name), + field)) + if oneof_index is not None: + fields += '}\n\n' + return '%smessage %s {\n%s%s%s%s%s%s\n}\n' % ( + leading_comment, msg_proto.name, trailing_comment, formatted_options, formatted_enums, + formatted_msgs, reserved_fields, fields) + + def visit_file(self, file_proto, type_context, services, msgs, enums): + empty_file = len(services) == 0 and len(enums) == 0 and len(msgs) == 0 + header = format_header_from_file(type_context.source_code_info, file_proto, empty_file) + formatted_services = format_block('\n'.join(services)) + formatted_enums = format_block('\n'.join(enums)) + formatted_msgs = format_block('\n'.join(msgs)) + return clang_format(header + formatted_services + formatted_enums + formatted_msgs) if __name__ == '__main__': - proto_desc_path = sys.argv[1] - file_proto = descriptor_pb2.FileDescriptorProto() - input_text = pathlib.Path(proto_desc_path).read_text() - if not input_text: - sys.exit(0) - text_format.Merge(input_text, file_proto) - dst_path = pathlib.Path(sys.argv[2]) - utils.load_type_db(sys.argv[3]) - dst_path.write_bytes(traverse.traverse_file(file_proto, ProtoFormatVisitor())) + proto_desc_path = sys.argv[1] + file_proto = descriptor_pb2.FileDescriptorProto() + input_text = pathlib.Path(proto_desc_path).read_text() + if not input_text: + sys.exit(0) + text_format.Merge(input_text, file_proto) + dst_path = pathlib.Path(sys.argv[2]) + utils.load_type_db(sys.argv[3]) + dst_path.write_bytes(traverse.traverse_file(file_proto, ProtoFormatVisitor())) diff --git a/tools/protoxform/protoxform.py b/tools/protoxform/protoxform.py index ed5e4f6aa448..104f27e8d426 100755 --- a/tools/protoxform/protoxform.py +++ b/tools/protoxform/protoxform.py @@ -25,74 +25,74 @@ class ProtoXformError(Exception): - """Base error class for the protoxform module.""" + """Base error class for the protoxform module.""" class ProtoFormatVisitor(visitor.Visitor): - """Visitor to generate a proto representation from a FileDescriptor proto. + """Visitor to generate a proto representation from a FileDescriptor proto. See visitor.Visitor for visitor method docs comments. """ - def __init__(self, active_or_frozen, params): - if params['type_db_path']: - utils.load_type_db(params['type_db_path']) - self._freeze = 'extra_args' in params and params['extra_args'] == 'freeze' - self._active_or_frozen = active_or_frozen - - def visit_service(self, service_proto, type_context): - return None - - def visit_enum(self, enum_proto, type_context): - return None - - def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): - return None - - def visit_file(self, file_proto, type_context, services, msgs, enums): - # Freeze protos that have next major version candidates. - typedb = utils.get_type_db() - output_proto = copy.deepcopy(file_proto) - existing_pkg_version_status = output_proto.options.Extensions[ - status_pb2.file_status].package_version_status - empty_file = len(services) == 0 and len(enums) == 0 and len(msgs) == 0 - pkg_version_status_exempt = file_proto.name.startswith('envoy/annotations') or empty_file - # It's a format error not to set package_version_status. - if existing_pkg_version_status == status_pb2.UNKNOWN and not pkg_version_status_exempt: - raise ProtoXformError('package_version_status must be set in %s' % file_proto.name) - # Only update package_version_status for .active_or_frozen.proto, - # migrate.version_upgrade_xform has taken care of next major version - # candidates. - if self._active_or_frozen and not pkg_version_status_exempt: - # Freeze if this is an active package with a next major version. Preserve - # frozen status otherwise. - if self._freeze and typedb.next_version_protos.get(output_proto.name, None): - target_pkg_version_status = status_pb2.FROZEN - elif existing_pkg_version_status == status_pb2.FROZEN: - target_pkg_version_status = status_pb2.FROZEN - else: - assert (existing_pkg_version_status == status_pb2.ACTIVE) - target_pkg_version_status = status_pb2.ACTIVE - output_proto.options.Extensions[ - status_pb2.file_status].package_version_status = target_pkg_version_status - return str(output_proto) + def __init__(self, active_or_frozen, params): + if params['type_db_path']: + utils.load_type_db(params['type_db_path']) + self._freeze = 'extra_args' in params and params['extra_args'] == 'freeze' + self._active_or_frozen = active_or_frozen + + def visit_service(self, service_proto, type_context): + return None + + def visit_enum(self, enum_proto, type_context): + return None + + def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): + return None + + def visit_file(self, file_proto, type_context, services, msgs, enums): + # Freeze protos that have next major version candidates. + typedb = utils.get_type_db() + output_proto = copy.deepcopy(file_proto) + existing_pkg_version_status = output_proto.options.Extensions[ + status_pb2.file_status].package_version_status + empty_file = len(services) == 0 and len(enums) == 0 and len(msgs) == 0 + pkg_version_status_exempt = file_proto.name.startswith('envoy/annotations') or empty_file + # It's a format error not to set package_version_status. + if existing_pkg_version_status == status_pb2.UNKNOWN and not pkg_version_status_exempt: + raise ProtoXformError('package_version_status must be set in %s' % file_proto.name) + # Only update package_version_status for .active_or_frozen.proto, + # migrate.version_upgrade_xform has taken care of next major version + # candidates. + if self._active_or_frozen and not pkg_version_status_exempt: + # Freeze if this is an active package with a next major version. Preserve + # frozen status otherwise. + if self._freeze and typedb.next_version_protos.get(output_proto.name, None): + target_pkg_version_status = status_pb2.FROZEN + elif existing_pkg_version_status == status_pb2.FROZEN: + target_pkg_version_status = status_pb2.FROZEN + else: + assert (existing_pkg_version_status == status_pb2.ACTIVE) + target_pkg_version_status = status_pb2.ACTIVE + output_proto.options.Extensions[ + status_pb2.file_status].package_version_status = target_pkg_version_status + return str(output_proto) def main(): - plugin.plugin([ - plugin.direct_output_descriptor('.active_or_frozen.proto', - functools.partial(ProtoFormatVisitor, True), - want_params=True), - plugin.OutputDescriptor('.next_major_version_candidate.proto', - functools.partial(ProtoFormatVisitor, False), - functools.partial(migrate.version_upgrade_xform, 2, False), - want_params=True), - plugin.OutputDescriptor('.next_major_version_candidate.envoy_internal.proto', - functools.partial(ProtoFormatVisitor, False), - functools.partial(migrate.version_upgrade_xform, 2, True), - want_params=True) - ]) + plugin.plugin([ + plugin.direct_output_descriptor('.active_or_frozen.proto', + functools.partial(ProtoFormatVisitor, True), + want_params=True), + plugin.OutputDescriptor('.next_major_version_candidate.proto', + functools.partial(ProtoFormatVisitor, False), + functools.partial(migrate.version_upgrade_xform, 2, False), + want_params=True), + plugin.OutputDescriptor('.next_major_version_candidate.envoy_internal.proto', + functools.partial(ProtoFormatVisitor, False), + functools.partial(migrate.version_upgrade_xform, 2, True), + want_params=True) + ]) if __name__ == '__main__': - main() + main() diff --git a/tools/protoxform/protoxform_test_helper.py b/tools/protoxform/protoxform_test_helper.py index 8d6ab2a5b7a7..ff4317d6fb72 100755 --- a/tools/protoxform/protoxform_test_helper.py +++ b/tools/protoxform/protoxform_test_helper.py @@ -12,7 +12,7 @@ def path_and_filename(label): - """Retrieve actual path and filename from bazel label + """Retrieve actual path and filename from bazel label Args: label: bazel label to specify target proto. @@ -20,19 +20,19 @@ def path_and_filename(label): Returns: actual path and filename """ - if label.startswith('/'): - label = label.replace('//', '/', 1) - elif label.startswith('@'): - label = re.sub(r'@.*/', '/', label) - else: - return label - label = label.replace(":", "/") - splitted_label = label.split('/') - return ['/'.join(splitted_label[:len(splitted_label) - 1]), splitted_label[-1]] + if label.startswith('/'): + label = label.replace('//', '/', 1) + elif label.startswith('@'): + label = re.sub(r'@.*/', '/', label) + else: + return label + label = label.replace(":", "/") + splitted_label = label.split('/') + return ['/'.join(splitted_label[:len(splitted_label) - 1]), splitted_label[-1]] def golden_proto_file(path, filename, version): - """Retrieve golden proto file path. In general, those are placed in tools/testdata/protoxform. + """Retrieve golden proto file path. In general, those are placed in tools/testdata/protoxform. Args: path: target proto path @@ -42,27 +42,27 @@ def golden_proto_file(path, filename, version): Returns: actual golden proto absolute path """ - base = "./" - base += path + "/" + filename + "." + version + ".gold" - return os.path.abspath(base) + base = "./" + base += path + "/" + filename + "." + version + ".gold" + return os.path.abspath(base) def proto_print(src, dst): - """Pretty-print FileDescriptorProto to a destination file. + """Pretty-print FileDescriptorProto to a destination file. Args: src: source path for FileDescriptorProto. dst: destination path for formatted proto. """ - print('proto_print %s -> %s' % (src, dst)) - subprocess.check_call([ - 'bazel-bin/tools/protoxform/protoprint', src, dst, - './bazel-bin/tools/protoxform/protoprint.runfiles/envoy/tools/type_whisperer/api_type_db.pb_text' - ]) + print('proto_print %s -> %s' % (src, dst)) + subprocess.check_call([ + 'bazel-bin/tools/protoxform/protoprint', src, dst, + './bazel-bin/tools/protoxform/protoprint.runfiles/envoy/tools/type_whisperer/api_type_db.pb_text' + ]) def result_proto_file(cmd, path, tmp, filename, version): - """Retrieve result proto file path. In general, those are placed in bazel artifacts. + """Retrieve result proto file path. In general, those are placed in bazel artifacts. Args: cmd: fix or freeze? @@ -74,17 +74,17 @@ def result_proto_file(cmd, path, tmp, filename, version): Returns: actual result proto absolute path """ - base = "./bazel-bin" - base += os.path.join(path, "%s_protos" % cmd) - base += os.path.join(base, path) - base += "/{0}.{1}.proto".format(filename, version) - dst = os.path.join(tmp, filename) - proto_print(os.path.abspath(base), dst) - return dst + base = "./bazel-bin" + base += os.path.join(path, "%s_protos" % cmd) + base += os.path.join(base, path) + base += "/{0}.{1}.proto".format(filename, version) + dst = os.path.join(tmp, filename) + proto_print(os.path.abspath(base), dst) + return dst def diff(result_file, golden_file): - """Execute diff command with unified form + """Execute diff command with unified form Args: result_file: result proto file @@ -93,15 +93,15 @@ def diff(result_file, golden_file): Returns: output and status code """ - command = 'diff -u ' - command += result_file + ' ' - command += golden_file - status, stdout, stderr = run_command(command) - return [status, stdout, stderr] + command = 'diff -u ' + command += result_file + ' ' + command += golden_file + status, stdout, stderr = run_command(command) + return [status, stdout, stderr] def run(cmd, path, filename, version): - """Run main execution for protoxform test + """Run main execution for protoxform test Args: cmd: fix or freeze? @@ -112,34 +112,34 @@ def run(cmd, path, filename, version): Returns: result message extracted from diff command """ - message = "" - with tempfile.TemporaryDirectory() as tmp: - golden_path = golden_proto_file(path, filename, version) - test_path = result_proto_file(cmd, path, tmp, filename, version) - if os.stat(golden_path).st_size == 0 and not os.path.exists(test_path): - return message + message = "" + with tempfile.TemporaryDirectory() as tmp: + golden_path = golden_proto_file(path, filename, version) + test_path = result_proto_file(cmd, path, tmp, filename, version) + if os.stat(golden_path).st_size == 0 and not os.path.exists(test_path): + return message - status, stdout, stderr = diff(golden_path, test_path) + status, stdout, stderr = diff(golden_path, test_path) - if status != 0: - message = '\n'.join([str(line) for line in stdout + stderr]) + if status != 0: + message = '\n'.join([str(line) for line in stdout + stderr]) - return message + return message if __name__ == "__main__": - messages = "" - logging.basicConfig(format='%(message)s') - cmd = sys.argv[1] - for target in sys.argv[2:]: - path, filename = path_and_filename(target) - messages += run(cmd, path, filename, 'active_or_frozen') - messages += run(cmd, path, filename, 'next_major_version_candidate') - messages += run(cmd, path, filename, 'next_major_version_candidate.envoy_internal') - - if len(messages) == 0: - logging.warning("PASS") - sys.exit(0) - else: - logging.error("FAILED:\n{}".format(messages)) - sys.exit(1) + messages = "" + logging.basicConfig(format='%(message)s') + cmd = sys.argv[1] + for target in sys.argv[2:]: + path, filename = path_and_filename(target) + messages += run(cmd, path, filename, 'active_or_frozen') + messages += run(cmd, path, filename, 'next_major_version_candidate') + messages += run(cmd, path, filename, 'next_major_version_candidate.envoy_internal') + + if len(messages) == 0: + logging.warning("PASS") + sys.exit(0) + else: + logging.error("FAILED:\n{}".format(messages)) + sys.exit(1) diff --git a/tools/protoxform/utils.py b/tools/protoxform/utils.py index 57093581c328..cc673e7209ae 100644 --- a/tools/protoxform/utils.py +++ b/tools/protoxform/utils.py @@ -8,12 +8,12 @@ def get_type_db(): - assert _typedb != None - return _typedb + assert _typedb != None + return _typedb def load_type_db(type_db_path): - global _typedb - _typedb = TypeDb() - with open(type_db_path, 'r') as f: - text_format.Merge(f.read(), _typedb) + global _typedb + _typedb = TypeDb() + with open(type_db_path, 'r') as f: + text_format.Merge(f.read(), _typedb) diff --git a/tools/run_command.py b/tools/run_command.py index 8d7a7c867f9a..79343bb6e23f 100644 --- a/tools/run_command.py +++ b/tools/run_command.py @@ -4,7 +4,7 @@ # Echoes and runs an OS command, returning exit status and the captured # stdout and stderr as a string array. def run_command(command): - proc = subprocess.run([command], shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.run([command], shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - return proc.returncode, proc.stdout.decode('utf-8').split('\n'), proc.stderr.decode( - 'utf-8').split('\n') + return proc.returncode, proc.stdout.decode('utf-8').split('\n'), proc.stderr.decode( + 'utf-8').split('\n') diff --git a/tools/socket_passing.py b/tools/socket_passing.py index 620331aadcfd..fc0321a9b7a1 100755 --- a/tools/socket_passing.py +++ b/tools/socket_passing.py @@ -25,94 +25,97 @@ # Because the hot restart files are yaml but yaml support is not included in # python by default, we parse this fairly manually. def generate_new_config(original_yaml, admin_address, updated_json): - # Get original listener addresses - with open(original_yaml, 'r') as original_file: - sys.stdout.write('Admin address is ' + admin_address + '\n') - try: - admin_conn = http.client.HTTPConnection(admin_address) - admin_conn.request('GET', '/listeners?format=json') - admin_response = admin_conn.getresponse() - if not admin_response.status == 200: - return False - discovered_listeners = json.loads(admin_response.read().decode('utf-8')) - except Exception as e: - sys.stderr.write('Cannot connect to admin: %s\n' % e) - return False - else: - raw_yaml = original_file.readlines() - index = 0 - for discovered in discovered_listeners['listener_statuses']: - replaced = False - if 'pipe' in discovered['local_address']: - path = discovered['local_address']['pipe']['path'] - for index in range(index + 1, len(raw_yaml) - 1): - if 'pipe:' in raw_yaml[index] and 'path:' in raw_yaml[index + 1]: - raw_yaml[index + 1] = re.sub('path:.*', 'path: "' + path + '"', raw_yaml[index + 1]) - replaced = True - break + # Get original listener addresses + with open(original_yaml, 'r') as original_file: + sys.stdout.write('Admin address is ' + admin_address + '\n') + try: + admin_conn = http.client.HTTPConnection(admin_address) + admin_conn.request('GET', '/listeners?format=json') + admin_response = admin_conn.getresponse() + if not admin_response.status == 200: + return False + discovered_listeners = json.loads(admin_response.read().decode('utf-8')) + except Exception as e: + sys.stderr.write('Cannot connect to admin: %s\n' % e) + return False else: - addr = discovered['local_address']['socket_address']['address'] - port = str(discovered['local_address']['socket_address']['port_value']) - if addr[0] == '[': - addr = addr[1:-1] # strip [] from ipv6 address. - for index in range(index + 1, len(raw_yaml) - 2): - if ('socket_address:' in raw_yaml[index] and 'address:' in raw_yaml[index + 1] and - 'port_value:' in raw_yaml[index + 2]): - raw_yaml[index + 1] = re.sub('address:.*', 'address: "' + addr + '"', - raw_yaml[index + 1]) - raw_yaml[index + 2] = re.sub('port_value:.*', 'port_value: ' + port, - raw_yaml[index + 2]) - replaced = True - break - if replaced: - sys.stderr.write('replaced listener at line ' + str(index) + ' with ' + str(discovered) + - '\n') - else: - sys.stderr.write('Failed to replace a discovered listener ' + str(discovered) + '\n') - return False - with open(updated_json, 'w') as outfile: - outfile.writelines(raw_yaml) - finally: - admin_conn.close() + raw_yaml = original_file.readlines() + index = 0 + for discovered in discovered_listeners['listener_statuses']: + replaced = False + if 'pipe' in discovered['local_address']: + path = discovered['local_address']['pipe']['path'] + for index in range(index + 1, len(raw_yaml) - 1): + if 'pipe:' in raw_yaml[index] and 'path:' in raw_yaml[index + 1]: + raw_yaml[index + 1] = re.sub('path:.*', 'path: "' + path + '"', + raw_yaml[index + 1]) + replaced = True + break + else: + addr = discovered['local_address']['socket_address']['address'] + port = str(discovered['local_address']['socket_address']['port_value']) + if addr[0] == '[': + addr = addr[1:-1] # strip [] from ipv6 address. + for index in range(index + 1, len(raw_yaml) - 2): + if ('socket_address:' in raw_yaml[index] and + 'address:' in raw_yaml[index + 1] and + 'port_value:' in raw_yaml[index + 2]): + raw_yaml[index + 1] = re.sub('address:.*', 'address: "' + addr + '"', + raw_yaml[index + 1]) + raw_yaml[index + 2] = re.sub('port_value:.*', 'port_value: ' + port, + raw_yaml[index + 2]) + replaced = True + break + if replaced: + sys.stderr.write('replaced listener at line ' + str(index) + ' with ' + + str(discovered) + '\n') + else: + sys.stderr.write('Failed to replace a discovered listener ' + str(discovered) + + '\n') + return False + with open(updated_json, 'w') as outfile: + outfile.writelines(raw_yaml) + finally: + admin_conn.close() - return True + return True if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Replace listener addressses in json file.') - parser.add_argument('-o', - '--original_json', - type=str, - required=True, - help='Path of the original config json file') - parser.add_argument('-a', - '--admin_address_path', - type=str, - required=True, - help='Path of the admin address file') - parser.add_argument('-u', - '--updated_json', - type=str, - required=True, - help='Path to output updated json config file') - args = parser.parse_args() - admin_address_path = args.admin_address_path + parser = argparse.ArgumentParser(description='Replace listener addressses in json file.') + parser.add_argument('-o', + '--original_json', + type=str, + required=True, + help='Path of the original config json file') + parser.add_argument('-a', + '--admin_address_path', + type=str, + required=True, + help='Path of the admin address file') + parser.add_argument('-u', + '--updated_json', + type=str, + required=True, + help='Path to output updated json config file') + args = parser.parse_args() + admin_address_path = args.admin_address_path - # Read admin address from file - counter = 0 - while not os.path.exists(admin_address_path): - time.sleep(1) - counter += 1 - if counter > ADMIN_FILE_TIMEOUT_SECS: - break + # Read admin address from file + counter = 0 + while not os.path.exists(admin_address_path): + time.sleep(1) + counter += 1 + if counter > ADMIN_FILE_TIMEOUT_SECS: + break - if not os.path.exists(admin_address_path): - sys.exit(1) + if not os.path.exists(admin_address_path): + sys.exit(1) - with open(admin_address_path, 'r') as admin_address_file: - admin_address = admin_address_file.read() + with open(admin_address_path, 'r') as admin_address_file: + admin_address = admin_address_file.read() - success = generate_new_config(args.original_json, admin_address, args.updated_json) + success = generate_new_config(args.original_json, admin_address, args.updated_json) - if not success: - sys.exit(1) + if not success: + sys.exit(1) diff --git a/tools/spelling/check_spelling_pedantic.py b/tools/spelling/check_spelling_pedantic.py index d7b4b7d4cb77..6ddf70cad321 100755 --- a/tools/spelling/check_spelling_pedantic.py +++ b/tools/spelling/check_spelling_pedantic.py @@ -15,16 +15,16 @@ # Handle function rename between python 2/3. try: - input = raw_input + input = raw_input except NameError: - pass + pass try: - cmp + cmp except NameError: - def cmp(x, y): - return (x > y) - (x < y) + def cmp(x, y): + return (x > y) - (x < y) CURR_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -132,169 +132,169 @@ def cmp(x, y): def red(s): - if COLOR: - return "\33[1;31m" + s + "\033[0m" - return s + if COLOR: + return "\33[1;31m" + s + "\033[0m" + return s def debug(s): - if DEBUG > 0: - print(s) + if DEBUG > 0: + print(s) def debug1(s): - if DEBUG > 1: - print(s) + if DEBUG > 1: + print(s) class SpellChecker: - """Aspell-based spell checker.""" - - def __init__(self, dictionary_file): - self.dictionary_file = dictionary_file - self.aspell = None - self.prefixes = [] - self.suffixes = [] - self.prefix_re = None - self.suffix_re = None - - def start(self): - words, prefixes, suffixes = self.load_dictionary() - - self.prefixes = prefixes - self.suffixes = suffixes - - self.prefix_re = re.compile("(?:\s|^)((%s)-)" % ("|".join(prefixes)), re.IGNORECASE) - self.suffix_re = re.compile("(-(%s))(?:\s|$)" % ("|".join(suffixes)), re.IGNORECASE) - - # Generate aspell personal dictionary. - pws = os.path.join(CURR_DIR, '.aspell.en.pws') - with open(pws, 'w') as f: - f.write("personal_ws-1.1 en %d\n" % (len(words))) - f.writelines(words) - - # Start an aspell process. - aspell_args = ["aspell", "pipe", "--lang=en_US", "--encoding=utf-8", "--personal=" + pws] - self.aspell = subprocess.Popen(aspell_args, - bufsize=4096, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True) - - # Read the version line that aspell emits on startup. - self.aspell.stdout.readline() - - def stop(self): - if not self.aspell: - return - - self.aspell.stdin.close() - self.aspell.wait() - self.aspell = None - - def check(self, line): - if line.strip() == '': - return [] - - self.aspell.poll() - if self.aspell.returncode is not None: - print("aspell quit unexpectedly: return code %d" % (self.aspell.returncode)) - sys.exit(2) - - debug1("ASPELL< %s" % (line)) - - self.aspell.stdin.write(line + os.linesep) - self.aspell.stdin.flush() - - errors = [] - while True: - result = self.aspell.stdout.readline().strip() - debug1("ASPELL> %s" % (result)) - - # Check for end of results. - if result == "": - break - - t = result[0] - if t == "*" or t == "-" or t == "+": - # *: found in dictionary. - # -: found run-together words in dictionary. - # +: found root word in dictionary. - continue - - # & : m1, m2, ... mN, g1, g2, ... - # ? 0 : g1, g2, .... - # # - original, rem = result[2:].split(" ", 1) - - if t == "#": - # Not in dictionary, but no suggestions. - errors.append((original, int(rem), [])) - elif t == '&' or t == '?': - # Near misses and/or guesses. - _, rem = rem.split(" ", 1) # Drop N (may be 0). - o, rem = rem.split(": ", 1) # o is offset from start of line. - suggestions = rem.split(", ") - - errors.append((original, int(o), suggestions)) - else: - print("aspell produced unexpected output: %s" % (result)) - sys.exit(2) - - return errors - - def load_dictionary(self): - # Read the custom dictionary. - all_words = [] - with open(self.dictionary_file, 'r') as f: - all_words = f.readlines() - - # Strip comments, invalid words, and blank lines. - words = [w for w in all_words if len(w.strip()) > 0 and re.match(DICTIONARY_WORD, w)] - - suffixes = [w.strip()[1:] for w in all_words if w.startswith('-')] - prefixes = [w.strip()[:-1] for w in all_words if w.strip().endswith('-')] - - # Allow acronyms and abbreviations to be spelled in lowercase. - # (e.g. Convert "HTTP" into "HTTP" and "http" which also matches - # "Http"). - for word in words: - if word.isupper(): - words += word.lower() - - return (words, prefixes, suffixes) - - def add_words(self, additions): - lines = [] - with open(self.dictionary_file, 'r') as f: - lines = f.readlines() - - additions = [w + os.linesep for w in additions] - additions.sort() - - # Insert additions into the lines ignoring comments, suffixes, and blank lines. - idx = 0 - add_idx = 0 - while idx < len(lines) and add_idx < len(additions): - line = lines[idx] - if len(line.strip()) != 0 and line[0] != "#" and line[0] != '-': - c = cmp(additions[add_idx], line) - if c < 0: - lines.insert(idx, additions[add_idx]) - add_idx += 1 - elif c == 0: - add_idx += 1 - idx += 1 - - # Append any remaining additions. - lines += additions[add_idx:] - - with open(self.dictionary_file, 'w') as f: - f.writelines(lines) - - self.stop() - self.start() + """Aspell-based spell checker.""" + + def __init__(self, dictionary_file): + self.dictionary_file = dictionary_file + self.aspell = None + self.prefixes = [] + self.suffixes = [] + self.prefix_re = None + self.suffix_re = None + + def start(self): + words, prefixes, suffixes = self.load_dictionary() + + self.prefixes = prefixes + self.suffixes = suffixes + + self.prefix_re = re.compile("(?:\s|^)((%s)-)" % ("|".join(prefixes)), re.IGNORECASE) + self.suffix_re = re.compile("(-(%s))(?:\s|$)" % ("|".join(suffixes)), re.IGNORECASE) + + # Generate aspell personal dictionary. + pws = os.path.join(CURR_DIR, '.aspell.en.pws') + with open(pws, 'w') as f: + f.write("personal_ws-1.1 en %d\n" % (len(words))) + f.writelines(words) + + # Start an aspell process. + aspell_args = ["aspell", "pipe", "--lang=en_US", "--encoding=utf-8", "--personal=" + pws] + self.aspell = subprocess.Popen(aspell_args, + bufsize=4096, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + + # Read the version line that aspell emits on startup. + self.aspell.stdout.readline() + + def stop(self): + if not self.aspell: + return + + self.aspell.stdin.close() + self.aspell.wait() + self.aspell = None + + def check(self, line): + if line.strip() == '': + return [] + + self.aspell.poll() + if self.aspell.returncode is not None: + print("aspell quit unexpectedly: return code %d" % (self.aspell.returncode)) + sys.exit(2) + + debug1("ASPELL< %s" % (line)) + + self.aspell.stdin.write(line + os.linesep) + self.aspell.stdin.flush() + + errors = [] + while True: + result = self.aspell.stdout.readline().strip() + debug1("ASPELL> %s" % (result)) + + # Check for end of results. + if result == "": + break + + t = result[0] + if t == "*" or t == "-" or t == "+": + # *: found in dictionary. + # -: found run-together words in dictionary. + # +: found root word in dictionary. + continue + + # & : m1, m2, ... mN, g1, g2, ... + # ? 0 : g1, g2, .... + # # + original, rem = result[2:].split(" ", 1) + + if t == "#": + # Not in dictionary, but no suggestions. + errors.append((original, int(rem), [])) + elif t == '&' or t == '?': + # Near misses and/or guesses. + _, rem = rem.split(" ", 1) # Drop N (may be 0). + o, rem = rem.split(": ", 1) # o is offset from start of line. + suggestions = rem.split(", ") + + errors.append((original, int(o), suggestions)) + else: + print("aspell produced unexpected output: %s" % (result)) + sys.exit(2) + + return errors + + def load_dictionary(self): + # Read the custom dictionary. + all_words = [] + with open(self.dictionary_file, 'r') as f: + all_words = f.readlines() + + # Strip comments, invalid words, and blank lines. + words = [w for w in all_words if len(w.strip()) > 0 and re.match(DICTIONARY_WORD, w)] + + suffixes = [w.strip()[1:] for w in all_words if w.startswith('-')] + prefixes = [w.strip()[:-1] for w in all_words if w.strip().endswith('-')] + + # Allow acronyms and abbreviations to be spelled in lowercase. + # (e.g. Convert "HTTP" into "HTTP" and "http" which also matches + # "Http"). + for word in words: + if word.isupper(): + words += word.lower() + + return (words, prefixes, suffixes) + + def add_words(self, additions): + lines = [] + with open(self.dictionary_file, 'r') as f: + lines = f.readlines() + + additions = [w + os.linesep for w in additions] + additions.sort() + + # Insert additions into the lines ignoring comments, suffixes, and blank lines. + idx = 0 + add_idx = 0 + while idx < len(lines) and add_idx < len(additions): + line = lines[idx] + if len(line.strip()) != 0 and line[0] != "#" and line[0] != '-': + c = cmp(additions[add_idx], line) + if c < 0: + lines.insert(idx, additions[add_idx]) + add_idx += 1 + elif c == 0: + add_idx += 1 + idx += 1 + + # Append any remaining additions. + lines += additions[add_idx:] + + with open(self.dictionary_file, 'w') as f: + f.writelines(lines) + + self.stop() + self.start() # Split camel case words and run them through the dictionary. Returns @@ -303,533 +303,540 @@ def add_words(self, additions): # the split words are all spelled correctly, or may be a new set of # errors referencing the misspelled sub-words. def check_camel_case(checker, err): - (word, word_offset, _) = err + (word, word_offset, _) = err - debug("check camel case %s" % (word)) - parts = re.findall(CAMEL_CASE, word) + debug("check camel case %s" % (word)) + parts = re.findall(CAMEL_CASE, word) - # Word is not camel case: the previous result stands. - if len(parts) <= 1: - debug(" -> not camel case") - return [err] + # Word is not camel case: the previous result stands. + if len(parts) <= 1: + debug(" -> not camel case") + return [err] - split_errs = [] - part_offset = 0 - for part in parts: - debug(" -> part: %s" % (part)) - split_err = checker.check(part) - if split_err: - debug(" -> not found in dictionary") - split_errs += [(part, word_offset + part_offset, split_err[0][2])] - part_offset += len(part) + split_errs = [] + part_offset = 0 + for part in parts: + debug(" -> part: %s" % (part)) + split_err = checker.check(part) + if split_err: + debug(" -> not found in dictionary") + split_errs += [(part, word_offset + part_offset, split_err[0][2])] + part_offset += len(part) - return split_errs + return split_errs # Check for affixes and run them through the dictionary again. Returns # a replacement list of errors which may just be the original errors # or empty if an affix was successfully handled. def check_affix(checker, err): - (word, word_offset, _) = err - - debug("check affix %s" % (word)) - - for prefix in checker.prefixes: - debug(" -> try %s" % (prefix)) - if word.lower().startswith(prefix.lower()): - root = word[len(prefix):] - if root != '': - debug(" -> check %s" % (root)) - root_err = checker.check(root) - if not root_err: - debug(" -> ok") - return [] - - for suffix in checker.suffixes: - if word.lower().endswith(suffix.lower()): - root = word[:-len(suffix)] - if root != '': - debug(" -> try %s" % (root)) - root_err = checker.check(root) - if not root_err: - debug(" -> ok") - return [] - - return [err] + (word, word_offset, _) = err + + debug("check affix %s" % (word)) + + for prefix in checker.prefixes: + debug(" -> try %s" % (prefix)) + if word.lower().startswith(prefix.lower()): + root = word[len(prefix):] + if root != '': + debug(" -> check %s" % (root)) + root_err = checker.check(root) + if not root_err: + debug(" -> ok") + return [] + + for suffix in checker.suffixes: + if word.lower().endswith(suffix.lower()): + root = word[:-len(suffix)] + if root != '': + debug(" -> try %s" % (root)) + root_err = checker.check(root) + if not root_err: + debug(" -> ok") + return [] + + return [err] # Find occurrences of the regex within comment and replace the numbered # matching group with spaces. If secondary is defined, the matching # group must also match secondary to be masked. def mask_with_regex(comment, regex, group, secondary=None): - found = False - for m in regex.finditer(comment): - if secondary and secondary.search(m.group(group)) is None: - continue + found = False + for m in regex.finditer(comment): + if secondary and secondary.search(m.group(group)) is None: + continue - start = m.start(group) - end = m.end(group) + start = m.start(group) + end = m.end(group) - comment = comment[:start] + (' ' * (end - start)) + comment[end:] - found = True + comment = comment[:start] + (' ' * (end - start)) + comment[end:] + found = True - return (comment, found) + return (comment, found) # Checks the comment at offset against the spell checker. Result is an array # of tuples where each tuple is the misspelled word, it's offset from the # start of the line, and an array of possible replacements. def check_comment(checker, offset, comment): - # Strip smart quotes which cause problems sometimes. - for sq, q in SMART_QUOTES.items(): - comment = comment.replace(sq, q) + # Strip smart quotes which cause problems sometimes. + for sq, q in SMART_QUOTES.items(): + comment = comment.replace(sq, q) - # Replace TODO comments with spaces to preserve string offsets. - comment, _ = mask_with_regex(comment, TODO, 0) + # Replace TODO comments with spaces to preserve string offsets. + comment, _ = mask_with_regex(comment, TODO, 0) - # Ignore @param varname - comment, _ = mask_with_regex(comment, METHOD_DOC, 0) + # Ignore @param varname + comment, _ = mask_with_regex(comment, METHOD_DOC, 0) - # Similarly, look for base64 sequences, but they must have at least one - # digit. - comment, _ = mask_with_regex(comment, BASE64, 1, NUMBER) + # Similarly, look for base64 sequences, but they must have at least one + # digit. + comment, _ = mask_with_regex(comment, BASE64, 1, NUMBER) - # Various hex constants: - comment, _ = mask_with_regex(comment, HEX, 1) - comment, _ = mask_with_regex(comment, HEX_SIG, 1) - comment, _ = mask_with_regex(comment, PREFIXED_HEX, 0) - comment, _ = mask_with_regex(comment, BIT_FIELDS, 0) - comment, _ = mask_with_regex(comment, AB_FIELDS, 1) - comment, _ = mask_with_regex(comment, UUID, 0) - comment, _ = mask_with_regex(comment, IPV6_ADDR, 1) + # Various hex constants: + comment, _ = mask_with_regex(comment, HEX, 1) + comment, _ = mask_with_regex(comment, HEX_SIG, 1) + comment, _ = mask_with_regex(comment, PREFIXED_HEX, 0) + comment, _ = mask_with_regex(comment, BIT_FIELDS, 0) + comment, _ = mask_with_regex(comment, AB_FIELDS, 1) + comment, _ = mask_with_regex(comment, UUID, 0) + comment, _ = mask_with_regex(comment, IPV6_ADDR, 1) - # Single words in quotes: - comment, _ = mask_with_regex(comment, QUOTED_WORD, 0) + # Single words in quotes: + comment, _ = mask_with_regex(comment, QUOTED_WORD, 0) - # RST inline literals: - comment, _ = mask_with_regex(comment, RST_LITERAL, 0) + # RST inline literals: + comment, _ = mask_with_regex(comment, RST_LITERAL, 0) - # Mask the reference part of an RST link (but not the link text). Otherwise, check for a quoted - # code-like expression (which would mask the link text if not guarded). - comment, found = mask_with_regex(comment, RST_LINK, 0) - if not found: - comment, _ = mask_with_regex(comment, QUOTED_EXPR, 0) + # Mask the reference part of an RST link (but not the link text). Otherwise, check for a quoted + # code-like expression (which would mask the link text if not guarded). + comment, found = mask_with_regex(comment, RST_LINK, 0) + if not found: + comment, _ = mask_with_regex(comment, QUOTED_EXPR, 0) - comment, _ = mask_with_regex(comment, TUPLE_EXPR, 0) + comment, _ = mask_with_regex(comment, TUPLE_EXPR, 0) - # Command flags: - comment, _ = mask_with_regex(comment, FLAG, 1) + # Command flags: + comment, _ = mask_with_regex(comment, FLAG, 1) - # Github user refs: - comment, _ = mask_with_regex(comment, USER, 1) + # Github user refs: + comment, _ = mask_with_regex(comment, USER, 1) - # Absolutew paths and references to source files. - comment, _ = mask_with_regex(comment, ABSPATH, 1) - comment, _ = mask_with_regex(comment, FILEREF, 1) + # Absolutew paths and references to source files. + comment, _ = mask_with_regex(comment, ABSPATH, 1) + comment, _ = mask_with_regex(comment, FILEREF, 1) - # Ordinals (1st, 2nd...) - comment, _ = mask_with_regex(comment, ORDINALS, 0) + # Ordinals (1st, 2nd...) + comment, _ = mask_with_regex(comment, ORDINALS, 0) - if checker.prefix_re is not None: - comment, _ = mask_with_regex(comment, checker.prefix_re, 1) + if checker.prefix_re is not None: + comment, _ = mask_with_regex(comment, checker.prefix_re, 1) - if checker.suffix_re is not None: - comment, _ = mask_with_regex(comment, checker.suffix_re, 1) + if checker.suffix_re is not None: + comment, _ = mask_with_regex(comment, checker.suffix_re, 1) - # Everything got masked, return early. - if comment == "" or comment.strip() == "": - return [] + # Everything got masked, return early. + if comment == "" or comment.strip() == "": + return [] - # Mask leading punctuation. - if not comment[0].isalnum(): - comment = ' ' + comment[1:] + # Mask leading punctuation. + if not comment[0].isalnum(): + comment = ' ' + comment[1:] - errors = checker.check(comment) + errors = checker.check(comment) - # Fix up offsets relative to the start of the line vs start of the comment. - errors = [(w, o + offset, s) for (w, o, s) in errors] + # Fix up offsets relative to the start of the line vs start of the comment. + errors = [(w, o + offset, s) for (w, o, s) in errors] - # CamelCase words get split and re-checked - errors = [*chain.from_iterable(map(lambda err: check_camel_case(checker, err), errors))] + # CamelCase words get split and re-checked + errors = [*chain.from_iterable(map(lambda err: check_camel_case(checker, err), errors))] - errors = [*chain.from_iterable(map(lambda err: check_affix(checker, err), errors))] + errors = [*chain.from_iterable(map(lambda err: check_affix(checker, err), errors))] - return errors + return errors def print_error(file, line_offset, lines, errors): - # Highlight misspelled words. - line = lines[line_offset] - prefix = "%s:%d:" % (file, line_offset + 1) - for (word, offset, suggestions) in reversed(errors): - line = line[:offset] + red(word) + line[offset + len(word):] - - print("%s%s" % (prefix, line.rstrip())) - - if MARK: - # Print a caret at the start of each misspelled word. - marks = ' ' * len(prefix) - last = 0 - for (word, offset, suggestions) in errors: - marks += (' ' * (offset - last)) + '^' - last = offset + 1 - print(marks) + # Highlight misspelled words. + line = lines[line_offset] + prefix = "%s:%d:" % (file, line_offset + 1) + for (word, offset, suggestions) in reversed(errors): + line = line[:offset] + red(word) + line[offset + len(word):] + + print("%s%s" % (prefix, line.rstrip())) + + if MARK: + # Print a caret at the start of each misspelled word. + marks = ' ' * len(prefix) + last = 0 + for (word, offset, suggestions) in errors: + marks += (' ' * (offset - last)) + '^' + last = offset + 1 + print(marks) def print_fix_options(word, suggestions): - print("%s:" % (word)) - print(" a: accept and add to dictionary") - print(" A: accept and add to dictionary as ALLCAPS (for acronyms)") - print(" f : replace with the given word without modifying dictionary") - print(" i: ignore") - print(" r : replace with given word and add to dictionary") - print(" R : replace with given word and add to dictionary as ALLCAPS (for acronyms)") - print(" x: abort") + print("%s:" % (word)) + print(" a: accept and add to dictionary") + print(" A: accept and add to dictionary as ALLCAPS (for acronyms)") + print(" f : replace with the given word without modifying dictionary") + print(" i: ignore") + print(" r : replace with given word and add to dictionary") + print(" R : replace with given word and add to dictionary as ALLCAPS (for acronyms)") + print(" x: abort") - if not suggestions: - return + if not suggestions: + return - col_width = max(len(word) for word in suggestions) - opt_width = int(math.log(len(suggestions), 10)) + 1 - padding = 2 # Two spaces of padding. - delim = 2 # Colon and space after number. - num_cols = int(78 / (col_width + padding + opt_width + delim)) - num_rows = int(len(suggestions) / num_cols + 1) - rows = [""] * num_rows + col_width = max(len(word) for word in suggestions) + opt_width = int(math.log(len(suggestions), 10)) + 1 + padding = 2 # Two spaces of padding. + delim = 2 # Colon and space after number. + num_cols = int(78 / (col_width + padding + opt_width + delim)) + num_rows = int(len(suggestions) / num_cols + 1) + rows = [""] * num_rows - indent = " " * padding - for idx, sugg in enumerate(suggestions): - row = idx % len(rows) - row_data = "%d: %s" % (idx, sugg) + indent = " " * padding + for idx, sugg in enumerate(suggestions): + row = idx % len(rows) + row_data = "%d: %s" % (idx, sugg) - rows[row] += indent + row_data.ljust(col_width + opt_width + delim) + rows[row] += indent + row_data.ljust(col_width + opt_width + delim) - for row in rows: - print(row) + for row in rows: + print(row) def fix_error(checker, file, line_offset, lines, errors): - print_error(file, line_offset, lines, errors) - - fixed = {} - replacements = [] - additions = [] - for (word, offset, suggestions) in errors: - if word in fixed: - # Same typo was repeated in a line, so just reuse the previous choice. - replacements += [fixed[word]] - continue - - print_fix_options(word, suggestions) - - replacement = "" - while replacement == "": - try: - choice = input("> ") - except EOFError: - choice = "x" - - add = None - if choice == "x": - print("Spell checking aborted.") - sys.exit(2) - elif choice == "a": - replacement = word - add = word - elif choice == "A": - replacement = word - add = word.upper() - elif choice[:1] == "f": - replacement = choice[1:].strip() - if replacement == "": - print("Invalid choice: '%s'. Must specify a replacement (e.g. 'f corrected')." % (choice)) - continue - elif choice == "i": - replacement = word - elif choice[:1] == "r" or choice[:1] == "R": - replacement = choice[1:].strip() - if replacement == "": - print("Invalid choice: '%s'. Must specify a replacement (e.g. 'r corrected')." % (choice)) - continue - - if choice[:1] == "R": - if replacement.upper() not in suggestions: - add = replacement.upper() - elif replacement not in suggestions: - add = replacement - else: - try: - idx = int(choice) - except ValueError: - idx = -1 - if idx >= 0 and idx < len(suggestions): - replacement = suggestions[idx] - else: - print("Invalid choice: '%s'" % (choice)) - - fixed[word] = replacement - replacements += [replacement] - if add: - if re.match(DICTIONARY_WORD, add): - additions += [add] - else: - print("Cannot add %s to the dictionary: it may only contain letter and apostrophes" % add) + print_error(file, line_offset, lines, errors) - if len(errors) != len(replacements): - print("Internal error %d errors with %d replacements" % (len(errors), len(replacements))) - sys.exit(2) + fixed = {} + replacements = [] + additions = [] + for (word, offset, suggestions) in errors: + if word in fixed: + # Same typo was repeated in a line, so just reuse the previous choice. + replacements += [fixed[word]] + continue + + print_fix_options(word, suggestions) + + replacement = "" + while replacement == "": + try: + choice = input("> ") + except EOFError: + choice = "x" + + add = None + if choice == "x": + print("Spell checking aborted.") + sys.exit(2) + elif choice == "a": + replacement = word + add = word + elif choice == "A": + replacement = word + add = word.upper() + elif choice[:1] == "f": + replacement = choice[1:].strip() + if replacement == "": + print("Invalid choice: '%s'. Must specify a replacement (e.g. 'f corrected')." % + (choice)) + continue + elif choice == "i": + replacement = word + elif choice[:1] == "r" or choice[:1] == "R": + replacement = choice[1:].strip() + if replacement == "": + print("Invalid choice: '%s'. Must specify a replacement (e.g. 'r corrected')." % + (choice)) + continue + + if choice[:1] == "R": + if replacement.upper() not in suggestions: + add = replacement.upper() + elif replacement not in suggestions: + add = replacement + else: + try: + idx = int(choice) + except ValueError: + idx = -1 + if idx >= 0 and idx < len(suggestions): + replacement = suggestions[idx] + else: + print("Invalid choice: '%s'" % (choice)) + + fixed[word] = replacement + replacements += [replacement] + if add: + if re.match(DICTIONARY_WORD, add): + additions += [add] + else: + print( + "Cannot add %s to the dictionary: it may only contain letter and apostrophes" % + add) + + if len(errors) != len(replacements): + print("Internal error %d errors with %d replacements" % (len(errors), len(replacements))) + sys.exit(2) - # Perform replacements on the line. - line = lines[line_offset] - for idx in range(len(replacements) - 1, -1, -1): - word, offset, _ = errors[idx] - replacement = replacements[idx] - if word == replacement: - continue + # Perform replacements on the line. + line = lines[line_offset] + for idx in range(len(replacements) - 1, -1, -1): + word, offset, _ = errors[idx] + replacement = replacements[idx] + if word == replacement: + continue - line = line[:offset] + replacement + line[offset + len(word):] - lines[line_offset] = line + line = line[:offset] + replacement + line[offset + len(word):] + lines[line_offset] = line - # Update the dictionary. - checker.add_words(additions) + # Update the dictionary. + checker.add_words(additions) class Comment: - """Comment represents a comment at a location within a file.""" + """Comment represents a comment at a location within a file.""" - def __init__(self, line, col, text, last_on_line): - self.line = line - self.col = col - self.text = text - self.last_on_line = last_on_line + def __init__(self, line, col, text, last_on_line): + self.line = line + self.col = col + self.text = text + self.last_on_line = last_on_line # Extract comments from lines. Returns an array of Comment. def extract_comments(lines): - in_comment = False - comments = [] - for line_idx, line in enumerate(lines): - line_comments = [] - last = 0 - if in_comment: - mc_end = MULTI_COMMENT_END.search(line) - if mc_end is None: - # Full line is within a multi-line comment. - line_comments.append((0, line)) - else: - # Start of line is the end of a multi-line comment. - line_comments.append((0, mc_end.group(1))) - last = mc_end.end() - in_comment = False - - if not in_comment: - for inline in INLINE_COMMENT.finditer(line, last): - # Single-line comment. - m = inline.lastindex # 1 is //, 2 is /* ... */ - line_comments.append((inline.start(m), inline.group(m))) - last = inline.end(m) - - if last < len(line): - mc_start = MULTI_COMMENT_START.search(line, last) - if mc_start is not None: - # New multi-lie comment starts at end of line. - line_comments.append((mc_start.start(1), mc_start.group(1))) - in_comment = True - - for idx, line_comment in enumerate(line_comments): - col, text = line_comment - last_on_line = idx + 1 >= len(line_comments) - comments.append(Comment(line=line_idx, col=col, text=text, last_on_line=last_on_line)) - - # Handle control statements and filter out comments that are part of - # RST code block directives. - result = [] - n = 0 - nc = len(comments) - - while n < nc: - text = comments[n].text - - if SPELLCHECK_SKIP_FILE in text: - # Skip the file: just don't return any comments. - return [] - - pos = text.find(SPELLCHECK_ON) - if pos != -1: - # Ignored because spellchecking isn't disabled. Just mask out the command. - comments[n].text = text[:pos] + ' ' * len(SPELLCHECK_ON) + text[pos + len(SPELLCHECK_ON):] - result.append(comments[n]) - n += 1 - elif SPELLCHECK_OFF in text or SPELLCHECK_SKIP_BLOCK in text: - skip_block = SPELLCHECK_SKIP_BLOCK in text - last_line = n - n += 1 - while n < nc: - if skip_block: - if comments[n].line - last_line > 1: - # Gap in comments. We've skipped the block. - break - line = lines[comments[n].line] - if line[:comments[n].col].strip() != "": - # Some code here. We've skipped the block. - break - elif SPELLCHECK_ON in comments[n].text: - # Turn checking back on. - n += 1 - break - - n += 1 - elif text.strip().startswith(RST_CODE_BLOCK): - # Start of a code block. - indent = len(INDENT.search(text).group(1)) - last_line = comments[n].line - n += 1 - - while n < nc: - if comments[n].line - last_line > 1: - # Gap in comments. Code block is finished. - break - last_line = comments[n].line - - if comments[n].text.strip() != "": - # Blank lines are ignored in code blocks. - if len(INDENT.search(comments[n].text).group(1)) <= indent: - # Back to original indent, or less. The code block is done. - break - n += 1 - else: - result.append(comments[n]) - n += 1 - - return result + in_comment = False + comments = [] + for line_idx, line in enumerate(lines): + line_comments = [] + last = 0 + if in_comment: + mc_end = MULTI_COMMENT_END.search(line) + if mc_end is None: + # Full line is within a multi-line comment. + line_comments.append((0, line)) + else: + # Start of line is the end of a multi-line comment. + line_comments.append((0, mc_end.group(1))) + last = mc_end.end() + in_comment = False + + if not in_comment: + for inline in INLINE_COMMENT.finditer(line, last): + # Single-line comment. + m = inline.lastindex # 1 is //, 2 is /* ... */ + line_comments.append((inline.start(m), inline.group(m))) + last = inline.end(m) + + if last < len(line): + mc_start = MULTI_COMMENT_START.search(line, last) + if mc_start is not None: + # New multi-lie comment starts at end of line. + line_comments.append((mc_start.start(1), mc_start.group(1))) + in_comment = True + + for idx, line_comment in enumerate(line_comments): + col, text = line_comment + last_on_line = idx + 1 >= len(line_comments) + comments.append(Comment(line=line_idx, col=col, text=text, last_on_line=last_on_line)) + + # Handle control statements and filter out comments that are part of + # RST code block directives. + result = [] + n = 0 + nc = len(comments) + + while n < nc: + text = comments[n].text + + if SPELLCHECK_SKIP_FILE in text: + # Skip the file: just don't return any comments. + return [] + + pos = text.find(SPELLCHECK_ON) + if pos != -1: + # Ignored because spellchecking isn't disabled. Just mask out the command. + comments[n].text = text[:pos] + ' ' * len(SPELLCHECK_ON) + text[pos + + len(SPELLCHECK_ON):] + result.append(comments[n]) + n += 1 + elif SPELLCHECK_OFF in text or SPELLCHECK_SKIP_BLOCK in text: + skip_block = SPELLCHECK_SKIP_BLOCK in text + last_line = n + n += 1 + while n < nc: + if skip_block: + if comments[n].line - last_line > 1: + # Gap in comments. We've skipped the block. + break + line = lines[comments[n].line] + if line[:comments[n].col].strip() != "": + # Some code here. We've skipped the block. + break + elif SPELLCHECK_ON in comments[n].text: + # Turn checking back on. + n += 1 + break + + n += 1 + elif text.strip().startswith(RST_CODE_BLOCK): + # Start of a code block. + indent = len(INDENT.search(text).group(1)) + last_line = comments[n].line + n += 1 + + while n < nc: + if comments[n].line - last_line > 1: + # Gap in comments. Code block is finished. + break + last_line = comments[n].line + + if comments[n].text.strip() != "": + # Blank lines are ignored in code blocks. + if len(INDENT.search(comments[n].text).group(1)) <= indent: + # Back to original indent, or less. The code block is done. + break + n += 1 + else: + result.append(comments[n]) + n += 1 + + return result def check_file(checker, file, lines, error_handler): - in_code_block = 0 - code_block_indent = 0 - num_errors = 0 + in_code_block = 0 + code_block_indent = 0 + num_errors = 0 - comments = extract_comments(lines) - errors = [] - for comment in comments: - errors += check_comment(checker, comment.col, comment.text) - if comment.last_on_line and len(errors) > 0: - # Handle all the errors in a line. - num_errors += len(errors) - error_handler(file, comment.line, lines, errors) - errors = [] + comments = extract_comments(lines) + errors = [] + for comment in comments: + errors += check_comment(checker, comment.col, comment.text) + if comment.last_on_line and len(errors) > 0: + # Handle all the errors in a line. + num_errors += len(errors) + error_handler(file, comment.line, lines, errors) + errors = [] - return (len(comments), num_errors) + return (len(comments), num_errors) def execute(files, dictionary_file, fix): - checker = SpellChecker(dictionary_file) - checker.start() + checker = SpellChecker(dictionary_file) + checker.start() - handler = print_error - if fix: - handler = partial(fix_error, checker) + handler = print_error + if fix: + handler = partial(fix_error, checker) - total_files = 0 - total_comments = 0 - total_errors = 0 - for path in files: - with open(path, 'r') as f: - lines = f.readlines() - total_files += 1 - (num_comments, num_errors) = check_file(checker, path, lines, handler) - total_comments += num_comments - total_errors += num_errors + total_files = 0 + total_comments = 0 + total_errors = 0 + for path in files: + with open(path, 'r') as f: + lines = f.readlines() + total_files += 1 + (num_comments, num_errors) = check_file(checker, path, lines, handler) + total_comments += num_comments + total_errors += num_errors - if fix and num_errors > 0: - with open(path, 'w') as f: - f.writelines(lines) + if fix and num_errors > 0: + with open(path, 'w') as f: + f.writelines(lines) - checker.stop() + checker.stop() - print("Checked %d file(s) and %d comment(s), found %d error(s)." % - (total_files, total_comments, total_errors)) + print("Checked %d file(s) and %d comment(s), found %d error(s)." % + (total_files, total_comments, total_errors)) - return total_errors == 0 + return total_errors == 0 if __name__ == "__main__": - # Force UTF-8 across all open and popen calls. Fallback to 'C' as the - # language to handle hosts where en_US is not recognized (e.g. CI). - try: - locale.setlocale(locale.LC_ALL, 'en_US.UTF-8') - except: - locale.setlocale(locale.LC_ALL, 'C.UTF-8') - - default_dictionary = os.path.join(CURR_DIR, 'spelling_dictionary.txt') - - parser = argparse.ArgumentParser(description="Check comment spelling.") - parser.add_argument('operation_type', - type=str, - choices=['check', 'fix'], - help="specify if the run should 'check' or 'fix' spelling.") - parser.add_argument('target_paths', - type=str, - nargs="*", - help="specify the files for the script to process.") - parser.add_argument('-d', - '--debug', - action='count', - default=0, - help="Debug spell checker subprocess.") - parser.add_argument('--mark', - action='store_true', - help="Emits extra output to mark misspelled words.") - parser.add_argument('--dictionary', - type=str, - default=default_dictionary, - help="specify a location for Envoy-specific dictionary words") - parser.add_argument('--color', - type=str, - choices=['on', 'off', 'auto'], - default="auto", - help="Controls colorized output. Auto limits color to TTY devices.") - parser.add_argument('--test-ignore-exts', - dest='test_ignore_exts', - action='store_true', - help="For testing, ignore file extensions.") - args = parser.parse_args() - - COLOR = args.color == "on" or (args.color == "auto" and sys.stdout.isatty()) - DEBUG = args.debug - MARK = args.mark - - paths = args.target_paths - if not paths: - paths = ['./api', './include', './source', './test', './tools'] - - # Exclude ./third_party/ directory from spell checking, even when requested through arguments. - # Otherwise git pre-push hook checks it for merged commits. - paths = [ - path for path in paths - if not path.startswith('./third_party/') and not path.startswith('./third_party/') - ] - - exts = ['.cc', '.h', '.proto'] - if args.test_ignore_exts: - exts = None - target_paths = [] - for p in paths: - if os.path.isdir(p): - for root, _, files in os.walk(p): - target_paths += [ - os.path.join(root, f) for f in files if (exts is None or os.path.splitext(f)[1] in exts) - ] - if os.path.isfile(p) and (exts is None or os.path.splitext(p)[1] in exts): - target_paths += [p] - - rv = execute(target_paths, args.dictionary, args.operation_type == 'fix') - - if args.operation_type == 'check': - if not rv: - print( - "ERROR: spell check failed. Run 'tools/spelling/check_spelling_pedantic.py fix and/or add new " - "words to tools/spelling/spelling_dictionary.txt'") - sys.exit(1) - - print("PASS") + # Force UTF-8 across all open and popen calls. Fallback to 'C' as the + # language to handle hosts where en_US is not recognized (e.g. CI). + try: + locale.setlocale(locale.LC_ALL, 'en_US.UTF-8') + except: + locale.setlocale(locale.LC_ALL, 'C.UTF-8') + + default_dictionary = os.path.join(CURR_DIR, 'spelling_dictionary.txt') + + parser = argparse.ArgumentParser(description="Check comment spelling.") + parser.add_argument('operation_type', + type=str, + choices=['check', 'fix'], + help="specify if the run should 'check' or 'fix' spelling.") + parser.add_argument('target_paths', + type=str, + nargs="*", + help="specify the files for the script to process.") + parser.add_argument('-d', + '--debug', + action='count', + default=0, + help="Debug spell checker subprocess.") + parser.add_argument('--mark', + action='store_true', + help="Emits extra output to mark misspelled words.") + parser.add_argument('--dictionary', + type=str, + default=default_dictionary, + help="specify a location for Envoy-specific dictionary words") + parser.add_argument('--color', + type=str, + choices=['on', 'off', 'auto'], + default="auto", + help="Controls colorized output. Auto limits color to TTY devices.") + parser.add_argument('--test-ignore-exts', + dest='test_ignore_exts', + action='store_true', + help="For testing, ignore file extensions.") + args = parser.parse_args() + + COLOR = args.color == "on" or (args.color == "auto" and sys.stdout.isatty()) + DEBUG = args.debug + MARK = args.mark + + paths = args.target_paths + if not paths: + paths = ['./api', './include', './source', './test', './tools'] + + # Exclude ./third_party/ directory from spell checking, even when requested through arguments. + # Otherwise git pre-push hook checks it for merged commits. + paths = [ + path for path in paths + if not path.startswith('./third_party/') and not path.startswith('./third_party/') + ] + + exts = ['.cc', '.h', '.proto'] + if args.test_ignore_exts: + exts = None + target_paths = [] + for p in paths: + if os.path.isdir(p): + for root, _, files in os.walk(p): + target_paths += [ + os.path.join(root, f) + for f in files + if (exts is None or os.path.splitext(f)[1] in exts) + ] + if os.path.isfile(p) and (exts is None or os.path.splitext(p)[1] in exts): + target_paths += [p] + + rv = execute(target_paths, args.dictionary, args.operation_type == 'fix') + + if args.operation_type == 'check': + if not rv: + print( + "ERROR: spell check failed. Run 'tools/spelling/check_spelling_pedantic.py fix and/or add new " + "words to tools/spelling/spelling_dictionary.txt'") + sys.exit(1) + + print("PASS") diff --git a/tools/spelling/check_spelling_pedantic_test.py b/tools/spelling/check_spelling_pedantic_test.py index 959b36a912db..d6c25e9e8397 100755 --- a/tools/spelling/check_spelling_pedantic_test.py +++ b/tools/spelling/check_spelling_pedantic_test.py @@ -20,85 +20,86 @@ # printing the comamnd run and the status code as well as the stdout, # and returning all of that to the caller. def run_check_format(operation, filename): - command = check_spelling + " --test-ignore-exts " + operation + " " + filename - status, stdout, stderr = run_command(command) - return (command, status, stdout + stderr) + command = check_spelling + " --test-ignore-exts " + operation + " " + filename + status, stdout, stderr = run_command(command) + return (command, status, stdout + stderr) def get_input_file(filename): - return os.path.join(src, filename) + return os.path.join(src, filename) def emit_stdout_as_error(stdout): - logging.error("\n".join(stdout)) + logging.error("\n".join(stdout)) def expect_error(filename, status, stdout, expected_substrings): - if status == 0: - logging.error("%s: Expected %d errors, but succeeded" % (filename, len(expected_substrings))) - return 1 - errors = 0 - for expected_substring in expected_substrings: - found = False - for line in stdout: - if expected_substring in line: - found = True - break - if not found: - logging.error("%s: Could not find '%s' in:\n" % (filename, expected_substring)) - emit_stdout_as_error(stdout) - errors += 1 - - return errors + if status == 0: + logging.error("%s: Expected %d errors, but succeeded" % + (filename, len(expected_substrings))) + return 1 + errors = 0 + for expected_substring in expected_substrings: + found = False + for line in stdout: + if expected_substring in line: + found = True + break + if not found: + logging.error("%s: Could not find '%s' in:\n" % (filename, expected_substring)) + emit_stdout_as_error(stdout) + errors += 1 + + return errors def check_file_expecting_errors(filename, expected_substrings): - command, status, stdout = run_check_format("check", get_input_file(filename)) - return expect_error(filename, status, stdout, expected_substrings) + command, status, stdout = run_check_format("check", get_input_file(filename)) + return expect_error(filename, status, stdout, expected_substrings) def check_file_path_expecting_ok(filename): - command, status, stdout = run_check_format("check", filename) - if status != 0: - logging.error("Expected %s to have no errors; status=%d, output:\n" % (filename, status)) - emit_stdout_as_error(stdout) - return status + command, status, stdout = run_check_format("check", filename) + if status != 0: + logging.error("Expected %s to have no errors; status=%d, output:\n" % (filename, status)) + emit_stdout_as_error(stdout) + return status def check_file_expecting_ok(filename): - return check_file_path_expecting_ok(get_input_file(filename)) + return check_file_path_expecting_ok(get_input_file(filename)) def run_checks(): - errors = 0 + errors = 0 - errors += check_file_expecting_ok("valid") - errors += check_file_expecting_ok("skip_file") - errors += check_file_expecting_ok("exclusions") + errors += check_file_expecting_ok("valid") + errors += check_file_expecting_ok("skip_file") + errors += check_file_expecting_ok("exclusions") - errors += check_file_expecting_ok("third_party/something/file.cc") - errors += check_file_expecting_ok("./third_party/something/file.cc") + errors += check_file_expecting_ok("third_party/something/file.cc") + errors += check_file_expecting_ok("./third_party/something/file.cc") - errors += check_file_expecting_errors( - "typos", ["spacific", "reelistic", "Awwful", "combeenations", "woork"]) - errors += check_file_expecting_errors( - "skip_blocks", ["speelinga", "speelingb", "speelingc", "speelingd", "speelinge"]) - errors += check_file_expecting_errors("on_off", ["speelinga", "speelingb"]) - errors += check_file_expecting_errors("rst_code_block", ["speelinga", "speelingb"]) - errors += check_file_expecting_errors("word_splitting", ["Speeled", "Korrectly"]) + errors += check_file_expecting_errors( + "typos", ["spacific", "reelistic", "Awwful", "combeenations", "woork"]) + errors += check_file_expecting_errors( + "skip_blocks", ["speelinga", "speelingb", "speelingc", "speelingd", "speelinge"]) + errors += check_file_expecting_errors("on_off", ["speelinga", "speelingb"]) + errors += check_file_expecting_errors("rst_code_block", ["speelinga", "speelingb"]) + errors += check_file_expecting_errors("word_splitting", ["Speeled", "Korrectly"]) - return errors + return errors if __name__ == "__main__": - parser = argparse.ArgumentParser(description='tester for check_format.py.') - parser.add_argument('--log', choices=['INFO', 'WARN', 'ERROR'], default='INFO') - args = parser.parse_args() - logging.basicConfig(format='%(message)s', level=args.log) + parser = argparse.ArgumentParser(description='tester for check_format.py.') + parser.add_argument('--log', choices=['INFO', 'WARN', 'ERROR'], default='INFO') + args = parser.parse_args() + logging.basicConfig(format='%(message)s', level=args.log) - errors = run_checks() + errors = run_checks() - if errors != 0: - logging.error("%d FAILURES" % errors) - exit(1) - logging.warning("PASS") + if errors != 0: + logging.error("%d FAILURES" % errors) + exit(1) + logging.warning("PASS") diff --git a/tools/stack_decode.py b/tools/stack_decode.py index 169d94ffd85f..c0f18a5ba029 100755 --- a/tools/stack_decode.py +++ b/tools/stack_decode.py @@ -22,43 +22,43 @@ # and line information. Output appended to end of original backtrace line. Output # any nonmatching lines unmodified. End when EOF received. def decode_stacktrace_log(object_file, input_source, address_offset=0): - traces = {} - # Match something like: - # [backtrace] [bazel-out/local-dbg/bin/source/server/_virtual_includes/backtrace_lib/server/backtrace.h:84] - backtrace_marker = "\[backtrace\] [^\s]+" - # Match something like: - # ${backtrace_marker} #10: SYMBOL [0xADDR] - # or: - # ${backtrace_marker} #10: [0xADDR] - stackaddr_re = re.compile("%s #\d+:(?: .*)? \[(0x[0-9a-fA-F]+)\]$" % backtrace_marker) - # Match something like: - # #10 0xLOCATION (BINARY+0xADDR) - asan_re = re.compile(" *#\d+ *0x[0-9a-fA-F]+ *\([^+]*\+(0x[0-9a-fA-F]+)\)") - - try: - while True: - line = input_source.readline() - if line == "": - return # EOF - stackaddr_match = stackaddr_re.search(line) - if not stackaddr_match: - stackaddr_match = asan_re.search(line) - if stackaddr_match: - address = stackaddr_match.groups()[0] - if address_offset != 0: - address = hex(int(address, 16) - address_offset) - file_and_line_number = run_addr2line(object_file, address) - file_and_line_number = trim_proc_cwd(file_and_line_number) - if address_offset != 0: - sys.stdout.write("%s->[%s] %s" % (line.strip(), address, file_and_line_number)) - else: - sys.stdout.write("%s %s" % (line.strip(), file_and_line_number)) - continue - else: - # Pass through print all other log lines: - sys.stdout.write(line) - except KeyboardInterrupt: - return + traces = {} + # Match something like: + # [backtrace] [bazel-out/local-dbg/bin/source/server/_virtual_includes/backtrace_lib/server/backtrace.h:84] + backtrace_marker = "\[backtrace\] [^\s]+" + # Match something like: + # ${backtrace_marker} #10: SYMBOL [0xADDR] + # or: + # ${backtrace_marker} #10: [0xADDR] + stackaddr_re = re.compile("%s #\d+:(?: .*)? \[(0x[0-9a-fA-F]+)\]$" % backtrace_marker) + # Match something like: + # #10 0xLOCATION (BINARY+0xADDR) + asan_re = re.compile(" *#\d+ *0x[0-9a-fA-F]+ *\([^+]*\+(0x[0-9a-fA-F]+)\)") + + try: + while True: + line = input_source.readline() + if line == "": + return # EOF + stackaddr_match = stackaddr_re.search(line) + if not stackaddr_match: + stackaddr_match = asan_re.search(line) + if stackaddr_match: + address = stackaddr_match.groups()[0] + if address_offset != 0: + address = hex(int(address, 16) - address_offset) + file_and_line_number = run_addr2line(object_file, address) + file_and_line_number = trim_proc_cwd(file_and_line_number) + if address_offset != 0: + sys.stdout.write("%s->[%s] %s" % (line.strip(), address, file_and_line_number)) + else: + sys.stdout.write("%s %s" % (line.strip(), file_and_line_number)) + continue + else: + # Pass through print all other log lines: + sys.stdout.write(line) + except KeyboardInterrupt: + return # Execute addr2line with a particular object file and input string of addresses @@ -66,36 +66,37 @@ def decode_stacktrace_log(object_file, input_source, address_offset=0): # # Returns list of result lines def run_addr2line(obj_file, addr_to_resolve): - return subprocess.check_output(["addr2line", "-Cpie", obj_file, addr_to_resolve]).decode('utf-8') + return subprocess.check_output(["addr2line", "-Cpie", obj_file, + addr_to_resolve]).decode('utf-8') # Because of how bazel compiles, addr2line reports file names that begin with # "/proc/self/cwd/" and sometimes even "/proc/self/cwd/./". This isn't particularly # useful information, so trim it out and make a perfectly useful relative path. def trim_proc_cwd(file_and_line_number): - trim_regex = r'/proc/self/cwd/(\./)?' - return re.sub(trim_regex, '', file_and_line_number) + trim_regex = r'/proc/self/cwd/(\./)?' + return re.sub(trim_regex, '', file_and_line_number) # Execute pmap with a pid to calculate the addr offset # # Returns list of extended process memory information. def run_pmap(pid): - return subprocess.check_output(['pmap', '-qX', str(pid)]).decode('utf-8')[1:] + return subprocess.check_output(['pmap', '-qX', str(pid)]).decode('utf-8')[1:] # Find the virtual address offset of the process. This may be needed due ASLR. # # Returns the virtual address offset as an integer, or 0 if unable to determine. def find_address_offset(pid): - try: - proc_memory = run_pmap(pid) - match = re.search(r'([a-f0-9]+)\s+r-xp', proc_memory) - if match is None: - return 0 - return int(match.group(1), 16) - except (subprocess.CalledProcessError, PermissionError): - return 0 + try: + proc_memory = run_pmap(pid) + match = re.search(r'([a-f0-9]+)\s+r-xp', proc_memory) + if match is None: + return 0 + return int(match.group(1), 16) + except (subprocess.CalledProcessError, PermissionError): + return 0 # When setting the logging level to trace, it's possible that we'll bump @@ -103,31 +104,31 @@ def find_address_offset(pid): # ignore these and keep going (instead of giving up and exiting # while possibly bringing Envoy down). def ignore_decoding_errors(io_wrapper): - # Only avail since 3.7. - # https://docs.python.org/3/library/io.html#io.TextIOWrapper.reconfigure - if hasattr(io_wrapper, 'reconfigure'): - try: - io_wrapper.reconfigure(errors='ignore') - except: - pass + # Only avail since 3.7. + # https://docs.python.org/3/library/io.html#io.TextIOWrapper.reconfigure + if hasattr(io_wrapper, 'reconfigure'): + try: + io_wrapper.reconfigure(errors='ignore') + except: + pass - return io_wrapper + return io_wrapper if __name__ == "__main__": - if len(sys.argv) > 2 and sys.argv[1] == '-s': - decode_stacktrace_log(sys.argv[2], ignore_decoding_errors(sys.stdin)) - sys.exit(0) - elif len(sys.argv) > 1: - rununder = subprocess.Popen(sys.argv[1:], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True) - offset = find_address_offset(rununder.pid) - decode_stacktrace_log(sys.argv[1], ignore_decoding_errors(rununder.stdout), offset) - rununder.wait() - sys.exit(rununder.returncode) # Pass back test pass/fail result - else: - print("Usage (execute subprocess): stack_decode.py executable_file [additional args]") - print("Usage (read from stdin): stack_decode.py -s executable_file") - sys.exit(1) + if len(sys.argv) > 2 and sys.argv[1] == '-s': + decode_stacktrace_log(sys.argv[2], ignore_decoding_errors(sys.stdin)) + sys.exit(0) + elif len(sys.argv) > 1: + rununder = subprocess.Popen(sys.argv[1:], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True) + offset = find_address_offset(rununder.pid) + decode_stacktrace_log(sys.argv[1], ignore_decoding_errors(rununder.stdout), offset) + rununder.wait() + sys.exit(rununder.returncode) # Pass back test pass/fail result + else: + print("Usage (execute subprocess): stack_decode.py executable_file [additional args]") + print("Usage (read from stdin): stack_decode.py -s executable_file") + sys.exit(1) diff --git a/tools/type_whisperer/file_descriptor_set_text_gen.py b/tools/type_whisperer/file_descriptor_set_text_gen.py index fb89dc457e25..1501bed1656c 100644 --- a/tools/type_whisperer/file_descriptor_set_text_gen.py +++ b/tools/type_whisperer/file_descriptor_set_text_gen.py @@ -11,15 +11,15 @@ def decode(path): - with open(path, 'rb') as f: - file_set = descriptor_pb2.FileDescriptorSet() - file_set.ParseFromString(f.read()) - return str(file_set) + with open(path, 'rb') as f: + file_set = descriptor_pb2.FileDescriptorSet() + file_set.ParseFromString(f.read()) + return str(file_set) if __name__ == '__main__': - output_path = sys.argv[1] - input_paths = sys.argv[2:] - pb_text = '\n'.join(decode(path) for path in input_paths) - with open(output_path, 'w') as f: - f.write(pb_text) + output_path = sys.argv[1] + input_paths = sys.argv[2:] + pb_text = '\n'.join(decode(path) for path in input_paths) + with open(output_path, 'w') as f: + f.write(pb_text) diff --git a/tools/type_whisperer/proto_build_targets_gen.py b/tools/type_whisperer/proto_build_targets_gen.py index 35adb33f1482..dd3a2d21443d 100644 --- a/tools/type_whisperer/proto_build_targets_gen.py +++ b/tools/type_whisperer/proto_build_targets_gen.py @@ -70,64 +70,64 @@ def load_type_db(type_db_path): - type_db = TypeDb() - with open(type_db_path, 'r') as f: - text_format.Merge(f.read(), type_db) - return type_db + type_db = TypeDb() + with open(type_db_path, 'r') as f: + text_format.Merge(f.read(), type_db) + return type_db # Key sort function to achieve consistent results with buildifier. def build_order_key(key): - return key.replace(':', '!') + return key.replace(':', '!') # Remove any packages that are definitely non-root, e.g. annotations. def filter_pkgs(pkgs): - def allowed_pkg(pkg): - return not pkg.startswith('envoy.annotations') + def allowed_pkg(pkg): + return not pkg.startswith('envoy.annotations') - return filter(allowed_pkg, pkgs) + return filter(allowed_pkg, pkgs) def deps_format(pkgs): - return '\n'.join(' "//%s:pkg",' % p.replace('.', '/') - for p in sorted(filter_pkgs(pkgs), key=build_order_key)) + return '\n'.join(' "//%s:pkg",' % p.replace('.', '/') + for p in sorted(filter_pkgs(pkgs), key=build_order_key)) def is_v2_package(pkg): - for regex in V2_REGEXES: - if regex.match(pkg): - return True - return False + for regex in V2_REGEXES: + if regex.match(pkg): + return True + return False def accidental_v3_package(pkg): - return pkg in ACCIDENTAL_V3_PKGS + return pkg in ACCIDENTAL_V3_PKGS def is_v3_package(pkg): - return V3_REGEX.match(pkg) is not None + return V3_REGEX.match(pkg) is not None if __name__ == '__main__': - type_db_path, output_path = sys.argv[1:] - type_db = load_type_db(type_db_path) - # TODO(htuch): generalize to > 2 versions - v2_packages = set([]) - v3_packages = set([]) - for desc in type_db.types.values(): - pkg = desc.qualified_package - if is_v3_package(pkg): - v3_packages.add(pkg) - continue - if is_v2_package(pkg): - v2_packages.add(pkg) - # Special case for v2 packages that are part of v3 (still active) - if accidental_v3_package(pkg): - v3_packages.add(pkg) - # Generate BUILD file. - build_file_contents = API_BUILD_FILE_TEMPLATE.substitute(v2_deps=deps_format(v2_packages), - v3_deps=deps_format(v3_packages)) - with open(output_path, 'w') as f: - f.write(build_file_contents) + type_db_path, output_path = sys.argv[1:] + type_db = load_type_db(type_db_path) + # TODO(htuch): generalize to > 2 versions + v2_packages = set([]) + v3_packages = set([]) + for desc in type_db.types.values(): + pkg = desc.qualified_package + if is_v3_package(pkg): + v3_packages.add(pkg) + continue + if is_v2_package(pkg): + v2_packages.add(pkg) + # Special case for v2 packages that are part of v3 (still active) + if accidental_v3_package(pkg): + v3_packages.add(pkg) + # Generate BUILD file. + build_file_contents = API_BUILD_FILE_TEMPLATE.substitute(v2_deps=deps_format(v2_packages), + v3_deps=deps_format(v3_packages)) + with open(output_path, 'w') as f: + f.write(build_file_contents) diff --git a/tools/type_whisperer/proto_cc_source_gen.py b/tools/type_whisperer/proto_cc_source_gen.py index fcc586d8ff89..f6ab4f262f31 100644 --- a/tools/type_whisperer/proto_cc_source_gen.py +++ b/tools/type_whisperer/proto_cc_source_gen.py @@ -16,9 +16,9 @@ """) if __name__ == '__main__': - constant_name = sys.argv[1] - output_path = sys.argv[2] - input_paths = sys.argv[3:] - pb_text = '\n'.join(pathlib.Path(path).read_text() for path in input_paths) - with open(output_path, 'w') as f: - f.write(CC_SOURCE_TEMPLATE.substitute(constant=constant_name, pb_text=pb_text)) + constant_name = sys.argv[1] + output_path = sys.argv[2] + input_paths = sys.argv[3:] + pb_text = '\n'.join(pathlib.Path(path).read_text() for path in input_paths) + with open(output_path, 'w') as f: + f.write(CC_SOURCE_TEMPLATE.substitute(constant=constant_name, pb_text=pb_text)) diff --git a/tools/type_whisperer/type_whisperer.py b/tools/type_whisperer/type_whisperer.py index fe1068665758..7d1f824e2482 100755 --- a/tools/type_whisperer/type_whisperer.py +++ b/tools/type_whisperer/type_whisperer.py @@ -11,57 +11,58 @@ class TypeWhispererVisitor(visitor.Visitor): - """Visitor to compute type information from a FileDescriptor proto. + """Visitor to compute type information from a FileDescriptor proto. See visitor.Visitor for visitor method docs comments. """ - def __init__(self): - super(TypeWhispererVisitor, self).__init__() - self._types = Types() + def __init__(self): + super(TypeWhispererVisitor, self).__init__() + self._types = Types() - def visit_service(self, service_proto, type_context): - pass + def visit_service(self, service_proto, type_context): + pass - def visit_enum(self, enum_proto, type_context): - type_desc = self._types.types[type_context.name] - type_desc.next_version_upgrade = any(v.options.deprecated for v in enum_proto.value) - type_desc.deprecated_type = type_context.deprecated + def visit_enum(self, enum_proto, type_context): + type_desc = self._types.types[type_context.name] + type_desc.next_version_upgrade = any(v.options.deprecated for v in enum_proto.value) + type_desc.deprecated_type = type_context.deprecated - def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): - type_desc = self._types.types[type_context.name] - type_desc.map_entry = msg_proto.options.map_entry - type_desc.deprecated_type = type_context.deprecated - type_deps = set([]) - for f in msg_proto.field: - if f.type_name.startswith('.'): - type_deps.add(f.type_name[1:]) - if f.options.deprecated: - type_desc.next_version_upgrade = True - type_desc.type_dependencies.extend(type_deps) + def visit_message(self, msg_proto, type_context, nested_msgs, nested_enums): + type_desc = self._types.types[type_context.name] + type_desc.map_entry = msg_proto.options.map_entry + type_desc.deprecated_type = type_context.deprecated + type_deps = set([]) + for f in msg_proto.field: + if f.type_name.startswith('.'): + type_deps.add(f.type_name[1:]) + if f.options.deprecated: + type_desc.next_version_upgrade = True + type_desc.type_dependencies.extend(type_deps) - def visit_file(self, file_proto, type_context, services, msgs, enums): - next_version_package = '' - if file_proto.options.HasExtension(migrate_pb2.file_migrate): - next_version_package = file_proto.options.Extensions[migrate_pb2.file_migrate].move_to_package - for t in self._types.types.values(): - t.qualified_package = file_proto.package - t.proto_path = file_proto.name - t.active = file_proto.options.Extensions[ - status_pb2.file_status].package_version_status == status_pb2.ACTIVE - if next_version_package: - t.next_version_package = next_version_package - t.next_version_upgrade = True - # Return in text proto format. This makes things easier to debug, these - # don't need to be compact as they are only interim build artifacts. - return str(self._types) + def visit_file(self, file_proto, type_context, services, msgs, enums): + next_version_package = '' + if file_proto.options.HasExtension(migrate_pb2.file_migrate): + next_version_package = file_proto.options.Extensions[ + migrate_pb2.file_migrate].move_to_package + for t in self._types.types.values(): + t.qualified_package = file_proto.package + t.proto_path = file_proto.name + t.active = file_proto.options.Extensions[ + status_pb2.file_status].package_version_status == status_pb2.ACTIVE + if next_version_package: + t.next_version_package = next_version_package + t.next_version_upgrade = True + # Return in text proto format. This makes things easier to debug, these + # don't need to be compact as they are only interim build artifacts. + return str(self._types) def main(): - plugin.plugin([ - plugin.direct_output_descriptor('.types.pb_text', TypeWhispererVisitor), - ]) + plugin.plugin([ + plugin.direct_output_descriptor('.types.pb_text', TypeWhispererVisitor), + ]) if __name__ == '__main__': - main() + main() diff --git a/tools/type_whisperer/typedb_gen.py b/tools/type_whisperer/typedb_gen.py index 5a28ce4cf70c..479be44eeec6 100644 --- a/tools/type_whisperer/typedb_gen.py +++ b/tools/type_whisperer/typedb_gen.py @@ -44,40 +44,40 @@ def upgraded_package(type_desc): - """Determine upgrade package for a type.""" - if type_desc.next_version_package: - return type_desc.next_version_package + """Determine upgrade package for a type.""" + if type_desc.next_version_package: + return type_desc.next_version_package - for pattern, repl in TYPE_UPGRADE_REGEXES: - s = re.sub(pattern, repl, type_desc.qualified_package) - if s != type_desc.qualified_package: - return s - raise ValueError('{} is not upgradable'.format(type_desc.qualified_package)) + for pattern, repl in TYPE_UPGRADE_REGEXES: + s = re.sub(pattern, repl, type_desc.qualified_package) + if s != type_desc.qualified_package: + return s + raise ValueError('{} is not upgradable'.format(type_desc.qualified_package)) def upgraded_type(type_name, type_desc): - """Determine upgraded type name.""" - _upgraded_package = upgraded_package(type_desc) - return type_name.replace(type_desc.qualified_package, _upgraded_package) + """Determine upgraded type name.""" + _upgraded_package = upgraded_package(type_desc) + return type_name.replace(type_desc.qualified_package, _upgraded_package) def upgraded_path(proto_path, upgraded_package): - """Determine upgraded API .proto path.""" - return '/'.join([upgraded_package.replace('.', '/'), proto_path.split('/')[-1]]) + """Determine upgraded API .proto path.""" + return '/'.join([upgraded_package.replace('.', '/'), proto_path.split('/')[-1]]) def upgraded_type_with_description(type_name, type_desc): - upgrade_type_desc = TypeDescription() - upgrade_type_desc.qualified_package = upgraded_package(type_desc) - upgrade_type_desc.proto_path = upgraded_path(type_desc.proto_path, - upgrade_type_desc.qualified_package) - upgrade_type_desc.deprecated_type = type_desc.deprecated_type - upgrade_type_desc.map_entry = type_desc.map_entry - return (upgraded_type(type_name, type_desc), upgrade_type_desc) + upgrade_type_desc = TypeDescription() + upgrade_type_desc.qualified_package = upgraded_package(type_desc) + upgrade_type_desc.proto_path = upgraded_path(type_desc.proto_path, + upgrade_type_desc.qualified_package) + upgrade_type_desc.deprecated_type = type_desc.deprecated_type + upgrade_type_desc.map_entry = type_desc.map_entry + return (upgraded_type(type_name, type_desc), upgrade_type_desc) def load_types(path): - """Load a tools.type_whisperer.Types proto from the filesystem. + """Load a tools.type_whisperer.Types proto from the filesystem. Args: path: filesystem path for a file in text proto format. @@ -85,14 +85,14 @@ def load_types(path): Returns: tools.type_whisperer.Types proto loaded from path. """ - types = Types() - with open(path, 'r') as f: - text_format.Merge(f.read(), types) - return types + types = Types() + with open(path, 'r') as f: + text_format.Merge(f.read(), types) + return types def next_version_upgrade(type_name, type_map, next_version_upgrade_memo, visited=None): - """Does a given type require upgrade between major version? + """Does a given type require upgrade between major version? Performs depth-first search through type dependency graph for any upgraded types that will force type_name to be upgraded. @@ -107,82 +107,82 @@ def next_version_upgrade(type_name, type_map, next_version_upgrade_memo, visited Returns: A boolean indicating whether the type requires upgrade. """ - if not visited: - visited = set([]) - # Ignore non-API types. - if not type_name.startswith('envoy'): - return False - # If we have a loop, we can't learn anything new by circling around again. - if type_name in visited: - return False - visited = visited.union(set([type_name])) - # If we have seen this type in a previous next_version_upgrade(), use that - # result. - if type_name in next_version_upgrade_memo: - return next_version_upgrade_memo[type_name] - type_desc = type_map[type_name] - # Force upgrade packages that we enumerate. - if type_desc.qualified_package in PKG_FORCE_UPGRADE: - return True - # Recurse and memoize. - should_upgrade = type_desc.next_version_upgrade or any( - next_version_upgrade(d, type_map, next_version_upgrade_memo, visited) - for d in type_desc.type_dependencies) - next_version_upgrade_memo[type_name] = should_upgrade - return should_upgrade + if not visited: + visited = set([]) + # Ignore non-API types. + if not type_name.startswith('envoy'): + return False + # If we have a loop, we can't learn anything new by circling around again. + if type_name in visited: + return False + visited = visited.union(set([type_name])) + # If we have seen this type in a previous next_version_upgrade(), use that + # result. + if type_name in next_version_upgrade_memo: + return next_version_upgrade_memo[type_name] + type_desc = type_map[type_name] + # Force upgrade packages that we enumerate. + if type_desc.qualified_package in PKG_FORCE_UPGRADE: + return True + # Recurse and memoize. + should_upgrade = type_desc.next_version_upgrade or any( + next_version_upgrade(d, type_map, next_version_upgrade_memo, visited) + for d in type_desc.type_dependencies) + next_version_upgrade_memo[type_name] = should_upgrade + return should_upgrade if __name__ == '__main__': - # Output path for type database. - out_path = sys.argv[1] - - # Load type descriptors for each type whisper - type_desc_paths = sys.argv[2:] - type_whispers = map(load_types, type_desc_paths) - - # Aggregate type descriptors to a single type map. - type_map = dict(sum([list(t.types.items()) for t in type_whispers], [])) - all_pkgs = set([type_desc.qualified_package for type_desc in type_map.values()]) - - # Determine via DFS on each type descriptor and its deps which packages require upgrade. - next_version_upgrade_memo = {} - next_versions_pkgs = set([ - type_desc.qualified_package - for type_name, type_desc in type_map.items() - if next_version_upgrade(type_name, type_map, next_version_upgrade_memo) - ]).union(set(['envoy.config.retry.previous_priorities', 'envoy.config.cluster.redis'])) - - # Generate type map entries for upgraded types. We run this twice to allow - # things like a v2 deprecated map field's synthesized map entry to forward - # propagate to v4alpha (for shadowing purposes). - for _ in range(2): - type_map.update([ - upgraded_type_with_description(type_name, type_desc) + # Output path for type database. + out_path = sys.argv[1] + + # Load type descriptors for each type whisper + type_desc_paths = sys.argv[2:] + type_whispers = map(load_types, type_desc_paths) + + # Aggregate type descriptors to a single type map. + type_map = dict(sum([list(t.types.items()) for t in type_whispers], [])) + all_pkgs = set([type_desc.qualified_package for type_desc in type_map.values()]) + + # Determine via DFS on each type descriptor and its deps which packages require upgrade. + next_version_upgrade_memo = {} + next_versions_pkgs = set([ + type_desc.qualified_package for type_name, type_desc in type_map.items() - if type_desc.qualified_package in next_versions_pkgs and - (type_desc.active or type_desc.deprecated_type or type_desc.map_entry) - ]) - - # Generate the type database proto. To provide some stability across runs, in - # terms of the emitted proto binary blob that we track in git, we sort before - # loading the map entries in the proto. This seems to work in practice, but - # has no guarantees. - type_db = TypeDb() - next_proto_info = {} - for t in sorted(type_map): - type_desc = type_db.types[t] - type_desc.qualified_package = type_map[t].qualified_package - type_desc.proto_path = type_map[t].proto_path - if type_desc.qualified_package in next_versions_pkgs: - type_desc.next_version_type_name = upgraded_type(t, type_map[t]) - assert (type_desc.next_version_type_name != t) - next_proto_info[type_map[t].proto_path] = ( - type_map[type_desc.next_version_type_name].proto_path, - type_map[type_desc.next_version_type_name].qualified_package) - for proto_path, (next_proto_path, next_package) in sorted(next_proto_info.items()): - type_db.next_version_protos[proto_path].proto_path = next_proto_path - type_db.next_version_protos[proto_path].qualified_package = next_package - - # Write out proto text. - with open(out_path, 'w') as f: - f.write(str(type_db)) + if next_version_upgrade(type_name, type_map, next_version_upgrade_memo) + ]).union(set(['envoy.config.retry.previous_priorities', 'envoy.config.cluster.redis'])) + + # Generate type map entries for upgraded types. We run this twice to allow + # things like a v2 deprecated map field's synthesized map entry to forward + # propagate to v4alpha (for shadowing purposes). + for _ in range(2): + type_map.update([ + upgraded_type_with_description(type_name, type_desc) + for type_name, type_desc in type_map.items() + if type_desc.qualified_package in next_versions_pkgs and + (type_desc.active or type_desc.deprecated_type or type_desc.map_entry) + ]) + + # Generate the type database proto. To provide some stability across runs, in + # terms of the emitted proto binary blob that we track in git, we sort before + # loading the map entries in the proto. This seems to work in practice, but + # has no guarantees. + type_db = TypeDb() + next_proto_info = {} + for t in sorted(type_map): + type_desc = type_db.types[t] + type_desc.qualified_package = type_map[t].qualified_package + type_desc.proto_path = type_map[t].proto_path + if type_desc.qualified_package in next_versions_pkgs: + type_desc.next_version_type_name = upgraded_type(t, type_map[t]) + assert (type_desc.next_version_type_name != t) + next_proto_info[type_map[t].proto_path] = ( + type_map[type_desc.next_version_type_name].proto_path, + type_map[type_desc.next_version_type_name].qualified_package) + for proto_path, (next_proto_path, next_package) in sorted(next_proto_info.items()): + type_db.next_version_protos[proto_path].proto_path = next_proto_path + type_db.next_version_protos[proto_path].qualified_package = next_package + + # Write out proto text. + with open(out_path, 'w') as f: + f.write(str(type_db)) diff --git a/tools/vscode/generate_debug_config.py b/tools/vscode/generate_debug_config.py index 32a99bd7c09a..894d9a2007bc 100755 --- a/tools/vscode/generate_debug_config.py +++ b/tools/vscode/generate_debug_config.py @@ -12,112 +12,113 @@ def bazel_info(name, bazel_extra_options=[]): - return subprocess.check_output(["bazel", "info", name] + BAZEL_OPTIONS + - bazel_extra_options).decode().strip() + return subprocess.check_output(["bazel", "info", name] + BAZEL_OPTIONS + + bazel_extra_options).decode().strip() def get_workspace(): - return bazel_info("workspace") + return bazel_info("workspace") def get_execution_root(workspace): - # If compilation database exists, use its execution root, this allows setting - # breakpoints with clangd navigation easier. - try: - compdb = pathlib.Path(workspace, "compile_commands.json").read_text() - return json.loads(compdb)[0]['directory'] - except: - return bazel_info("execution_root") + # If compilation database exists, use its execution root, this allows setting + # breakpoints with clangd navigation easier. + try: + compdb = pathlib.Path(workspace, "compile_commands.json").read_text() + return json.loads(compdb)[0]['directory'] + except: + return bazel_info("execution_root") def binary_path(bazel_bin, target): - return pathlib.Path( - bazel_bin, - *[s for s in target.replace('@', 'external/').replace(':', '/').split('/') if s != '']) + return pathlib.Path( + bazel_bin, + *[s for s in target.replace('@', 'external/').replace(':', '/').split('/') if s != '']) def build_binary_with_debug_info(target): - targets = [target, target + ".dwp"] - subprocess.check_call(["bazel", "build", "-c", "dbg"] + BAZEL_OPTIONS + targets) + targets = [target, target + ".dwp"] + subprocess.check_call(["bazel", "build", "-c", "dbg"] + BAZEL_OPTIONS + targets) - bazel_bin = bazel_info("bazel-bin", ["-c", "dbg"]) - return binary_path(bazel_bin, target) + bazel_bin = bazel_info("bazel-bin", ["-c", "dbg"]) + return binary_path(bazel_bin, target) def get_launch_json(workspace): - try: - return json.loads(pathlib.Path(workspace, ".vscode", "launch.json").read_text()) - except: - return {"version": "0.2.0"} + try: + return json.loads(pathlib.Path(workspace, ".vscode", "launch.json").read_text()) + except: + return {"version": "0.2.0"} def write_launch_json(workspace, launch): - launch_json = pathlib.Path(workspace, ".vscode", "launch.json") - backup_launch_json = pathlib.Path(workspace, ".vscode", "launch.json.bak") - if launch_json.exists(): - shutil.move(str(launch_json), str(backup_launch_json)) + launch_json = pathlib.Path(workspace, ".vscode", "launch.json") + backup_launch_json = pathlib.Path(workspace, ".vscode", "launch.json.bak") + if launch_json.exists(): + shutil.move(str(launch_json), str(backup_launch_json)) - launch_json.write_text(json.dumps(launch, indent=4)) + launch_json.write_text(json.dumps(launch, indent=4)) def gdb_config(target, binary, workspace, execroot, arguments): - return { - "name": "gdb " + target, - "request": "launch", - "arguments": arguments, - "type": "gdb", - "target": str(binary), - "debugger_args": ["--directory=" + execroot], - "cwd": "${workspaceFolder}", - "valuesFormatting": "disabled" - } + return { + "name": "gdb " + target, + "request": "launch", + "arguments": arguments, + "type": "gdb", + "target": str(binary), + "debugger_args": ["--directory=" + execroot], + "cwd": "${workspaceFolder}", + "valuesFormatting": "disabled" + } def lldb_config(target, binary, workspace, execroot, arguments): - return { - "name": "lldb " + target, - "program": str(binary), - "sourceMap": { - "/proc/self/cwd": workspace, - "/proc/self/cwd/external": execroot + "/external", - "/proc/self/cwd/bazel-out": execroot + "/bazel-out" - }, - "cwd": "${workspaceFolder}", - "args": shlex.split(arguments), - "type": "lldb", - "request": "launch" - } + return { + "name": "lldb " + target, + "program": str(binary), + "sourceMap": { + "/proc/self/cwd": workspace, + "/proc/self/cwd/external": execroot + "/external", + "/proc/self/cwd/bazel-out": execroot + "/bazel-out" + }, + "cwd": "${workspaceFolder}", + "args": shlex.split(arguments), + "type": "lldb", + "request": "launch" + } def add_to_launch_json(target, binary, workspace, execroot, arguments, debugger_type): - launch = get_launch_json(workspace) - new_config = {} - if debugger_type == "lldb": - new_config = lldb_config(target, binary, workspace, execroot, arguments) - else: - new_config = gdb_config(target, binary, workspace, execroot, arguments) - - configurations = launch.get("configurations", []) - for config in configurations: - if config.get("name", None) == new_config["name"]: - config.clear() - config.update(new_config) - break - else: - configurations.append(new_config) - - launch["configurations"] = configurations - write_launch_json(workspace, launch) + launch = get_launch_json(workspace) + new_config = {} + if debugger_type == "lldb": + new_config = lldb_config(target, binary, workspace, execroot, arguments) + else: + new_config = gdb_config(target, binary, workspace, execroot, arguments) + + configurations = launch.get("configurations", []) + for config in configurations: + if config.get("name", None) == new_config["name"]: + config.clear() + config.update(new_config) + break + else: + configurations.append(new_config) + + launch["configurations"] = configurations + write_launch_json(workspace, launch) if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Build and generate launch config for VSCode') - parser.add_argument('--debugger', default="gdb") - parser.add_argument('--args', default='') - parser.add_argument('target') - args = parser.parse_args() - - workspace = get_workspace() - execution_root = get_execution_root(workspace) - debug_binary = build_binary_with_debug_info(args.target) - add_to_launch_json(args.target, debug_binary, workspace, execution_root, args.args, args.debugger) + parser = argparse.ArgumentParser(description='Build and generate launch config for VSCode') + parser.add_argument('--debugger', default="gdb") + parser.add_argument('--args', default='') + parser.add_argument('target') + args = parser.parse_args() + + workspace = get_workspace() + execution_root = get_execution_root(workspace) + debug_binary = build_binary_with_debug_info(args.target) + add_to_launch_json(args.target, debug_binary, workspace, execution_root, args.args, + args.debugger)