Skip to content

Commit

Permalink
Fixed rpc tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelalonsojr committed Oct 2, 2023
1 parent 44d80de commit c4a04b6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
14 changes: 7 additions & 7 deletions ml-agents-envs/mlagents_envs/rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def process_pixels(
image = Image.open(image_fp)
# Normally Image loads lazily, load() forces it to do loading in the timer scope.
image.load()
image_arrays.append(np.array(image, dtype=np.float32) / 255.0)
image_arrays.append(np.moveaxis(np.array(image, dtype=np.float32) / 255.0, -1, 0))

# Look for the next header, starting from the current stream location
try:
Expand All @@ -142,7 +142,7 @@ def _process_images_mapping(image_arrays, mappings):
"""
Helper function for processing decompressed images with compressed channel mappings.
"""
image_arrays = np.concatenate(image_arrays, axis=2).transpose((2, 0, 1))
image_arrays = np.concatenate(image_arrays, axis=0).transpose((0, 1, 2))

if len(mappings) != len(image_arrays):
raise UnityObservationException(
Expand Down Expand Up @@ -178,15 +178,15 @@ def _process_images_num_channels(image_arrays, expected_channels):
"""
if expected_channels == 1:
# Convert to grayscale
img = np.mean(image_arrays[0], axis=2)
img = np.reshape(img, [img.shape[0], img.shape[1], 1])
img = np.mean(image_arrays[0], axis=0)
img = np.reshape(img, [1, img.shape[0], img.shape[1]])
else:
img = np.concatenate(image_arrays, axis=2)
img = np.concatenate(image_arrays, axis=0)
# We can drop additional channels since they may need to be added to include
# numbers of observation channels not divisible by 3.
actual_channels = list(img.shape)[2]
actual_channels = list(img.shape)[0]
if actual_channels > expected_channels:
img = img[..., 0:expected_channels]
img = img[0:expected_channels, ...]
return img


Expand Down
55 changes: 28 additions & 27 deletions ml-agents/mlagents/trainers/tests/test_rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,18 @@ def generate_compressed_data(in_array: np.ndarray) -> bytes:
image_arr = (in_array * 255).astype(np.uint8)
bytes_out = bytes()

num_channels = in_array.shape[2]
num_channels = in_array.shape[0]
num_images = (num_channels + 2) // 3
# Split the input image into batches of 3 channels.
for i in range(num_images):
sub_image = image_arr[..., 3 * i : 3 * i + 3]
sub_image = image_arr[3 * i: 3 * i + 3, ...]
if (i == num_images - 1) and (num_channels % 3) != 0:
# Pad zeros
zero_shape = list(in_array.shape)
zero_shape[2] = 3 - (num_channels % 3)
zero_shape[0] = 3 - (num_channels % 3)
z = np.zeros(zero_shape, dtype=np.uint8)
sub_image = np.concatenate([sub_image, z], axis=2)
sub_image = np.concatenate([sub_image, z], axis=0)
sub_image = np.moveaxis(sub_image, 0, -1)
im = Image.fromarray(sub_image, "RGB")
byteIO = io.BytesIO()
im.save(byteIO, format="PNG")
Expand All @@ -92,7 +93,7 @@ def generate_compressed_proto_obs(
obs_proto.compression_type = PNG
if grayscale:
# grayscale flag is only used for old API without mapping
expected_shape = [in_array.shape[0], in_array.shape[1], 1]
expected_shape = [1, in_array.shape[1], in_array.shape[2]]
obs_proto.shape.extend(expected_shape)
else:
obs_proto.shape.extend(in_array.shape)
Expand All @@ -109,9 +110,9 @@ def generate_compressed_proto_obs_with_mapping(
if mapping is not None:
obs_proto.compressed_channel_mapping.extend(mapping)
expected_shape = [
in_array.shape[0],
in_array.shape[1],
len({m for m in mapping if m >= 0}),
in_array.shape[1],
in_array.shape[2],
]
obs_proto.shape.extend(expected_shape)
else:
Expand Down Expand Up @@ -233,10 +234,10 @@ def proto_from_steps_and_action(


def test_process_pixels():
in_array = np.random.rand(128, 64, 3)
in_array = np.random.rand(3, 128, 64)
byte_arr = generate_compressed_data(in_array)
out_array = process_pixels(byte_arr, 3)
assert out_array.shape == (128, 64, 3)
assert out_array.shape == (3, 128, 64)
assert np.sum(in_array - out_array) / np.prod(in_array.shape) < 0.01
assert np.allclose(in_array, out_array, atol=0.01)

Expand All @@ -245,21 +246,21 @@ def test_process_pixels_multi_png():
height = 128
width = 64
num_channels = 7
in_array = np.random.rand(height, width, num_channels)
in_array = np.random.rand(num_channels, height, width)
byte_arr = generate_compressed_data(in_array)
out_array = process_pixels(byte_arr, num_channels)
assert out_array.shape == (height, width, num_channels)
assert out_array.shape == (num_channels, height, width)
assert np.sum(in_array - out_array) / np.prod(in_array.shape) < 0.01
assert np.allclose(in_array, out_array, atol=0.01)


def test_process_pixels_gray():
in_array = np.random.rand(128, 64, 3)
in_array = np.random.rand(3, 128, 64)
byte_arr = generate_compressed_data(in_array)
out_array = process_pixels(byte_arr, 1)
assert out_array.shape == (128, 64, 1)
assert np.mean(in_array.mean(axis=2, keepdims=True) - out_array) < 0.01
assert np.allclose(in_array.mean(axis=2, keepdims=True), out_array, atol=0.01)
assert out_array.shape == (1, 128, 64)
assert np.mean(in_array.mean(axis=0, keepdims=True) - out_array) < 0.01
assert np.allclose(in_array.mean(axis=0, keepdims=True), out_array, atol=0.01)


def test_vector_observation():
Expand All @@ -276,7 +277,7 @@ def test_vector_observation():


def test_process_visual_observation():
shape = (128, 64, 3)
shape = (3, 128, 64)
in_array_1 = np.random.rand(*shape)
proto_obs_1 = generate_compressed_proto_obs(in_array_1)
in_array_2 = np.random.rand(*shape)
Expand All @@ -292,51 +293,51 @@ def test_process_visual_observation():
ap_list = [ap1, ap2]
obs_spec = create_observation_specs_with_shapes([shape])[0]
arr = _process_maybe_compressed_observation(0, obs_spec, ap_list)
assert list(arr.shape) == [2, 128, 64, 3]
assert list(arr.shape) == [2, 3, 128, 64]
assert np.allclose(arr[0, :, :, :], in_array_1, atol=0.01)
assert np.allclose(arr[1, :, :, :], in_array_2, atol=0.01)


def test_process_visual_observation_grayscale():
in_array_1 = np.random.rand(128, 64, 3)
in_array_1 = np.random.rand(3, 128, 64)
proto_obs_1 = generate_compressed_proto_obs(in_array_1, grayscale=True)
expected_out_array_1 = np.mean(in_array_1, axis=2, keepdims=True)
in_array_2 = np.random.rand(128, 64, 3)
expected_out_array_1 = np.mean(in_array_1, axis=0, keepdims=True)
in_array_2 = np.random.rand(3, 128, 64)
in_array_2_mapping = [0, 0, 0]
proto_obs_2 = generate_compressed_proto_obs_with_mapping(
in_array_2, in_array_2_mapping
)
expected_out_array_2 = np.mean(in_array_2, axis=2, keepdims=True)
expected_out_array_2 = np.mean(in_array_2, axis=0, keepdims=True)

ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap2 = AgentInfoProto()
ap2.observations.extend([proto_obs_2])
ap_list = [ap1, ap2]
shape = (128, 64, 1)
shape = (1, 128, 64)
obs_spec = create_observation_specs_with_shapes([shape])[0]
arr = _process_maybe_compressed_observation(0, obs_spec, ap_list)
assert list(arr.shape) == [2, 128, 64, 1]
assert list(arr.shape) == [2, 1, 128, 64]
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)
assert np.allclose(arr[1, :, :, :], expected_out_array_2, atol=0.01)


def test_process_visual_observation_padded_channels():
in_array_1 = np.random.rand(128, 64, 12)
in_array_1 = np.random.rand(12, 128, 64)
in_array_1_mapping = [0, 1, 2, 3, -1, -1, 4, 5, 6, 7, -1, -1]
proto_obs_1 = generate_compressed_proto_obs_with_mapping(
in_array_1, in_array_1_mapping
)
expected_out_array_1 = np.take(in_array_1, [0, 1, 2, 3, 6, 7, 8, 9], axis=2)
expected_out_array_1 = np.take(in_array_1, [0, 1, 2, 3, 6, 7, 8, 9], axis=0)

ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap_list = [ap1]
shape = (128, 64, 8)
shape = (8, 128, 64)
obs_spec = create_observation_specs_with_shapes([shape])[0]

arr = _process_maybe_compressed_observation(0, obs_spec, ap_list)
assert list(arr.shape) == [1, 128, 64, 8]
assert list(arr.shape) == [1, 8, 128, 64]
assert np.allclose(arr[0, :, :, :], expected_out_array_1, atol=0.01)


Expand Down

0 comments on commit c4a04b6

Please sign in to comment.