Skip to content

Commit

Permalink
planning: update feature generator
Browse files Browse the repository at this point in the history
  • Loading branch information
jmtao authored and xiaoxq committed Feb 19, 2020
1 parent d1d175a commit 332b2cf
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 45 deletions.
2 changes: 1 addition & 1 deletion modules/planning/pipeline/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ cc_library(
"//modules/canbus/proto:canbus_proto",
"//modules/common/adapters:adapter_gflags",
"//modules/localization/proto:localization_proto",
"//modules/planning/proto:instance_proto",
"//modules/planning/proto:learning_data_proto",
"//modules/prediction/util:data_extraction",
"//third_party:boost",
"@com_google_absl//absl/strings",
Expand Down
78 changes: 42 additions & 36 deletions modules/planning/pipeline/feature_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
#include "modules/common/adapters/adapter_gflags.h"

DEFINE_string(planning_data_dir, "/apollo/modules/planning/data/",
"Prefix of files to store instance data");
"Prefix of files to store learning_data_frame data");

DEFINE_int32(
instance_label_sample_interval, 100,
"total number of localization msgs to generate one instance label data.");
learning_data_frame_label_sample_interval, 100,
"total number of localization msgs to generate one "
"learning_data_frame label data.");

DEFINE_int32(instance_num_per_file, 100,
"number of instance to write out in one data file.");
DEFINE_int32(learning_data_frame_num_per_file, 100,
"number of learning_data_farame to write out in one data file.");

