-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
updated readme, started inference server
- Loading branch information
Showing
17 changed files
with
309 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[submodule "third_party/t2r/tensor2robot"] | ||
path = third_party/t2r/tensor2robot | ||
url = https://github.com/sebbyjp/tensor2robot.git | ||
[submodule "third_party/rt1/robotics_transformer"] | ||
path = third_party/rt1/robotics_transformer | ||
url = https://www.github.com/google-research/robotics_transformer.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,91 @@ | ||
[![Code Coverage](https://codecov.io/gh/sebbyjp/dgl_ros/branch/code_cov/graph/badge.svg?token=9225d677-c4f2-4607-a9dd-8c22446f13bc)](https://codecov.io/gh/sebbyjp/dgl_ros) | ||
|
||
# Library for Robotic Transformers. RT-1 and RT-X-1. | ||
# Instructions: | ||
|
||
- Clone this repo: `https://github.com/sebbyjp/robo_transformers.git --recurse-submodules` | ||
- Run `pip install poetry` | ||
- Run `poetry build .` | ||
- Run inference `python robo_transformers/rt1_inference.py` | ||
## Download RT-1-X model from the Open-X Embodiment paper. | ||
## Install: | ||
|
||
Requirements: | ||
- python3.11 | ||
|
||
Clone this repo: | ||
`git clone https://github.com/sebbyjp/robo_transformers.git --recurse-submodules` | ||
|
||
`cd robo_transformers` | ||
|
||
Use poetry | ||
|
||
`pip install poetry` | ||
|
||
Create the requirements.txt: | ||
|
||
`poetry install && poetry export` | ||
|
||
**Install dependencies:** | ||
|
||
- *Optional*: Activate a virtual environment: | ||
|
||
- `poetry config virtualenvs.in-project true` | ||
- `poetry env use .venv/bin/activate` (Run `poetry env info` to find the proper path) | ||
|
||
`python3 -m pip install -r requirements.txt` | ||
|
||
## Run Inference On Demo Images. | ||
`python robo_transformers/inference/rt1/rt1_inference.py` | ||
|
||
### Optional | ||
To install the checkpoints from the robotics_transformer git repo, you will need git-lfs | ||
- Install git-lfs (use brew or apt if on unix), then | ||
|
||
`cd third_party/rt1/robotics_transformer && git lfs install` | ||
`git lfs pull https://www.github.com/google-research/robotics_transformer.git ` | ||
|
||
### Optional: Download RT-1-X model from the Open-X Embodiment paper. | ||
- Install gsutil: `pip install gsutil` | ||
- Run: `gsutil -m cp -r gs://gdm-robotics-open-x-embodiment/open_x_embodiment_and_rt_x_oss/rt_1_x_tf_trained_for_002272480_step.zip ./checkpoints/` | ||
- Unzip: `cd checkpoints && unzip rt_1_x_tf_trained_for_002272480_step.zip` | ||
|
||
## Optional | ||
To install the checkpoints from the robotics_transformer git repo, you will need git-lfs | ||
- Install git-lfs (use brew or apt if on unix) | ||
- Run `git lfs install` | ||
- Run `git lfs clone https://www.github.com/google-research/robotics_transformer.git ` | ||
## Notes | ||
`action, next_policy_state = model.act(time_step, curr_policy_state)` | ||
### policy state is internal state of network: | ||
In this case it is a 6-frame window of past observations,actions and the index in time. | ||
``` | ||
{'action_tokens': ArraySpec(shape=(6, 11, 1, 1), dtype=dtype('int32'), name='action_tokens'), | ||
'image': ArraySpec(shape=(6, 256, 320, 3), dtype=dtype('uint8'), name='image'), | ||
'step_num': ArraySpec(shape=(1, 1, 1, 1), dtype=dtype('int32'), name='step_num'), | ||
't': ArraySpec(shape=(1, 1, 1, 1), dtype=dtype('int32'), name='t')} | ||
``` | ||
|
||
|
||
### time_step is the input from the environment: | ||
``` | ||
{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0), | ||
'observation': {'base_pose_tool_reached': ArraySpec(shape=(7,), dtype=dtype('float32'), name='base_pose_tool_reached'), | ||
'gripper_closed': ArraySpec(shape=(1,), dtype=dtype('float32'), name='gripper_closed'), | ||
'gripper_closedness_commanded': ArraySpec(shape=(1,), dtype=dtype('float32'), name='gripper_closedness_commanded'), | ||
'height_to_bottom': ArraySpec(shape=(1,), dtype=dtype('float32'), name='height_to_bottom'), | ||
'image': ArraySpec(shape=(256, 320, 3), dtype=dtype('uint8'), name='image'), | ||
'natural_language_embedding': ArraySpec(shape=(512,), dtype=dtype('float32'), name='natural_language_embedding'), | ||
'natural_language_instruction': ArraySpec(shape=(), dtype=dtype('O'), name='natural_language_instruction'), | ||
'orientation_box': ArraySpec(shape=(2, 3), dtype=dtype('float32'), name='orientation_box'), | ||
'orientation_start': ArraySpec(shape=(4,), dtype=dtype('float32'), name='orientation_in_camera_space'), | ||
'robot_orientation_positions_box': ArraySpec(shape=(3, 3), dtype=dtype('float32'), name='robot_orientation_positions_box'), | ||
'rotation_delta_to_go': ArraySpec(shape=(3,), dtype=dtype('float32'), name='rotation_delta_to_go'), | ||
'src_rotation': ArraySpec(shape=(4,), dtype=dtype('float32'), name='transform_camera_robot'), | ||
'vector_to_go': ArraySpec(shape=(3,), dtype=dtype('float32'), name='vector_to_go'), | ||
'workspace_bounds': ArraySpec(shape=(3, 3), dtype=dtype('float32'), name='workspace_bounds')}, | ||
'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'), | ||
'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')} | ||
``` | ||
|
||
### action: | ||
``` | ||
{'base_displacement_vector': BoundedArraySpec(shape=(2,), dtype=dtype('float32'), name='base_displacement_vector', minimum=-1.0, maximum=1.0), | ||
'base_displacement_vertical_rotation': BoundedArraySpec(shape=(1,), dtype=dtype('float32'), name='base_displacement_vertical_rotation', minimum=-3.1415927410125732, maximum=3.1415927410125732), | ||
'gripper_closedness_action': BoundedArraySpec(shape=(1,), dtype=dtype('float32'), name='gripper_closedness_action', minimum=-1.0, maximum=1.0), | ||
'rotation_delta': BoundedArraySpec(shape=(3,), dtype=dtype('float32'), name='rotation_delta', minimum=-1.5707963705062866, maximum=1.5707963705062866), | ||
'terminate_episode': BoundedArraySpec(shape=(3,), dtype=dtype('int32'), name='terminate_episode', minimum=0, maximum=1), | ||
'world_vector': BoundedArraySpec(shape=(3,), dtype=dtype('float32'), name='world_vector', minimum=-1.0, maximum=1.0)} | ||
``` | ||
|
||
## TODO: | ||
- Render action, policy_state, observation specs in something prettier like pandas data frame. |
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from typing import Optional | ||
import PIL | ||
|
||
|
||
class Observation(object): | ||
pass | ||
class State(object): | ||
pass | ||
class Supervision(object): | ||
pass | ||
class Config(object): | ||
pass | ||
|
||
|
||
class Action(object): | ||
pass | ||
class Actor(object): | ||
def act(self, observation: Observation, state: State, config: Config, supervision: Optional[Supervision] = None) -> Action: | ||
pass | ||
|
||
class Src(object): | ||
'''The source data type from a sensor or web interface. A camera image, a point cloud, etc.''' | ||
pass | ||
|
||
class InternalState(object): | ||
''' The internal state of the agent. This is produced internally in contrast to Observation.''' | ||
pass | ||
|
||
# This is different than the Observer in tf-agents. That observer simply records from the environment. | ||
# This Observer is responsible from taking in data from the environment and converting it into an Observation. | ||
# Essentially it is a tf-agents observer + a post-processor. | ||
|
||
# In the case of RT1, the Observer would take in a camera image and a natural language instruction and convert it into an Observation. | ||
class Observer(object): | ||
def observe(self, srcs: list[Src]) -> Observation: | ||
pass | ||
|
||
class Rt1Observer(Observer): | ||
def observe(self, srcs: list[Src(PIL.Image), Src(str)]) -> Observation: | ||
pass | ||
|
||
|
||
def inference( | ||
model: any, | ||
internal_state: dict, | ||
observation: dict, | ||
supervision: dict, | ||
config: dict, | ||
) -> dict: | ||
"""Infer action from observation. | ||
Args: | ||
cgn (CGN): ContactGraspNet model | ||
pcd (np.ndarray): point cloud | ||
threshold (float, optional): Success threshol. Defaults to 0.5. | ||
visualize (bool, optional): Whether or not to visualize output. Defaults to False. | ||
max_grasps (int, optional): Maximum grasps. Zero means unlimited. Defaults to 0. | ||
obj_mask (np.ndarray, optional): Object mask. Defaults to None. | ||
Returns: | ||
tuple[np.ndarray, np.ndarray, np.ndarray]: The grasps, confidence and indices of the points used for inference. | ||
""" | ||
# cgn.eval() | ||
# pcd = torch.Tensor(pcd).to(dtype=torch.float32).to(cgn.device) | ||
# if pcd.shape[0] > 20000: | ||
# downsample_idxs = np.array(random.sample(range(pcd.shape[0] - 1), 20000)) | ||
# else: | ||
# downsample_idxs = np.arange(pcd.shape[0]) | ||
# pcd = pcd[downsample_idxs, :] | ||
|
||
# batch = torch.zeros(pcd.shape[0]).to(dtype=torch.int64).to(cgn.device) | ||
# fps_idxs = farthest_point_sample(pcd, batch, 2048 / pcd.shape[0]) | ||
|
||
# if obj_mask is not None: | ||
# obj_mask = torch.Tensor(obj_mask[downsample_idxs]) | ||
# obj_mask = obj_mask[fps_idxs] | ||
# else: | ||
# obj_mask = torch.ones(fps_idxs.shape[0]) | ||
# points, pred_grasps, confidence, pred_widths, _, _ = cgn( | ||
# pcd[:, 3:], | ||
# pcd_poses=pcd[:, :3], | ||
# batch=batch, | ||
# idxs=fps_idxs, | ||
# gripper_depth=gripper_depth, | ||
# gripper_width=gripper_width, | ||
# ) | ||
|
||
# sig = torch.nn.Sigmoid() | ||
# confidence = sig(confidence) | ||
# confidence = confidence.reshape(-1) | ||
# pred_grasps = ( | ||
# torch.flatten(pred_grasps, start_dim=0, end_dim=1).detach().cpu().numpy() | ||
# ) | ||
|
||
# confidence = ( | ||
# obj_mask.detach().cpu().numpy() * confidence.detach().cpu().numpy() | ||
# ).reshape(-1) | ||
# pred_widths = ( | ||
# torch.flatten(pred_widths, start_dim=0, end_dim=1).detach().cpu().numpy() | ||
# ) | ||
# points = torch.flatten(points, start_dim=0, end_dim=1).detach().cpu().numpy() | ||
|
||
# success_mask = (confidence > threshold).nonzero()[0] | ||
# if len(success_mask) == 0: | ||
# print("failed to find successful grasps") | ||
# return None, None, None | ||
|
||
# success_grasps = pred_grasps[success_mask] | ||
# success_confidence = confidence[success_mask] | ||
# print("Found {} grasps".format(success_grasps.shape[0])) | ||
# if max_grasps > 0 and success_grasps.shape[0] > max_grasps: | ||
# success_grasps = success_grasps[:max_grasps] | ||
# success_confidence = success_confidence[:max_grasps] | ||
# if visualize: | ||
# visualize_grasps( | ||
# pcd.detach().cpu().numpy(), | ||
# success_grasps, | ||
# gripper_depth=gripper_depth, | ||
# gripper_width=gripper_width, | ||
# ) | ||
# return success_grasps, success_confidence, downsample_idxs |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.