diff --git a/src/daemon/absorber.cc b/src/daemon/absorber.cc index a68d844e..537e2c32 100644 --- a/src/daemon/absorber.cc +++ b/src/daemon/absorber.cc @@ -13,17 +13,6 @@ using namespace std::placeholders; namespace dist_clang { namespace daemon { -namespace { - -bool ValidateClientVersion(ui32 client_version) { - return client_version >= Absorber::kMinGoodClientVersion; -} - -} // namespace - -// static -const ui32 Absorber::kMinGoodClientVersion = 100; - Absorber::Absorber(const proto::Configuration& configuration) : CompilationDaemon(configuration) { using Worker = base::WorkerPool::SimpleWorker; @@ -78,11 +67,12 @@ bool Absorber::HandleNewMessage(net::ConnectionPtr connection, if (message->HasExtension(proto::Remote::extension)) { Message execute(message->ReleaseExtension(proto::Remote::extension)); - if (!ValidateClientVersion(execute->client_version())) { - LOG(WARNING) << "Client sent a bad version: " - << execute->client_version(); + auto config = conf(); + if (execute->emitter_version() < config->absorber().min_emitter_version()) { + LOG(WARNING) << "Emitter sent a bad version: " + << execute->emitter_version(); net::proto::Status bad_version_status; - bad_version_status.set_code(net::proto::Status::BAD_CLIENT_VERSION); + bad_version_status.set_code(net::proto::Status::BAD_EMITTER_VERSION); return connection->ReportStatus(bad_version_status); } diff --git a/src/daemon/absorber.h b/src/daemon/absorber.h index 86c1828e..dc66ce6d 100644 --- a/src/daemon/absorber.h +++ b/src/daemon/absorber.h @@ -9,8 +9,6 @@ namespace daemon { class Absorber : public CompilationDaemon { public: - static const ui32 kMinGoodClientVersion; - explicit Absorber(const proto::Configuration& configuration); virtual ~Absorber(); diff --git a/src/daemon/absorber_test.cc b/src/daemon/absorber_test.cc index 71c832be..68b1ed17 100644 --- a/src/daemon/absorber_test.cc +++ b/src/daemon/absorber_test.cc @@ -8,6 +8,10 @@ namespace daemon { namespace { +static const ui32 kEmptyEmitterVersion = 0u; +static const ui32 kBadEmitterVersion = 1u; +static const ui32 kGoodEmitterVersion = 100u; + // S1..4 types can be one of the following: |String|, |Immutable|, |Literal| template net::Connection::ScopedMessage CreateMessage( @@ -15,11 +19,13 @@ net::Connection::ScopedMessage CreateMessage( const S2& action, const S3& compiler_version, const S4& language = S4(), - ui32 client_version = Absorber::kMinGoodClientVersion) { + ui32 emitter_version = kEmptyEmitterVersion) { net::Connection::ScopedMessage message(new net::Connection::Message); auto* extension = message->MutableExtension(proto::Remote::extension); extension->set_source(source); - extension->set_client_version(client_version); + if (emitter_version != kEmptyEmitterVersion) + extension->set_emitter_version(emitter_version); + auto* compiler = extension->mutable_flags()->mutable_compiler(); compiler->set_version(compiler_version); extension->mutable_flags()->set_action(action); @@ -377,7 +383,7 @@ TEST_F(AbsorberTest, BadMessageStatus) { << "Daemon must not store references to the connection"; } -TEST_F(AbsorberTest, BadClientVersion) { +TEST_F(AbsorberTest, BadEmitterVersion) { const String expected_host = "fake_host"; const ui16 expected_port = 12345; const String compiler_version("compiler_version"); @@ -385,6 +391,7 @@ TEST_F(AbsorberTest, BadClientVersion) { conf.mutable_absorber()->mutable_local()->set_host(expected_host); conf.mutable_absorber()->mutable_local()->set_port(expected_port); + conf.mutable_absorber()->set_min_emitter_version(kGoodEmitterVersion); auto* version = conf.add_versions(); version->set_version(compiler_version); version->set_path(compiler_path); @@ -399,7 +406,7 @@ TEST_F(AbsorberTest, BadClientVersion) { connection->CallOnSend([](const net::Connection::Message& message) { EXPECT_TRUE(message.HasExtension(net::proto::Status::extension)); const auto& status = message.GetExtension(net::proto::Status::extension); - EXPECT_EQ(net::proto::Status::BAD_CLIENT_VERSION, status.code()); + EXPECT_EQ(net::proto::Status::BAD_EMITTER_VERSION, status.code()); EXPECT_FALSE(message.HasExtension(proto::Result::extension)); }); @@ -413,7 +420,7 @@ TEST_F(AbsorberTest, BadClientVersion) { SharedPtr test_connection = std::static_pointer_cast(connection1); - auto message(CreateMessage(""_l, ""_l, ""_l, ""_l, 1)); + auto message(CreateMessage(""_l, ""_l, ""_l, ""_l, kBadEmitterVersion)); EXPECT_TRUE( test_connection->TriggerReadAsync(std::move(message), StatusOK())); @@ -437,7 +444,9 @@ TEST_F(AbsorberTest, BadClientVersion) { }); auto message(CreateMessage("source"_l, "action"_l, compiler_version)); auto* extension = message->MutableExtension(proto::Remote::extension); - extension->clear_client_version(); + + // Force to use default value from .proto file + extension->clear_emitter_version(); EXPECT_TRUE( test_connection->TriggerReadAsync(std::move(message), StatusOK())); @@ -460,5 +469,62 @@ TEST_F(AbsorberTest, BadClientVersion) { << "Daemon must not store references to the connection"; } +TEST_F(AbsorberTest, GoodCustomEmitterVersion) { + const String expected_host = "fake_host"; + const ui16 expected_port = 12345; + const String compiler_version("compiler_version"); + const String compiler_path("compiler_path"); + const ui32 custom_emitter_version = kGoodEmitterVersion + 1000; + + conf.mutable_absorber()->mutable_local()->set_host(expected_host); + conf.mutable_absorber()->mutable_local()->set_port(expected_port); + conf.mutable_absorber()->set_min_emitter_version(custom_emitter_version); + auto* version = conf.add_versions(); + version->set_version(compiler_version); + version->set_path(compiler_path); + + listen_callback = + [&expected_host, expected_port](const String& host, ui16 port, String*) { + EXPECT_EQ(expected_host, host); + EXPECT_EQ(expected_port, port); + return true; + }; + connect_callback = [](net::TestConnection* connection) { + connection->CallOnSend([](const net::Connection::Message& message) { + EXPECT_TRUE(message.HasExtension(net::proto::Status::extension)); + const auto& status = message.GetExtension(net::proto::Status::extension); + EXPECT_EQ(net::proto::Status::OK, status.code()); + + EXPECT_TRUE(message.HasExtension(proto::Result::extension)); + EXPECT_TRUE(message.GetExtension(proto::Result::extension).has_obj()); + }); + }; + + absorber.reset(new Absorber(conf)); + ASSERT_TRUE(absorber->Initialize()); + + auto connection = test_service->TriggerListen(expected_host, expected_port); + { + SharedPtr test_connection = + std::static_pointer_cast(connection); + + auto message(CreateMessage("source"_l, "action"_l, compiler_version, ""_l, + custom_emitter_version)); + EXPECT_TRUE( + test_connection->TriggerReadAsync(std::move(message), StatusOK())); + + absorber.reset(); + } + + EXPECT_EQ(1u, run_count); + EXPECT_EQ(1u, listen_count); + EXPECT_EQ(1u, connect_count); + EXPECT_EQ(1u, connections_created); + EXPECT_EQ(1u, read_count); + EXPECT_EQ(1u, send_count); + EXPECT_EQ(1, connection.use_count()) + << "Daemon must not store references to the connection"; +} + } // namespace daemon } // namespace dist_clang diff --git a/src/daemon/configuration.proto b/src/daemon/configuration.proto index 8463674a..85fa4288 100644 --- a/src/daemon/configuration.proto +++ b/src/daemon/configuration.proto @@ -57,7 +57,8 @@ message Configuration { } message Absorber { - required Host local = 1; + required Host local = 1; + optional uint32 min_emitter_version = 2 [ default = 100 ]; } message Collector { diff --git a/src/daemon/remote.proto b/src/daemon/remote.proto index f0fccfca..27240503 100644 --- a/src/daemon/remote.proto +++ b/src/daemon/remote.proto @@ -7,7 +7,7 @@ package dist_clang.daemon.proto; message Remote { optional base.proto.Flags flags = 1; optional string source = 2; - optional uint32 client_version = 3 [ default = 100 ]; + optional uint32 emitter_version = 3 [ default = 100 ]; extend net.proto.Universal { optional Remote extension = 6; diff --git a/src/net/universal.proto b/src/net/universal.proto index 341ecf95..0c230f6a 100644 --- a/src/net/universal.proto +++ b/src/net/universal.proto @@ -9,15 +9,15 @@ message Universal { message Status { enum Code { - OK = 0; - INCONSEQUENT = 1; - NETWORK = 2; - BAD_MESSAGE = 3; - EMPTY_MESSAGE = 4; - EXECUTION = 5; - OVERLOAD = 6; - NO_VERSION = 7; - BAD_CLIENT_VERSION = 8; + OK = 0; + INCONSEQUENT = 1; + NETWORK = 2; + BAD_MESSAGE = 3; + EMPTY_MESSAGE = 4; + EXECUTION = 5; + OVERLOAD = 6; + NO_VERSION = 7; + BAD_EMITTER_VERSION = 8; } required Code code = 1 [ default = OK ];