DEFINE_int32(
localization_sample_interval_for_trajectory_point, 10,
Expand All @@ -42,7 +43,7 @@ DEFINE_int32(localization_move_window_step, 5,
"number of localization msgs to skip after generating one label "
"trajectory point.");

DEFINE_bool(enable_binary_instance, true,
DEFINE_bool(enable_binary_learning_data, true,
"True to generate protobuf binary data file.");

namespace apollo {
Expand All @@ -53,37 +54,40 @@ using apollo::cyber::record::RecordMessage;
using apollo::cyber::record::RecordReader;
using apollo::localization::LocalizationEstimate;

void FeatureGenerator::Init() { instance_ = instances_.add_instances(); }
void FeatureGenerator::Init() {
learning_data_frame_ = learning_data_.add_learning_data(); }

void FeatureGenerator::WriteOutInstances(const Instances& instances,
const std::string& file_name) {
if (FLAGS_enable_binary_instance) {
cyber::common::SetProtoToBinaryFile(instances, file_name);
void FeatureGenerator::WriteOutLearningData(const LearningData& learning_data,
const std::string& file_name) {
if (FLAGS_enable_binary_learning_data) {
cyber::common::SetProtoToBinaryFile(learning_data, file_name);
} else {
cyber::common::SetProtoToASCIIFile(instances, file_name);
cyber::common::SetProtoToASCIIFile(learning_data, file_name);
}
}

void FeatureGenerator::Close() {
const std::string file_name = absl::StrCat(
FLAGS_planning_data_dir, "/instances.", instance_file_index_, ".bin");
total_instance_num_ += instances_.instances_size();
WriteOutInstances(instances_, file_name);
++instance_file_index_;
AINFO << "Total instance number:" << total_instance_num_;
FLAGS_planning_data_dir, "/learning_data.",
learning_data_frame_file_index_, ".bin");
total_learning_data_frame_num_ += learning_data_.learning_data_size();
WriteOutLearningData(learning_data_, file_name);
++learning_data_frame_file_index_;
AINFO << "Total learning_data_frame number:"
<< total_learning_data_frame_num_;
}

void FeatureGenerator::GenerateTrajectoryLabel(
const std::list<apollo::localization::LocalizationEstimate>&
localization_for_label,
Instance* instance) {
LearningDataFrame* learning_data_frame) {
int i = -1;
for (const auto& le : localization_for_label) {
++i;
if ((i % FLAGS_localization_sample_interval_for_trajectory_point) != 0) {
continue;
}
auto trajectory_point = instance->add_label_trajectory_points();
auto trajectory_point = learning_data_frame->add_label_trajectory_points();
auto& pose = le.pose();
trajectory_point->mutable_path_point()->set_x(pose.position().x());
trajectory_point->mutable_path_point()->set_y(pose.position().y());
Expand All @@ -101,12 +105,12 @@ void FeatureGenerator::GenerateTrajectoryLabel(

void FeatureGenerator::OnLocalization(
const apollo::localization::LocalizationEstimate& le) {
if (instance_ == nullptr) {
AERROR << "instance pointer is nullptr";
if (learning_data_frame_ == nullptr) {
AERROR << "learning_data_frame_ pointer is nullptr";
return;
}

auto features = instance_->mutable_localization_feature();
auto features = learning_data_frame_->mutable_localization_feature();
const auto& pose = le.pose();
features->mutable_position()->CopyFrom(pose.position());
features->set_heading(pose.heading());
Expand All @@ -116,31 +120,33 @@ void FeatureGenerator::OnLocalization(
localization_for_label_.push_back(le);

if (static_cast<int>(localization_for_label_.size()) >=
FLAGS_instance_label_sample_interval) {
GenerateTrajectoryLabel(localization_for_label_, instance_);
instance_ = instances_.add_instances();
FLAGS_learning_data_frame_label_sample_interval) {
GenerateTrajectoryLabel(localization_for_label_, learning_data_frame_);
learning_data_frame_ = learning_data_.add_learning_data();
for (int i = 0; i < FLAGS_localization_move_window_step; ++i) {
localization_for_label_.pop_front();
}
}

if (instances_.instances_size() >= FLAGS_instance_num_per_file) {
const std::string file_name = absl::StrCat(
FLAGS_planning_data_dir, "/instances.", instance_file_index_, ".bin");
WriteOutInstances(instances_, file_name);
total_instance_num_ += instances_.instances_size();
instances_.Clear();
++instance_file_index_;
instance_ = instances_.add_instances();
if (learning_data_.learning_data_size() >=
FLAGS_learning_data_frame_num_per_file) {
const std::string file_name =
absl::StrCat(FLAGS_planning_data_dir, "/learning_data.",
learning_data_frame_file_index_, ".bin");
WriteOutLearningData(learning_data_, file_name);
total_learning_data_frame_num_ += learning_data_.learning_data_size();
learning_data_.Clear();
++learning_data_frame_file_index_;
learning_data_frame_ = learning_data_.add_learning_data();
}
}

void FeatureGenerator::OnChassis(const apollo::canbus::Chassis& chassis) {
if (instance_ == nullptr) {
AERROR << "instance pointer is nullptr";
if (learning_data_frame_ == nullptr) {
AERROR << "learning_data_frame_ pointer is nullptr";
return;
}
auto features = instance_->mutable_chassis_feature();
auto features = learning_data_frame_->mutable_chassis_feature();
features->set_speed_mps(chassis.speed_mps());
features->set_throttle_percentage(chassis.throttle_percentage());
features->set_brake_percentage(chassis.brake_percentage());
Expand Down
16 changes: 8 additions & 8 deletions modules/planning/pipeline/feature_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "cyber/common/file.h"
#include "modules/canbus/proto/chassis.pb.h"
#include "modules/localization/proto/localization.pb.h"
#include "modules/planning/proto/instance.pb.h"
#include "modules/planning/proto/learning_data.pb.h"

namespace apollo {
namespace planning {
Expand All @@ -36,19 +36,19 @@ class FeatureGenerator {
private:
void OnLocalization(const apollo::localization::LocalizationEstimate& le);
void OnChassis(const apollo::canbus::Chassis& chassis);
void WriteOutInstances(const Instances& instances,
const std::string& file_name);
void WriteOutLearningData(const LearningData& learning_data,
const std::string& file_name);
void GenerateTrajectoryLabel(
const std::list<apollo::localization::LocalizationEstimate>&
localization_for_label,
Instance* instance);
LearningDataFrame* learning_data_frame);

Instance* instance_ = nullptr; // not owned
Instances instances_;
int instance_file_index_ = 0;
LearningDataFrame* learning_data_frame_ = nullptr; // not owned
LearningData learning_data_;
int learning_data_frame_file_index_ = 0;
std::list<apollo::localization::LocalizationEstimate>
localization_for_label_;
int total_instance_num_ = 0;
int total_learning_data_frame_num_ = 0;
};

} // namespace planning
Expand Down

0 comments on commit 332b2cf

Please sign in to comment.