Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 29, check for bad client version #40

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/daemon/absorber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ using namespace std::placeholders;
namespace dist_clang {
namespace daemon {

namespace {

bool ValidateClientVersion(ui32 client_version) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this function only I can say that this approach is not anticipated. The backward-compatible version should be the field in the Absorber config. Since the usual situation is when the clients (Emitters) are updated much more often than Absorbers, it's convenient to update this field without redistributing the new package.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will do.

return client_version >= Absorber::kMinGoodClientVersion;
}

} // namespace

// static
const ui32 Absorber::kMinGoodClientVersion = 100;

Absorber::Absorber(const proto::Configuration& configuration)
: CompilationDaemon(configuration) {
using Worker = base::WorkerPool::SimpleWorker;
Expand Down Expand Up @@ -67,6 +78,14 @@ bool Absorber::HandleNewMessage(net::ConnectionPtr connection,

if (message->HasExtension(proto::Remote::extension)) {
Message execute(message->ReleaseExtension(proto::Remote::extension));
if (!ValidateClientVersion(execute->client_version())) {
LOG(WARNING) << "Client sent a bad version: "
<< execute->client_version();
net::proto::Status bad_version_status;
bad_version_status.set_code(net::proto::Status::BAD_CLIENT_VERSION);
return connection->ReportStatus(bad_version_status);
}

DCHECK(!execute->flags().compiler().has_path());
if (execute->has_source()) {
return tasks_->Push(Task{connection, std::move(execute)});
Expand Down
2 changes: 2 additions & 0 deletions src/daemon/absorber.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace daemon {

class Absorber : public CompilationDaemon {
public:
static const ui32 kMinGoodClientVersion;

explicit Absorber(const proto::Configuration& configuration);
virtual ~Absorber();

Expand Down
94 changes: 90 additions & 4 deletions src/daemon/absorber_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ namespace {

// S1..4 types can be one of the following: |String|, |Immutable|, |Literal|
template <typename S1, typename S2, typename S3, typename S4 = String>
net::Connection::ScopedMessage CreateMessage(const S1& source,
const S2& action,
const S3& compiler_version,
const S4& language = S4()) {
net::Connection::ScopedMessage CreateMessage(
const S1& source,
const S2& action,
const S3& compiler_version,
const S4& language = S4(),
ui32 client_version = Absorber::kMinGoodClientVersion) {
net::Connection::ScopedMessage message(new net::Connection::Message);
auto* extension = message->MutableExtension(proto::Remote::extension);
extension->set_source(source);
extension->set_client_version(client_version);
auto* compiler = extension->mutable_flags()->mutable_compiler();
compiler->set_version(compiler_version);
extension->mutable_flags()->set_action(action);
Expand Down Expand Up @@ -374,5 +377,88 @@ TEST_F(AbsorberTest, BadMessageStatus) {
<< "Daemon must not store references to the connection";
}

TEST_F(AbsorberTest, BadClientVersion) {
const String expected_host = "fake_host";
const ui16 expected_port = 12345;
const String compiler_version("compiler_version");
const String compiler_path("compiler_path");

conf.mutable_absorber()->mutable_local()->set_host(expected_host);
conf.mutable_absorber()->mutable_local()->set_port(expected_port);
auto* version = conf.add_versions();
version->set_version(compiler_version);
version->set_path(compiler_path);

listen_callback =
[&expected_host, expected_port](const String& host, ui16 port, String*) {
EXPECT_EQ(expected_host, host);
EXPECT_EQ(expected_port, port);
return true;
};
connect_callback = [](net::TestConnection* connection) {
connection->CallOnSend([](const net::Connection::Message& message) {
EXPECT_TRUE(message.HasExtension(net::proto::Status::extension));
const auto& status = message.GetExtension(net::proto::Status::extension);
EXPECT_EQ(net::proto::Status::BAD_CLIENT_VERSION, status.code());

EXPECT_FALSE(message.HasExtension(proto::Result::extension));
});
};

absorber.reset(new Absorber(conf));
ASSERT_TRUE(absorber->Initialize());

auto connection1 = test_service->TriggerListen(expected_host, expected_port);
{
SharedPtr<net::TestConnection> test_connection =
std::static_pointer_cast<net::TestConnection>(connection1);

auto message(CreateMessage(""_l, ""_l, ""_l, ""_l, 1));
EXPECT_TRUE(
test_connection->TriggerReadAsync(std::move(message), StatusOK()));

UniqueLock lock(send_mutex);
ASSERT_TRUE(send_condition.wait_for(lock, std::chrono::seconds(1),
[this] { return send_count == 1; }));
}

auto connection2 = test_service->TriggerListen(expected_host, expected_port);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part doesn't correspond to the test's purpose - according to its name. It's better to leave only the first connection.

{
SharedPtr<net::TestConnection> test_connection =
std::static_pointer_cast<net::TestConnection>(connection2);

test_connection->CallOnSend([](const net::Connection::Message& message) {
EXPECT_TRUE(message.HasExtension(net::proto::Status::extension));
const auto& status = message.GetExtension(net::proto::Status::extension);
EXPECT_EQ(net::proto::Status::OK, status.code());

EXPECT_TRUE(message.HasExtension(proto::Result::extension));
EXPECT_TRUE(message.GetExtension(proto::Result::extension).has_obj());
});
auto message(CreateMessage("source"_l, "action"_l, compiler_version));
auto* extension = message->MutableExtension(proto::Remote::extension);
extension->clear_client_version();
EXPECT_TRUE(
test_connection->TriggerReadAsync(std::move(message), StatusOK()));

UniqueLock lock(send_mutex);
ASSERT_TRUE(send_condition.wait_for(lock, std::chrono::seconds(1),
[this] { return send_count == 2; }));
}

absorber.reset();

EXPECT_EQ(1u, run_count);
EXPECT_EQ(1u, listen_count);
EXPECT_EQ(2u, connect_count);
EXPECT_EQ(2u, connections_created);
EXPECT_EQ(2u, read_count);
EXPECT_EQ(2u, send_count);
EXPECT_EQ(1, connection1.use_count())
<< "Daemon must not store references to the connection";
EXPECT_EQ(1, connection2.use_count())
<< "Daemon must not store references to the connection";
}

} // namespace daemon
} // namespace dist_clang
1 change: 1 addition & 0 deletions src/daemon/remote.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package dist_clang.daemon.proto;
message Remote {
optional base.proto.Flags flags = 1;
optional string source = 2;
optional uint32 client_version = 3 [ default = 100 ];

extend net.proto.Universal {
optional Remote extension = 6;
Expand Down
1 change: 1 addition & 0 deletions src/net/universal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ message Status {
EXECUTION = 5;
OVERLOAD = 6;
NO_VERSION = 7;
BAD_CLIENT_VERSION = 8;
}

required Code code = 1 [ default = OK ];
Expand Down