Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented vision_msgs #40

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion object_detection/config/params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ object_detection:
input_img_topic: color_camera/image_raw
output_bb_topic: object_detection/img_bb
output_img_topic: object_detection/img
output_vision_topic: object_detection/vision_msg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename the topic to object_detection/detection_info as it might be more self explanatory.

model_params:
detector_type: YOLOv5
model_dir_path: models/
weight_file_name: auto_final.onnx
weight_file_name: yolov5.onnx
confidence_threshold : 0.7
show_fps : 1
2 changes: 1 addition & 1 deletion object_detection/object_detection/Detectors/YOLOv5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class YOLOv5(DetectorBase):
def __init__(self, conf_threshold = 0.7, score_threshold = 0.4, nms_threshold = 0.25, is_cuda = 0):
def __init__(self, conf_threshold = 0.7, score_threshold = 0.4, nms_threshold = 0.25, is_cuda = 1):

super().__init__()

Expand Down
13 changes: 5 additions & 8 deletions object_detection/object_detection/Detectors/YOLOv8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
import time

class YOLOv8:
def __init__(self, model_dir_path, weight_file_name, conf_threshold = 0.7,
def __init__(self, conf_threshold = 0.7,
score_threshold = 0.4, nms_threshold = 0.25,
show_fps = 1, is_cuda = 0):

self.model_dir_path = model_dir_path
self.weight_file_name = weight_file_name


self.conf_threshold = conf_threshold
Expand All @@ -30,20 +27,20 @@ def __init__(self, model_dir_path, weight_file_name, conf_threshold = 0.7,
self.load_classes()


def build_model(self) :
def build_model(self,model_dir_path,weight_file_name) :

try :
model_path = os.path.join(self.model_dir_path, self.weight_file_name)
model_path = os.path.join(model_dir_path, weight_file_name)
self.model = YOLO(model_path)

except :
raise Exception("Error loading given model from path: {}. Maybe the file doesn't exist?".format(model_path))

def load_classes(self):
def load_classes(self, model_dir_path):

self.class_list = []

with open(self.model_dir_path + "/classes.txt", "r") as f:
with open(model_dir_path + "/classes.txt", "r") as f:
self.class_list = [cname.strip() for cname in f.readlines()]

return self.class_list
Expand Down
42 changes: 39 additions & 3 deletions object_detection/object_detection/ObjectDetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from rclpy.node import Node

from sensor_msgs.msg import Image
#from vision_msgs.msg import BoundingBox2D

from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose, Pose2D, Point2D, ObjectHypothesis
from cv_bridge import CvBridge
import cv2

Expand All @@ -30,6 +29,7 @@ def __init__(self):
('input_img_topic', ""),
('output_bb_topic', ""),
('output_img_topic', ""),
('output_vision_topic', ""),
('model_params.detector_type', ""),
('model_params.model_dir_path', ""),
('model_params.weight_file_name', ""),
Expand All @@ -42,6 +42,7 @@ def __init__(self):
self.input_img_topic = self.get_parameter('input_img_topic').value
self.output_bb_topic = self.get_parameter('output_bb_topic').value
self.output_img_topic = self.get_parameter('output_img_topic').value
self.output_vision_topic = self.get_parameter('output_vision_topic').value

# model params
self.detector_type = self.get_parameter('model_params.detector_type').value
Expand All @@ -50,6 +51,9 @@ def __init__(self):
self.confidence_threshold = self.get_parameter('model_params.confidence_threshold').value
self.show_fps = self.get_parameter('model_params.show_fps').value

print(f"Model dir: {self.model_dir_path}")
print(f"Model: {self.weight_file_name}")

# raise an exception if specified detector was not found
if self.detector_type not in self.available_detectors:
raise ModuleNotFoundError(self.detector_type + " Detector specified in config was not found. " +
Expand All @@ -62,6 +66,8 @@ def __init__(self):
self.bb_pub = None
self.img_sub = self.create_subscription(Image, self.input_img_topic, self.detection_cb, 10)

self.vision_msg_pub = self.create_publisher(Detection2DArray, self.output_vision_topic, 10)

self.bridge = CvBridge()


Expand Down Expand Up @@ -91,6 +97,9 @@ def detection_cb(self, img_msg):
cv_image = self.bridge.imgmsg_to_cv2(img_msg, "bgr8")

predictions = self.detector.get_predictions(cv_image=cv_image)


detection_arr = Detection2DArray()

if predictions == None :
print("Image input from topic : {} is empty".format(self.input_img_topic))
Expand All @@ -100,12 +109,39 @@ def detection_cb(self, img_msg):
right = left + width
bottom = top + height

class_id = str(prediction['class_id'])
conf = float(prediction['confidence'])

detection_msg = Detection2D()
detection_msg.bbox.size_x = float(width)
detection_msg.bbox.size_y = float(height)

position_msg = Point2D()
position_msg.x = float((left + right) / 2)
position_msg.y = float((bottom + top) / 2)

center_msg = Pose2D()
center_msg.position = position_msg

detection_msg.bbox.center = center_msg

results_msg = ObjectHypothesisWithPose()
hypothesis_msg = ObjectHypothesis()
hypothesis_msg.class_id = class_id
hypothesis_msg.score = conf

results_msg.hypothesis = hypothesis_msg
detection_msg.results.append(results_msg)

detection_arr.detections.append(detection_msg)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments to everyline to make code understandable to others.

#Draw the bounding box
cv_image = cv2.rectangle(cv_image,(left,top),(right, bottom),(0,255,0),1)

output = self.bridge.cv2_to_imgmsg(cv_image, "bgr8")
self.img_pub.publish(output)
print(predictions)
self.vision_msg_pub.publish(detection_arr)


def main():
Expand Down
2 changes: 2 additions & 0 deletions object_detection/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
<maintainer email="[email protected]">singh</maintainer>
<license>TODO: License declaration</license>

<depend>vision_msgs</depend>

<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
Expand Down
2 changes: 1 addition & 1 deletion object_detection/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
(os.path.join('share', package_name, 'launch'), glob('launch/*.launch.py')),

],
install_requires=['setuptools'],
install_requires=['setuptools', 'vision_msgs'],
zip_safe=True,
maintainer='singh',
maintainer_email='[email protected]',
Expand Down