Skip to content

Commit

Permalink
Merge branch 'main' into FB35F7423633BEEDB6F5751B3449E9F7
Browse files Browse the repository at this point in the history
  • Loading branch information
Qwlouse authored Apr 4, 2024
2 parents 3a200da + 21a0ab7 commit b3675ef
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 22 deletions.
96 changes: 85 additions & 11 deletions challenges/point_tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def track_points(
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1,
snap_to_occluder=False,
):
"""Track points in 2D using Kubric data.
Expand Down Expand Up @@ -508,6 +509,11 @@ def track_points(
max_seg_id: The maxium segment id in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
snap_to_occluder: If true, query points within 1 pixel of occlusion
boundaries will track the occluding surface rather than the background.
This results in models which are biased to track foreground objects
instead of background. Whether this is desirable depends on downstream
applications.
Returns:
A set of queries, randomly sampled from the video (with a bias toward
Expand Down Expand Up @@ -556,13 +562,37 @@ def extract_box(x):
start_vec[2]:window[3]:sampling_stride]
return x

def erode_segmentations(seg, depth):
# Mask out points that are near to being occluded by some other object, as
# measured by nearby depth discontinuities within a 1px radius.
sz = seg.shape
pad_depth = tf.pad(
depth, [(0, 0), (1, 1), (1, 1), (0, 0)], mode='SYMMETRIC'
)
invalid = False
for x in range(0, 3):
for y in range(0, 3):
if x == 1 and y == 1:
continue
wind_depth = pad_depth[:, y : y + sz[1], x : x + sz[2], :]
invalid = tf.logical_or(invalid, tf.math.less(wind_depth, depth * 0.95))
seg = tf.where(invalid, tf.zeros_like(seg) - 1, seg)
return seg

if snap_to_occluder:
segmentations = erode_segmentations(
tf.cast(segmentations, tf.int32), depth_map
)

segmentations_box = extract_box(segmentations)
object_coordinates_box = extract_box(object_coordinates)

# Next, get the number of points to sample from each object. First count
# how many points are available for each object.

cnt = tf.math.bincount(tf.cast(tf.reshape(segmentations_box, [-1]), tf.int32))
cnt = tf.math.bincount(
tf.cast(tf.reshape(segmentations_box, [-1]) + 1, tf.int32)
)[1:]
num_to_sample = get_num_to_sample(
cnt,
max_seg_id,
Expand Down Expand Up @@ -633,8 +663,6 @@ def get_camera(fr=None):
normals = tf.gather(normals, idx)
trust_sn_gather = tf.gather(trust_sn_mask, idx)

pixel_to_raster = tf.constant([0.0, 0.5, 0.5])[tf.newaxis,:]

if obj_id == -1:
# For the background object, no bounding box is available. However,
# this doesn't move, so we use the depth map to backproject these points
Expand All @@ -651,9 +679,7 @@ def get_camera(fr=None):
pt_3d.append(
unproject(pt_coords_chunk[:, 1:], get_camera(fr), depth_map[fr]))
pt = tf.concat(pt_3d, axis=0)
chosen_points.append(
tf.cast(tf.concat(pt_coords_reorder, axis=0), tf.float32) +
pixel_to_raster)
chosen_points.append(tf.concat(pt_coords_reorder,axis=0))
bbox = None
quat = None
frame_for_pt = None
Expand All @@ -662,7 +688,7 @@ def get_camera(fr=None):
# kubric.
pt = tf.gather(pt, idx)
pt = pt / np.iinfo(np.uint16).max - .5
chosen_points.append(tf.cast(pt_coords, tf.float32) + pixel_to_raster)
chosen_points.append(pt_coords)
# if obj_id>num_objects, then we won't have a box. We also won't have
# points, so just use a dummy to prevent tf from crashing.
bbox = tf.cond(obj_id >= tf.shape(bboxes_3d)[0], lambda: bboxes_3d[0, :],
Expand Down Expand Up @@ -732,7 +758,42 @@ def get_camera(fr=None):
# chosen_points is [num_points, (z,y,x)]
chosen_points = tf.concat(chosen_points, axis=0)

chosen_points = tf.cast(chosen_points, tf.float32)
if snap_to_occluder:
# For query points that are near to an occlusion boundary, occasionally
# jitter the query point onto the occluded object.
random_perturb = chosen_points[:, 1:] + tf.random.uniform(
tf.shape(chosen_points[:, 1:]), -1, 2, dtype=tf.int32
)
random_perturb = tf.minimum(
tf.shape(depth_map)[1:3] - 1, tf.maximum(0, random_perturb)
)
random_idx = (
random_perturb[:, 1]
+ random_perturb[:, 0] * tf.shape(depth_map)[2]
+ chosen_points[:, 0] * tf.shape(depth_map)[1] * tf.shape(depth_map)[2]
)
chosen_points_idx = (
chosen_points[:, 1]
+ chosen_points[:, 0] * tf.shape(depth_map)[2]
+ chosen_points[:, 0] * tf.shape(depth_map)[1] * tf.shape(depth_map)[2]
)
random_depth = tf.gather(tf.reshape(depth_map, [-1]), random_idx)
chosen_points_depth = tf.gather(
tf.reshape(depth_map, [-1]), chosen_points_idx
)
swap = tf.logical_and(
chosen_points_depth < random_depth * 0.95,
tf.random_uniform(tf.shape(chosen_points_depth)) < 0.5,
)
random_perturb = tf.concat([chosen_points[:, 0:1], random_perturb], axis=-1)
chosen_points = tf.where(
tf.logical_and(swap[:, tf.newaxis], tf.ones([1, 3], dtype=bool)),
random_perturb,
chosen_points,
)

pixel_to_raster = tf.constant([0.0, 0.5, 0.5])[tf.newaxis,:]
chosen_points = tf.cast(chosen_points, tf.float32) + pixel_to_raster

# renormalize so the box corners are at [-1,1]
chosen_points = (chosen_points - wd[:, 0, :3]) / (wd[:, 0, 3:] - wd[:, 0, :3])
Expand Down Expand Up @@ -778,7 +839,8 @@ def add_tracks(data,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1):
max_sampled_frac=0.1,
snap_to_occluder=False):
"""Track points in 2D using Kubric data.
Args:
Expand All @@ -794,6 +856,11 @@ def add_tracks(data,
max_seg_id: The maxium segment id in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
snap_to_occluder: If true, query points within 1 pixel of occlusion
boundaries will track the occluding surface rather than the background.
This results in models which are biased to track foreground objects
instead of background. Whether this is desirable depends on downstream
applications.
Returns:
A dict with the following keys:
Expand Down Expand Up @@ -843,7 +910,7 @@ def add_tracks(data,
data['camera']['focal_length'],
data['camera']['positions'], data['camera']['quaternions'],
data['camera']['sensor_width'], crop_window, tracks_to_sample,
sampling_stride, max_seg_id, max_sampled_frac)
sampling_stride, max_seg_id, max_sampled_frac, snap_to_occluder)
video = data['video']

