Skip to content
This repository has been archived by the owner on Jul 8, 2024. It is now read-only.

Commit

Permalink
Clean up trajoptrust type conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Jun 26, 2024
1 parent 5f237b7 commit f9f3130
Showing 1 changed file with 76 additions and 113 deletions.
189 changes: 76 additions & 113 deletions src/trajoptlibrust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,75 +15,35 @@
#include "trajopt/constraint/LinearVelocityDirectionConstraint.hpp"
#include "trajopt/constraint/LinearVelocityMaxMagnitudeConstraint.hpp"
#include "trajopt/constraint/PointAtConstraint.hpp"
#include "trajopt/drivetrain/SwerveDrivetrain.hpp"
#include "trajopt/geometry/Translation2.hpp"
#include "trajopt/drivetrain/SwerveModule.hpp"
#include "trajopt/trajectory/HolonomicTrajectory.hpp"
#include "trajopt/trajectory/HolonomicTrajectorySample.hpp"
#include "trajoptlib/src/lib.rs.h"

namespace trajoptlibrust {

template <typename FromType, typename ToType, typename FromVecType,
typename ToVecType, ToType (*Converter)(const FromType&)>
ToVecType _convert_generic_vec(const FromVecType& fromVec) {
ToVecType toVec;
toVec.reserve(fromVec.size());
for (const FromType& item : fromVec) {
toVec.emplace_back(Converter(item));
}
return toVec;
}

template <typename RustType, typename CppType,
CppType (*Converter)(const RustType&)>
std::vector<CppType> _rust_vec_to_cpp_vector(
const rust::Vec<RustType>& rustVec) {
return _convert_generic_vec<RustType, CppType, rust::Vec<RustType>,
std::vector<CppType>, Converter>(rustVec);
}

template <typename CppType, typename RustType,
RustType (*Converter)(const CppType&)>
rust::Vec<RustType> _cpp_vector_to_rust_vec(
const std::vector<CppType>& cppVec) {
return _convert_generic_vec<CppType, RustType, std::vector<CppType>,
rust::Vec<RustType>, Converter>(cppVec);
}

trajopt::SwerveModule _convert_swerve_module(const SwerveModule& swerveModule) {
return trajopt::SwerveModule{
trajopt::Translation2d{swerveModule.x, swerveModule.y},
swerveModule.wheel_radius, swerveModule.wheel_max_angular_velocity,
swerveModule.wheel_max_torque};
}

trajopt::SwerveDrivetrain _convert_swerve_drivetrain(
const SwerveDrivetrain& drivetrain) {
return trajopt::SwerveDrivetrain{
.mass = drivetrain.mass,
.moi = drivetrain.moi,
.modules =
_rust_vec_to_cpp_vector<SwerveModule, trajopt::SwerveModule,
&_convert_swerve_module>(drivetrain.modules)};
}

trajopt::Pose2d _convert_initial_guess_point(const Pose2d& initialGuessPoint) {
return {initialGuessPoint.x, initialGuessPoint.y, initialGuessPoint.heading};
}

void SwervePathBuilderImpl::set_drivetrain(const SwerveDrivetrain& drivetrain) {
path.SetDrivetrain(_convert_swerve_drivetrain(drivetrain));
}
std::vector<trajopt::SwerveModule> cppModules;
for (const auto& module : drivetrain.modules) {
cppModules.push_back(
trajopt::SwerveModule{{module.x, module.y},
module.wheel_radius,
module.wheel_max_angular_velocity,
module.wheel_max_torque});
}

size_t _convert_count(const size_t& count) {
return count;
path.SetDrivetrain(trajopt::SwerveDrivetrain{drivetrain.mass, drivetrain.moi,
std::move(cppModules)});
}

void SwervePathBuilderImpl::set_control_interval_counts(
const rust::Vec<size_t> counts) {
std::vector<size_t> converted_counts =
_rust_vec_to_cpp_vector<size_t, size_t, &_convert_count>(counts);
path.ControlIntervalCounts(std::move(converted_counts));
std::vector<size_t> cppCounts;
for (const auto& count : counts) {
cppCounts.emplace_back(count);
}

path.ControlIntervalCounts(std::move(cppCounts));
}

void SwervePathBuilderImpl::set_bumpers(double length, double width) {
Expand Down Expand Up @@ -111,10 +71,13 @@ void SwervePathBuilderImpl::empty_wpt(size_t index, double x_guess,

void SwervePathBuilderImpl::sgmt_initial_guess_points(
size_t from_index, const rust::Vec<Pose2d>& guess_points) {
std::vector<trajopt::Pose2d> convertedGuessPoints =
_rust_vec_to_cpp_vector<Pose2d, trajopt::Pose2d,
&_convert_initial_guess_point>(guess_points);
path.SgmtInitialGuessPoints(from_index, convertedGuessPoints);
std::vector<trajopt::Pose2d> cppGuessPoints;
for (const auto& guess_point : guess_points) {
cppGuessPoints.emplace_back(guess_point.x, guess_point.y,
guess_point.heading);
}

path.SgmtInitialGuessPoints(from_index, std::move(cppGuessPoints));
}

void SwervePathBuilderImpl::wpt_linear_velocity_direction(size_t index,
Expand Down Expand Up @@ -181,74 +144,58 @@ void SwervePathBuilderImpl::sgmt_point_at(size_t from_index, size_t to_index,
double field_point_y,
double heading_tolerance) {
path.SgmtConstraint(from_index, to_index,
trajopt::PointAtConstraint{
trajopt::Translation2d{field_point_x, field_point_y},
heading_tolerance});
trajopt::PointAtConstraint{{field_point_x, field_point_y},
heading_tolerance});
}

void SwervePathBuilderImpl::sgmt_circle_obstacle(size_t from_index,
size_t to_index, double x,
double y, double radius) {
auto obstacle =
trajopt::Obstacle{.safetyDistance = radius, .points = {{x, y}}};
path.SgmtObstacle(from_index, to_index, obstacle);
path.SgmtObstacle(from_index, to_index, {radius, {{x, y}}});
}

void SwervePathBuilderImpl::sgmt_polygon_obstacle(size_t from_index,
size_t to_index,
const rust::Vec<double> x,
const rust::Vec<double> y,
double radius) {
std::vector<trajopt::Translation2d> points;
if (x.size() != y.size()) {
if (x.size() != y.size()) [[unlikely]] {
return;
}
for (size_t i = 0; i < x.size(); i++) {
points.push_back({x.at(i), y.at(i)});
}
auto obstacle = trajopt::Obstacle{.safetyDistance = radius, .points = points};
path.SgmtObstacle(from_index, to_index, obstacle);
}

HolonomicTrajectorySample _convert_holonomic_trajectory_sample(
const trajopt::HolonomicTrajectorySample& sample) {
// copy data into rust vecs
rust::Vec<double> fx;
std::copy(sample.moduleForcesX.begin(), sample.moduleForcesX.end(),
std::back_inserter(fx));

rust::Vec<double> fy;
std::copy(sample.moduleForcesY.begin(), sample.moduleForcesY.end(),
std::back_inserter(fy));

return HolonomicTrajectorySample{
.timestamp = sample.timestamp,
.x = sample.x,
.y = sample.y,
.heading = sample.heading,
.velocity_x = sample.velocityX,
.velocity_y = sample.velocityY,
.angular_velocity = sample.angularVelocity,
.module_forces_x = std::move(fx),
.module_forces_y = std::move(fy),
};
}
std::vector<trajopt::Translation2d> cppPoints;
for (size_t i = 0; i < x.size(); ++i) {
cppPoints.emplace_back(x.at(i), y.at(i));
}

HolonomicTrajectory _convert_holonomic_trajectory(
const trajopt::HolonomicTrajectory& trajectory) {
return HolonomicTrajectory{
.samples = _cpp_vector_to_rust_vec<trajopt::HolonomicTrajectorySample,
HolonomicTrajectorySample,
&_convert_holonomic_trajectory_sample>(
trajectory.samples)};
path.SgmtObstacle(from_index, to_index,
trajopt::Obstacle{.safetyDistance = radius,
.points = std::move(cppPoints)});
}

HolonomicTrajectory SwervePathBuilderImpl::generate(bool diagnostics,
int64_t handle) const {
trajopt::SwerveTrajectoryGenerator generator{path, handle};
if (auto sol = generator.Generate(diagnostics); sol.has_value()) {
return _convert_holonomic_trajectory(
trajopt::HolonomicTrajectory{sol.value()});
trajopt::HolonomicTrajectory cppTrajectory{sol.value()};

rust::Vec<HolonomicTrajectorySample> rustSamples;
for (const auto& cppSample : cppTrajectory.samples) {
rust::Vec<double> fx;
std::copy(cppSample.moduleForcesX.begin(), cppSample.moduleForcesX.end(),
std::back_inserter(fx));

rust::Vec<double> fy;
std::copy(cppSample.moduleForcesY.begin(), cppSample.moduleForcesY.end(),
std::back_inserter(fy));

rustSamples.push_back(HolonomicTrajectorySample{
cppSample.timestamp, cppSample.x, cppSample.y, cppSample.heading,
cppSample.velocityX, cppSample.velocityY, cppSample.angularVelocity,
std::move(fx), std::move(fy)});
}

return HolonomicTrajectory{std::move(rustSamples)};
} else {
throw std::runtime_error{sol.error()};
}
Expand All @@ -265,12 +212,28 @@ HolonomicTrajectory SwervePathBuilderImpl::generate(bool diagnostics,
*/
void SwervePathBuilderImpl::add_progress_callback(
rust::Fn<void(HolonomicTrajectory, int64_t)> callback) {
path.AddIntermediateCallback([=](trajopt::SwerveSolution& solution,
int64_t handle) {
callback(
_convert_holonomic_trajectory(trajopt::HolonomicTrajectory{solution}),
handle);
});
path.AddIntermediateCallback(
[=](trajopt::SwerveSolution& solution, int64_t handle) {
trajopt::HolonomicTrajectory cppTrajectory{solution};

rust::Vec<HolonomicTrajectorySample> rustSamples;
for (const auto& cppSample : cppTrajectory.samples) {
rust::Vec<double> fx;
std::copy(cppSample.moduleForcesX.begin(),
cppSample.moduleForcesX.end(), std::back_inserter(fx));

rust::Vec<double> fy;
std::copy(cppSample.moduleForcesY.begin(),
cppSample.moduleForcesY.end(), std::back_inserter(fy));

rustSamples.push_back(HolonomicTrajectorySample{
cppSample.timestamp, cppSample.x, cppSample.y, cppSample.heading,
cppSample.velocityX, cppSample.velocityY,
cppSample.angularVelocity, std::move(fx), std::move(fy)});
}

callback(HolonomicTrajectory{rustSamples}, handle);
});
}

void SwervePathBuilderImpl::cancel_all() {
Expand Down

0 comments on commit f9f3130

Please sign in to comment.