From ae0401e2d47c9642ab0b0439974e0cd490a906ea Mon Sep 17 00:00:00 2001 From: Sebastian Peralta Date: Sat, 16 Dec 2023 12:12:11 -0500 Subject: [PATCH 1/3] inference server --- README.md | 32 ++++++- robo_transformers/inference.py | 121 ------------------------- robo_transformers/inference_server.py | 117 ++++++++++++++++++++++++ robo_transformers/rt1/rt1_inference.py | 88 +++++++++++------- 4 files changed, 198 insertions(+), 160 deletions(-) delete mode 100644 robo_transformers/inference.py create mode 100644 robo_transformers/inference_server.py diff --git a/README.md b/README.md index c0e5364..ef8b2d4 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,12 @@ Requirements: python >= 3.9 -### Using PyPI +### Recommended: Using PyPI `pip install robo-transformers` ### From Source -Clone this repo: +Clone this repo: + `git clone https://github.com/sebbyjp/robo_transformers.git` `cd robo_transformers` @@ -20,18 +21,39 @@ Use poetry `pip install poetry && poetry config virtualenvs.in-project true` -**Install dependencies:** +**Install dependencies** + `poetry install` +Poetry has installed the dependencies in a virtualenv so we need to activate it. + `source .venv/bin/activate` ## Run RT-1 Inference On Demo Images. `python -m robo_transformers.rt1.rt1_inference` -## See options: +## See usage: +You can specify a custom checkpoint path or the model_keys for the three mentioned in the RT-1 paper as well as RT-X. + `python -m robo_transformers.rt1.rt1_inference --help` + +## Run Inference Server +The inference server takes care of all the internal state so all you need to specify is an instruction and image. You may also pass in a reward and termination signal. Batching is also supported. +``` +from robo_transformers.inference_server import InferenceServer +import numpy as np + +# Somewhere in your robot control stack code... + +instruction = "pick block" +img = np.random.randn(256, 320, 3) # Width, Height, RGB +inference = InferenceServer() + +action = inference(instruction, img) +``` + -## Notes +## Data Types `action, next_policy_state = model.act(time_step, curr_policy_state)` ### policy state is internal state of network: In this case it is a 6-frame window of past observations,actions and the index in time. diff --git a/robo_transformers/inference.py b/robo_transformers/inference.py deleted file mode 100644 index 33af122..0000000 --- a/robo_transformers/inference.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Optional -import PIL - - -class Observation(object): - pass -class State(object): - pass -class Supervision(object): - pass -class Config(object): - pass - - -class Action(object): - pass -class Actor(object): - def act(self, observation: Observation, state: State, config: Config, supervision: Optional[Supervision] = None) -> Action: - pass - -class Src(object): - '''The source data type from a sensor or web interface. A camera image, a point cloud, etc.''' - pass - -class InternalState(object): - ''' The internal state of the agent. This is produced internally in contrast to Observation.''' - pass - -# This is different than the Observer in tf-agents. That observer simply records from the environment. -# This Observer is responsible from taking in data from the environment and converting it into an Observation. -# Essentially it is a tf-agents observer + a post-processor. - -# In the case of RT1, the Observer would take in a camera image and a natural language instruction and convert it into an Observation. -class Observer(object): - def observe(self, srcs: list[Src]) -> Observation: - pass - -class Rt1Observer(Observer): - def observe(self, srcs: list[Src(PIL.Image), Src(str)]) -> Observation: - pass - - -def inference( - model: any, - internal_state: dict, - observation: dict, - supervision: dict, - config: dict, -) -> dict: - """Infer action from observation. - - Args: - cgn (CGN): ContactGraspNet model - pcd (np.ndarray): point cloud - threshold (float, optional): Success threshol. Defaults to 0.5. - visualize (bool, optional): Whether or not to visualize output. Defaults to False. - max_grasps (int, optional): Maximum grasps. Zero means unlimited. Defaults to 0. - obj_mask (np.ndarray, optional): Object mask. Defaults to None. - - Returns: - tuple[np.ndarray, np.ndarray, np.ndarray]: The grasps, confidence and indices of the points used for inference. - """ - # cgn.eval() - # pcd = torch.Tensor(pcd).to(dtype=torch.float32).to(cgn.device) - # if pcd.shape[0] > 20000: - # downsample_idxs = np.array(random.sample(range(pcd.shape[0] - 1), 20000)) - # else: - # downsample_idxs = np.arange(pcd.shape[0]) - # pcd = pcd[downsample_idxs, :] - - # batch = torch.zeros(pcd.shape[0]).to(dtype=torch.int64).to(cgn.device) - # fps_idxs = farthest_point_sample(pcd, batch, 2048 / pcd.shape[0]) - - # if obj_mask is not None: - # obj_mask = torch.Tensor(obj_mask[downsample_idxs]) - # obj_mask = obj_mask[fps_idxs] - # else: - # obj_mask = torch.ones(fps_idxs.shape[0]) - # points, pred_grasps, confidence, pred_widths, _, _ = cgn( - # pcd[:, 3:], - # pcd_poses=pcd[:, :3], - # batch=batch, - # idxs=fps_idxs, - # gripper_depth=gripper_depth, - # gripper_width=gripper_width, - # ) - - # sig = torch.nn.Sigmoid() - # confidence = sig(confidence) - # confidence = confidence.reshape(-1) - # pred_grasps = ( - # torch.flatten(pred_grasps, start_dim=0, end_dim=1).detach().cpu().numpy() - # ) - - # confidence = ( - # obj_mask.detach().cpu().numpy() * confidence.detach().cpu().numpy() - # ).reshape(-1) - # pred_widths = ( - # torch.flatten(pred_widths, start_dim=0, end_dim=1).detach().cpu().numpy() - # ) - # points = torch.flatten(points, start_dim=0, end_dim=1).detach().cpu().numpy() - - # success_mask = (confidence > threshold).nonzero()[0] - # if len(success_mask) == 0: - # print("failed to find successful grasps") - # return None, None, None - - # success_grasps = pred_grasps[success_mask] - # success_confidence = confidence[success_mask] - # print("Found {} grasps".format(success_grasps.shape[0])) - # if max_grasps > 0 and success_grasps.shape[0] > max_grasps: - # success_grasps = success_grasps[:max_grasps] - # success_confidence = success_confidence[:max_grasps] - # if visualize: - # visualize_grasps( - # pcd.detach().cpu().numpy(), - # success_grasps, - # gripper_depth=gripper_depth, - # gripper_width=gripper_width, - # ) - # return success_grasps, success_confidence, downsample_idxs \ No newline at end of file diff --git a/robo_transformers/inference_server.py b/robo_transformers/inference_server.py new file mode 100644 index 0000000..31f6fd9 --- /dev/null +++ b/robo_transformers/inference_server.py @@ -0,0 +1,117 @@ +from tf_agents.policies.py_policy import PyPolicy +from tf_agents.policies.tf_policy import TFPolicy +from tf_agents.trajectories import policy_step as ps +from robo_transformers.rt1.rt1_inference import load_rt1, inference as rt1_inference +import numpy as np + + +class InferenceServer: + + def __init__(self, + model: PyPolicy | TFPolicy = None, + verbose: bool = False): + self.model = model + if self.model is None: + self.model = load_rt1() + + self.policy_state = None + self.verbose = verbose + self.step = 0 + + def __call__(self, + instructions: list[str] | str, + imgs: list[np.ndarray] | np.ndarray, + reward: list[float] | float = None, + terminate: bool = False) -> ps.ActionType: + action, state, _ = rt1_inference(instructions, imgs, self.step, reward, + self.model, self.policy_state, terminate, + self.verbose) + self.policy_state = state + self.step += 1 + return action + + + +# class Rt1Observer(Observer): +# def observe(self, srcs: list[Src(PIL.Image), Src(str)]) -> Observation: +# pass + +# def inference( +# model: any, +# internal_state: dict, +# observation: dict, +# supervision: dict, +# config: dict, +# ) -> dict: +# """Infer action from observation. + +# Args: +# cgn (CGN): ContactGraspNet model +# pcd (np.ndarray): point cloud +# threshold (float, optional): Success threshol. Defaults to 0.5. +# visualize (bool, optional): Whether or not to visualize output. Defaults to False. +# max_grasps (int, optional): Maximum grasps. Zero means unlimited. Defaults to 0. +# obj_mask (np.ndarray, optional): Object mask. Defaults to None. + +# Returns: +# tuple[np.ndarray, np.ndarray, np.ndarray]: The grasps, confidence and indices of the points used for inference. +# """ +# cgn.eval() +# pcd = torch.Tensor(pcd).to(dtype=torch.float32).to(cgn.device) +# if pcd.shape[0] > 20000: +# downsample_idxs = np.array(random.sample(range(pcd.shape[0] - 1), 20000)) +# else: +# downsample_idxs = np.arange(pcd.shape[0]) +# pcd = pcd[downsample_idxs, :] + +# batch = torch.zeros(pcd.shape[0]).to(dtype=torch.int64).to(cgn.device) +# fps_idxs = farthest_point_sample(pcd, batch, 2048 / pcd.shape[0]) + +# if obj_mask is not None: +# obj_mask = torch.Tensor(obj_mask[downsample_idxs]) +# obj_mask = obj_mask[fps_idxs] +# else: +# obj_mask = torch.ones(fps_idxs.shape[0]) +# points, pred_grasps, confidence, pred_widths, _, _ = cgn( +# pcd[:, 3:], +# pcd_poses=pcd[:, :3], +# batch=batch, +# idxs=fps_idxs, +# gripper_depth=gripper_depth, +# gripper_width=gripper_width, +# ) + +# sig = torch.nn.Sigmoid() +# confidence = sig(confidence) +# confidence = confidence.reshape(-1) +# pred_grasps = ( +# torch.flatten(pred_grasps, start_dim=0, end_dim=1).detach().cpu().numpy() +# ) + +# confidence = ( +# obj_mask.detach().cpu().numpy() * confidence.detach().cpu().numpy() +# ).reshape(-1) +# pred_widths = ( +# torch.flatten(pred_widths, start_dim=0, end_dim=1).detach().cpu().numpy() +# ) +# points = torch.flatten(points, start_dim=0, end_dim=1).detach().cpu().numpy() + +# success_mask = (confidence > threshold).nonzero()[0] +# if len(success_mask) == 0: +# print("failed to find successful grasps") +# return None, None, None + +# success_grasps = pred_grasps[success_mask] +# success_confidence = confidence[success_mask] +# print("Found {} grasps".format(success_grasps.shape[0])) +# if max_grasps > 0 and success_grasps.shape[0] > max_grasps: +# success_grasps = success_grasps[:max_grasps] +# success_confidence = success_confidence[:max_grasps] +# if visualize: +# visualize_grasps( +# pcd.detach().cpu().numpy(), +# success_grasps, +# gripper_depth=gripper_depth, +# gripper_width=gripper_width, +# ) +# return success_grasps, success_confidence, downsample_idxs diff --git a/robo_transformers/rt1/rt1_inference.py b/robo_transformers/rt1/rt1_inference.py index 26de313..3b245de 100644 --- a/robo_transformers/rt1/rt1_inference.py +++ b/robo_transformers/rt1/rt1_inference.py @@ -4,8 +4,9 @@ import PIL.Image as Image import tensorflow_hub as hub from tf_agents.policies.py_tf_eager_policy import SavedModelPyTFEagerPolicy as LoadedPolicy -from tf_agents.trajectories import time_step as ts +from tf_agents.trajectories import time_step as ts, policy_step as ps from tf_agents import specs +from tf_agents.typing import types from importlib.resources import files import absl.logging import os @@ -23,10 +24,10 @@ "https://drive.google.com/drive/folders/1_nudHVmGuGUpGcrLlswg9O-aWy27Cjg0?usp=drive_link", 'rt1multirobot': "https://drive.google.com/drive/folders/1EWjKSnfvD-ANPTLxugpCVP5zU6ADy8km?usp=drive_link", - 'xlatest': + 'rtx1': "https://drive.google.com/drive/folders/1LjTizUsqM88-5uHAIczTrObB3_z4OlgE?usp=drive_link", - 'xgresearch': - "https://drive.google.com/drive/folders/185nP-a8z-1Pm6Zc3yU2qZ01hoszyYx51?usp=drive_link" + # 'xgresearch': + # "https://drive.google.com/drive/folders/185nP-a8z-1Pm6Zc3yU2qZ01hoszyYx51?usp=drive_link" } @@ -46,7 +47,6 @@ def download_checkpoint(key: str, output: str = None): output=downloads_folder, quiet=True, use_cookies=False) - # quiet=True) return output @@ -88,7 +88,8 @@ def load_rt1(model_key: str = 'rt1simreal', def embed_text(input: list[str] | str, batch_size: int = 1) -> tf.Tensor: - '''Embeds a string using the Universal Sentence Encoder. + '''Embeds a string using the Universal Sentence Encoder. Copies the string + to fill the batch dimension. Args: input (str): The string to embed. @@ -104,8 +105,8 @@ def embed_text(input: list[str] | str, batch_size: int = 1) -> tf.Tensor: (batch_size, 512)) -def get_demo_imgs() -> tf.Tensor: - '''Loads a demo video from the ./demo_vids/ directory. +def get_demo_imgs(output=None) -> tf.Tensor: + '''Loads a demo video from the directory. Returns: list[tf.Tensor]: A list of tensors of shape (batch_size, HEIGHT, WIDTH, 3). @@ -121,6 +122,8 @@ def get_demo_imgs() -> tf.Tensor: ] for fn in filenames: img = Image.open(fn) + if output is not None: + img.save(os.path.join(output, fn.name)) img = np.array(img.resize((WIDTH, HEIGHT)).convert('RGB')) img = np.expand_dims(img, axis=0) img = tf.reshape(tf.convert_to_tensor(img, dtype=tf.uint8), @@ -129,30 +132,38 @@ def get_demo_imgs() -> tf.Tensor: return tf.concat(imgs, 0) -def inference(instructions: list[str] | str, - imgs: list[np.ndarray] | np.ndarray, - reward: list[float] | float = None, - policy: LoadedPolicy = None, - state=None, - verbose: bool = False, - step: int = 0, - done: bool = False): +def inference( + instructions: list[str] | str, + imgs: list[np.ndarray] | np.ndarray, + step: int, + reward: list[float] | float = None, + policy: LoadedPolicy = None, + policy_state=types.NestedArray, + terminate=False, + verbose: bool = False, +) -> tuple[ps.ActionType, types.NestedSpecTensorOrArray, + types.NestedSpecTensorOrArray]: '''Runs inference on a list of images and instructions. Args: instructions (list[str]): A list of instructions. E.g. ["pick up the block"] imgs (list[np.ndarray]): A list of images with shape[(HEIGHT, WIDTH, 3)] + step (int): The current time step. reward (list[float], optional): Defaults to None. policy (tf_agents.policies.tf_policy.TFPolicy, optional): Defaults to None. - state (_type_, optional). Defaults to None. + state (, optional). The internal network state. See 'policy state' in the "Data Types" section + of README.md. Defaults to None. + terminate (bool, optional): Whether or not to terminate the episode. Defaults to False. verbose (bool, optional): Whether or not to print debugging information. Defaults to False. Returns: - _type_: _description_ + tuple[Action, State, Info]: The action, state, and info from the policy Again see the + "Data Types" section of README.md. ''' if policy is None: policy = load_rt1() + # Calculate batch size from instructions shape. if isinstance(instructions, str): batch_size = 1 imgs = np.expand_dims(imgs, axis=0) @@ -161,11 +172,10 @@ def inference(instructions: list[str] | str, else: batch_size = len(instructions) - reward = tf.constant(reward, dtype=tf.float32) imgs = tf.constant(imgs, dtype=tf.uint8) - if state is None: - state = policy.get_initial_state(batch_size) + if policy_state is None: + policy_state = policy.get_initial_state(batch_size) if reward is None: reward = tf.zeros((batch_size,), dtype=tf.float32) @@ -177,13 +187,15 @@ def inference(instructions: list[str] | str, observation['image'] = imgs observation['natural_language_embedding'] = embed_text( instructions, batch_size) + if step == 0: time_step = ts.restart(observation, batch_size) - elif done: + elif terminate: time_step = ts.termination(observation, reward) else: time_step = ts.transition(observation, reward) - action, next_state, info = policy.action(time_step, state) + + action, next_state, info = policy.action(time_step, policy_state) if verbose: writer = tf.summary.create_file_writer("logs") @@ -199,7 +211,7 @@ def inference(instructions: list[str] | str, action['gripper_closedness_action'][0, 0], step=step) writer.flush() - return action, next_state + return action, next_state, info def run_on_demo_imgs(policy: LoadedPolicy = None, verbose: bool = False): @@ -211,14 +223,14 @@ def run_on_demo_imgs(policy: LoadedPolicy = None, verbose: bool = False): for i in range(3): Image.fromarray(imgs[i].numpy().astype(np.uint8)).save( './demo_img{}.png'.format(i)) - action, state = inference(instructions, - imgs[i], - rewards[i], - policy, - state, - verbose=True, - step=i, - done=(i == 2)) + action, state, _ = inference(instructions, + imgs[i], + rewards[i], + step=i, + policy=policy, + policy_state=state, + verbose=verbose, + terminate=(i == 2)) pprint(action) @@ -226,12 +238,12 @@ def run_on_demo_imgs(policy: LoadedPolicy = None, verbose: bool = False): parser = argparse.ArgumentParser( description= 'Run inference on demo images. Print the action and for three time steps and' - 'save the demo images to ./demo_{i}.png') + 'save the demo images to ./demo_imgs if requested.') parser.add_argument('-m', '--model_key', type=str, choices=REGISTRY.keys(), - default='xlatest', + default='rtx1', help='Which model to load.') parser.add_argument('-c', '--checkpoint_path', @@ -242,10 +254,18 @@ def run_on_demo_imgs(policy: LoadedPolicy = None, verbose: bool = False): '--verbose', action='store_true', help='Whether or not to print debugging information.') + parser.add_argument('-s', + '--show-demo-imgs', + action='store_true', + help='Whether or not to show the demo images.') args = parser.parse_args() if args.verbose: tf.debugging.experimental.enable_dump_debug_info( './logs', tensor_debug_mode='FULL_HEALTH') + if args.show_demo_imgs: + os.makedirs('./demo_imgs', exist_ok=True) + get_demo_imgs('./demo_imgs') + run_on_demo_imgs(load_rt1(args.model_key, args.checkpoint_path), verbose=args.verbose) From 7a9af58bdbd5a97b9c34ace0d226de1071037ac1 Mon Sep 17 00:00:00 2001 From: Sebastian Peralta Date: Sat, 16 Dec 2023 12:16:36 -0500 Subject: [PATCH 2/3] fixed reward bug --- robo_transformers/rt1/rt1_inference.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/robo_transformers/rt1/rt1_inference.py b/robo_transformers/rt1/rt1_inference.py index 3b245de..123536b 100644 --- a/robo_transformers/rt1/rt1_inference.py +++ b/robo_transformers/rt1/rt1_inference.py @@ -168,7 +168,7 @@ def inference( batch_size = 1 imgs = np.expand_dims(imgs, axis=0) if reward is not None: - reward = [reward] + reward = reward * tf.constant((batch_size,), dtype=tf.float32) else: batch_size = len(instructions) @@ -213,24 +213,23 @@ def inference( writer.flush() return action, next_state, info - def run_on_demo_imgs(policy: LoadedPolicy = None, verbose: bool = False): instructions = "pick block" imgs = get_demo_imgs() rewards = [0, 0.5, 0.9] state = None - for i in range(3): - Image.fromarray(imgs[i].numpy().astype(np.uint8)).save( - './demo_img{}.png'.format(i)) + for step in range(3): + Image.fromarray(imgs[step].numpy().astype(np.uint8)).save( + './demo_img{}.png'.format(step)) action, state, _ = inference(instructions, - imgs[i], - rewards[i], - step=i, - policy=policy, - policy_state=state, - verbose=verbose, - terminate=(i == 2)) + imgs[step], + step, + rewards[step], + policy, + state, + verbose=verbose, + terminate=(step == 2)) pprint(action) From dba3cdc2cb115e55b2478469e0b1f52e1c1e5b80 Mon Sep 17 00:00:00 2001 From: Sebastian Peralta Date: Sat, 16 Dec 2023 12:17:26 -0500 Subject: [PATCH 3/3] fixed reward bug --- robo_transformers/rt1/rt1_inference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/robo_transformers/rt1/rt1_inference.py b/robo_transformers/rt1/rt1_inference.py index 123536b..7dec787 100644 --- a/robo_transformers/rt1/rt1_inference.py +++ b/robo_transformers/rt1/rt1_inference.py @@ -220,8 +220,6 @@ def run_on_demo_imgs(policy: LoadedPolicy = None, verbose: bool = False): state = None for step in range(3): - Image.fromarray(imgs[step].numpy().astype(np.uint8)).save( - './demo_img{}.png'.format(step)) action, state, _ = inference(instructions, imgs[step], step,