Skip to content

Commit

Permalink
Add/add v2 op (#11)
Browse files Browse the repository at this point in the history
* add AddV2 op
  • Loading branch information
tokusumi authored Aug 17, 2020
1 parent f67795a commit 5c130bf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
15 changes: 12 additions & 3 deletions keras_flops/flops_registory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
from tensorflow.python.profiler.internal.flops_registry import _reduction_op_flops
from tensorflow.python.profiler.internal.flops_registry import (
_reduction_op_flops,
_binary_per_element_op_flops,
)


@ops.RegisterStatistics("FusedBatchNormV3", "flops")
Expand All @@ -18,8 +21,8 @@ def _flops_fused_batch_norm_v3(graph, node):
raise ValueError("Only supports inference mode")

num_flops = (
in_shape.num_elements()
+ 4 * variance_shape.num_elements()
2 * in_shape.num_elements()
+ 5 * variance_shape.num_elements()
+ mean_shape.num_elements()
)
return ops.OpStats("flops", num_flops)
Expand All @@ -31,3 +34,9 @@ def _flops_max(graph, node):
# reduction - comparison, no finalization
return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)


@ops.RegisterStatistics("AddV2", "flops")
def _flops_add(graph, node):
"""inference is supportted"""
return _binary_per_element_op_flops(graph, node)

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "keras-flops"
version = "0.1.1"
description = "FLOPs calculator with tf.profiler for neural network architecture written in tensorflow 2.x (tf.keras)"
version = "0.1.2"
description = "FLOPs calculator with tf.profiler for neural network architecture written in tensorflow 2.2+ (tf.keras)"
authors = ["tokusumi <[email protected]>"]
license = "MIT"
readme = "README.md"
Expand Down
13 changes: 7 additions & 6 deletions tests/test_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ def test_conv1dtranspose():
def test_batchnormalization():
"""
batch normalization is calculated as follows,
1. (2 ops * |var|) inv = rsqrt(var + eps)
1. (3 ops * |var|) inv = rsqrt(var + eps)
2. (1 ops * |var|) inv *= gamma (scale)
3. (|x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
3. (2 * |x| + |mean| + |var| ops) x' = inv * x + beta (shift) - mean * inv
, where |var| = |mean| = channel size in default
Thus, tot FLOPs = 5 * channel size + input element size.
Thus, tot FLOPs = 6 * channel size + 2 * input element size.
"""
in_w = 32
in_h = 32
Expand All @@ -334,7 +334,7 @@ def test_batchnormalization():
)
flops = get_flops(model, batch_size=1)
assert (
flops == 5 * in_ch + in_w * in_ch
flops == 6 * in_ch + 2 * in_w * in_ch
), "fused is False. see nn_impl.batch_normalization"

model = Sequential(
Expand All @@ -346,7 +346,7 @@ def test_batchnormalization():
)
flops = get_flops(model, batch_size=1)
assert (
flops == 5 * in_ch + in_w * in_h * in_ch
flops == 6 * in_ch + 2 * in_w * in_h * in_ch
), "fused is True, see gen_nn.fused_batch_norm_v3"


Expand All @@ -355,7 +355,7 @@ def test_additive_attention():
Bahdanau-style attention. query (batch, Tq, dim), key (batch, Tv, dim) and value (batch, Tv, dim) are inputs.
following computations is processed.
1. reshape query as shape [batch, Tq, 1, dim] and value as shape [batch, 1, Tv, dim]
2. broadcasting multiply between both of above as output shape [batch, Tq, Tv, dim]
2. broadcasting multiply between additive of above as output shape [batch, Tq, Tv, dim]
3. reduce_sum above with dim axis as output shape [batch, Tq, Tv]
4. softmax of above
5. MatMul between 4. and value as output shape [batch, Tq, dim]
Expand All @@ -375,6 +375,7 @@ def test_additive_attention():
assert (
flops
== Tq * Tv * dim # No.2 (multiply)
+ Tq * Tv * dim # No.3 (add)
+ Tq * Tv * (dim - 1) # No.3 (reduce_sum)
+ 5 * Tq * Tv # No.4 (softmax)
+ 2 * Tv * Tq * dim # No.5 (MatMul)
Expand Down

0 comments on commit 5c130bf

Please sign in to comment.