Skip to content

Commit

Permalink
Let PID controller parse command joint names instead of joints
Browse files Browse the repository at this point in the history
  • Loading branch information
christophfroehlich committed Oct 23, 2023
1 parent 05ca67c commit 0c2f441
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ controller_interface::CallbackReturn JointTrajectoryController::on_configure(
params_.controller_plugin.c_str(), ex.what());
return CallbackReturn::FAILURE;
}
if (traj_contr_->initialize(get_node()) == false)
if (traj_contr_->initialize(get_node(), command_joint_names_) == false)
{
RCLCPP_FATAL(
logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define JOINT_TRAJECTORY_CONTROLLER_PLUGINS__PID_TRAJECTORY_PLUGIN_HPP_

#include <memory>
#include <string>
#include <vector>

#include "control_toolbox/pid.hpp"
Expand All @@ -32,7 +33,9 @@ namespace joint_trajectory_controller_plugins
class PidTrajectoryPlugin : public TrajectoryControllerBase
{
public:
bool initialize(rclcpp_lifecycle::LifecycleNode::SharedPtr node) override;
bool initialize(
rclcpp_lifecycle::LifecycleNode::SharedPtr node,
std::vector<std::string> command_joint_names) override;

bool computeGains(const trajectory_msgs::msg::JointTrajectory trajectory) override;

Expand All @@ -45,8 +48,10 @@ class PidTrajectoryPlugin : public TrajectoryControllerBase
void reset() override;

protected:
// degree of freedom
size_t dof_;
// number of command joints
size_t num_cmd_joints_;
// name of the command joints
std::vector<std::string> command_joint_names_;
// PID controllers
std::vector<PidPtr> pids_;
// Feed-forward velocity weight factor when calculating closed loop pid adapter's command
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef JOINT_TRAJECTORY_CONTROLLER_PLUGINS__TRAJECTORY_CONTROLLER_BASE_HPP_
#define JOINT_TRAJECTORY_CONTROLLER_PLUGINS__TRAJECTORY_CONTROLLER_BASE_HPP_

#include <string>
#include <vector>

#include "rclcpp/rclcpp.hpp"
Expand All @@ -39,7 +40,9 @@ class TrajectoryControllerBase
/**
*/
JOINT_TRAJECTORY_CONTROLLER_PLUGINS_PUBLIC
virtual bool initialize(rclcpp_lifecycle::LifecycleNode::SharedPtr node) = 0;
virtual bool initialize(
rclcpp_lifecycle::LifecycleNode::SharedPtr node,
std::vector<std::string> command_joint_names) = 0;

/**
*/
Expand Down
31 changes: 22 additions & 9 deletions joint_trajectory_controller_plugins/src/pid_trajectory_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,36 @@
namespace joint_trajectory_controller_plugins
{

bool PidTrajectoryPlugin::initialize(rclcpp_lifecycle::LifecycleNode::SharedPtr node)
bool PidTrajectoryPlugin::initialize(
rclcpp_lifecycle::LifecycleNode::SharedPtr node, std::vector<std::string> command_joint_names)
{
node_ = node;
command_joint_names_ = command_joint_names;

try
{
// Create the parameter listener and get the parameters
param_listener_ = std::make_shared<ParamListener>(node_);
params_ = param_listener_->get_params();
dof_ = params_.joints.size();
}
catch (const std::exception & e)
{
fprintf(stderr, "Exception thrown during init stage with message: %s \n", e.what());
return false;
}
RCLCPP_INFO(node_->get_logger(), "[PidTrajectoryPlugin] Initialized with %lu joints.", dof_);

// parse read-only params
num_cmd_joints_ = command_joint_names_.size();
if (num_cmd_joints_ == 0)
{
RCLCPP_ERROR(node_->get_logger(), "[PidTrajectoryPlugin] No command joints specified.");
return false;
}
pids_.resize(num_cmd_joints_);
ff_velocity_scale_.resize(num_cmd_joints_);

RCLCPP_INFO(
node_->get_logger(), "[PidTrajectoryPlugin] Initialized with %lu joints.", num_cmd_joints_);
return true;
}

Expand All @@ -45,13 +58,13 @@ bool PidTrajectoryPlugin::computeGains(const trajectory_msgs::msg::JointTrajecto
RCLCPP_DEBUG(node_->get_logger(), "[PidTrajectoryPlugin] Updated parameters");
}

pids_.resize(dof_);
ff_velocity_scale_.resize(dof_);
pids_.resize(num_cmd_joints_);
ff_velocity_scale_.resize(num_cmd_joints_);

// Init PID gains from ROS parameters
for (size_t i = 0; i < dof_; ++i)
for (size_t i = 0; i < num_cmd_joints_; ++i)
{
const auto & gains = params_.gains.joints_map.at(params_.joints[i]);
const auto & gains = params_.gains.joints_map.at(command_joint_names_[i]);
pids_[i] = std::make_shared<control_toolbox::Pid>(
gains.p, gains.i, gains.d, gains.i_clamp, -gains.i_clamp);

Expand All @@ -60,7 +73,7 @@ bool PidTrajectoryPlugin::computeGains(const trajectory_msgs::msg::JointTrajecto

RCLCPP_INFO(
node_->get_logger(),
"[PidTrajectoryPlugin] Loaded PID gains from ROS parameters for %lu joints.", dof_);
"[PidTrajectoryPlugin] Loaded PID gains from ROS parameters for %lu joints.", num_cmd_joints_);
return true;
}

Expand All @@ -71,7 +84,7 @@ void PidTrajectoryPlugin::computeCommands(
const rclcpp::Duration & period)
{
// Update PIDs
for (auto i = 0ul; i < dof_; ++i)
for (auto i = 0ul; i < num_cmd_joints_; ++i)
{
tmp_command[i] = (desired.velocities[i] * ff_velocity_scale_[i]) +
pids_[i]->computeCommand(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TEST_F(PidTrajectoryTest, TestEmptySetup)
std::shared_ptr<TestableJointTrajectoryControllerPlugin> traj_contr =
std::make_shared<TestableJointTrajectoryControllerPlugin>();

ASSERT_FALSE(traj_contr->initialize(node_));
ASSERT_FALSE(traj_contr->initialize(node_, std::vector<std::string>()));
}

TEST_F(PidTrajectoryTest, TestSingleJoint)
Expand All @@ -33,7 +33,7 @@ TEST_F(PidTrajectoryTest, TestSingleJoint)
// override read_only parameter
node_->declare_parameter("joints", joint_names_paramv);

ASSERT_TRUE(traj_contr->initialize(node_));
ASSERT_TRUE(traj_contr->initialize(node_, joint_names));

// set dynamic parameters
traj_contr->trigger_declare_parameters();
Expand Down Expand Up @@ -63,7 +63,7 @@ TEST_F(PidTrajectoryTest, TestMultipleJoints)
// override read_only parameter
node_->declare_parameter("joints", joint_names_paramv);

ASSERT_TRUE(traj_contr->initialize(node_));
ASSERT_TRUE(traj_contr->initialize(node_, joint_names));

// set dynamic parameters
traj_contr->trigger_declare_parameters();
Expand Down

0 comments on commit 0c2f441

Please sign in to comment.