shp = video.shape.as_list()
Expand Down Expand Up @@ -890,6 +957,7 @@ def create_point_tracking_dataset(
max_seg_id=25,
max_sampled_frac=0.1,
num_parallel_point_extraction_calls=16,
snap_to_occluder=False,
**kwargs):
"""Construct a dataset for point tracking using Kubric.
Expand All @@ -912,6 +980,11 @@ def create_point_tracking_dataset(
object, out of all points that lie on the sampling grid.
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
map function for point extraction.
snap_to_occluder: If true, query points within 1 pixel of occlusion
boundaries will track the occluding surface rather than the background.
This results in models which are biased to track foreground objects
instead of background. Whether this is desirable depends on downstream
applications.
**kwargs: additional args to pass to tfds.load.
Returns:
Expand All @@ -935,7 +1008,8 @@ def create_point_tracking_dataset(
tracks_to_sample=tracks_to_sample,
sampling_stride=sampling_stride,
max_seg_id=max_seg_id,
max_sampled_frac=max_sampled_frac),
max_sampled_frac=max_sampled_frac,
snap_to_occluder=snap_to_occluder),
num_parallel_calls=num_parallel_point_extraction_calls)
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)
Expand Down
16 changes: 7 additions & 9 deletions kubric/renderer/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,25 @@
# limitations under the License.

import collections
from contextlib import redirect_stdout
import functools
import io
import logging
import os
import sys
from contextlib import redirect_stdout
from typing import Any, Dict, Optional, Sequence, Union
import tempfile

from kubric.safeimport.bpy import bpy

import numpy as np
import tensorflow as tf
from typing import Any, Dict, Optional, Sequence, Union

import kubric as kb
from kubric import core
from kubric import file_io
from kubric.core.assets import UndefinedAsset
from kubric.file_io import PathLike
from kubric.redirect_io import RedirectStream
from kubric.renderer import blender_utils
from kubric import file_io
from kubric.file_io import PathLike
from kubric.safeimport.bpy import bpy
import numpy as np
import tensorflow as tf

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions kubric/simulator/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import tempfile
from typing import Dict, List, Optional, Tuple, Union

import tensorflow as tf

from kubric import core
from kubric.redirect_io import RedirectStream
import tensorflow as tf

# --- hides the "pybullet build time: May 26 2021 18:52:36" message on import
with RedirectStream(stream=sys.stderr):
Expand Down

0 comments on commit b3675ef

Please sign in to comment.