Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,7 +2259,7 @@ def convert_slice(self, op):
# Create axes list for all dimensions being sliced
axes = list(range(input_tensor_rank))
begin = [int(v) for v in begin]
end = [int(v) for v in end]
end = [int(v) for v in end]
out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
return out

Expand Down Expand Up @@ -2840,9 +2840,7 @@ def convert_batch_matmul(self, op):
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b]
max_rank = max(rank_a, rank_b)

batch_shape = [
max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)
]
batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)]

a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
Expand Down Expand Up @@ -2987,21 +2985,11 @@ def convert_prelu(self, op):

input_tensor = input_tensors[0]
alpha_tensor = input_tensors[1]
if self.has_expr(alpha_tensor.tensor_idx):
alpha_expr = self.get_expr(alpha_tensor.tensor_idx)
else:
alpha_tensor_type = alpha_tensor.tensor.Type()
alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
alpha_expr = self.exp_tab.new_const(
self.get_tensor_value(alpha_tensor),
dtype=alpha_tensor_type_str,
source_name=alpha_tensor.tensor.Name(),
)
in_expr = self.get_expr(input_tensor.tensor_idx)
data_shape = to_int_list(self.get_tensor_shape(input_tensor))

alpha_expr = relax.op.broadcast_to(alpha_expr, data_shape)
alpha_expr = relax.op.reshape(alpha_expr, [-1])
alpha_expr = self.get_tensor_expr(alpha_tensor)
alpha_expr = self.bb.normalize(relax.op.broadcast_to(alpha_expr, data_shape))
alpha_expr = self.bb.normalize(relax.op.reshape(alpha_expr, [-1]))
in_expr = self.get_tensor_expr(input_tensor)
out = relax.op.nn.prelu(_op.reshape(in_expr, [-1]), alpha_expr, axis=0)
out = relax.op.reshape(out, data_shape)
return out
Expand Down
70 changes: 64 additions & 6 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def func(self, x):

verify(Tile)


def test_concat_v2():
class ConcatV2(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])
Expand Down Expand Up @@ -804,6 +805,7 @@ def func(self, data, kernel):

verify(TransposeConv)


def test_l2_pool2d():
class L2Pool2D(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32)])
Expand All @@ -815,9 +817,9 @@ def func(self, data):
@I.ir_module
class Expected:
@R.function
def main(
data: R.Tensor((1, 8, 8, 2), dtype="float32")
) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
def main(data: R.Tensor((1, 8, 8, 2), dtype="float32")) -> R.Tensor(
(1, 8, 8, 2), dtype="float32"
):
R.func_attr({"num_input": 1})
with R.dataflow():
squared = R.power(data, R.const(2.0, "float32"))
Expand Down Expand Up @@ -883,6 +885,7 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3

verify(ReverseV2, Expected)


def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding):
class Conv2DModule(tf.Module):
@tf.function(
Expand Down Expand Up @@ -1590,9 +1593,7 @@ def test_nms_v5_ir():
"build_kwargs,expected_topk_count,expected_keep_background",
_DETECTION_POSTPROCESS_SMOKE_CASES,
)
def test_detection_postprocess_smoke(
build_kwargs, expected_topk_count, expected_keep_background
):
def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, expected_keep_background):
mod = _build_detection_postprocess_mod(**build_kwargs)
ir = mod.script()

Expand Down Expand Up @@ -1649,6 +1650,7 @@ def test_detection_postprocess_shape_variations(build_kwargs):
),
)


def _make_resize_expected(
input_shape, output_size, method, coordinate_transformation_mode, rounding_method
):
Expand Down Expand Up @@ -2109,5 +2111,61 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="floa
verify(ReLU_N1_to_1, Expected)


@pytest.mark.parametrize(
"shared_axes",
[
pytest.param([1, 2], id="channelwise_shared_axes"),
pytest.param([1, 2, 3], id="scalar_shared_axes"),
pytest.param(None, id="elementwise_no_shared_axes"),
],
)
def test_prelu(shared_axes):
inputs = tf.keras.Input(shape=(4, 4, 3), batch_size=1, dtype=tf.float32)
prelu_kwargs = {
"alpha_initializer": tf.initializers.constant(0.25),
}
if shared_axes is not None:
prelu_kwargs["shared_axes"] = shared_axes
outputs = tf.keras.layers.PReLU(**prelu_kwargs)(inputs)
keras_model = tf.keras.Model(inputs, outputs)

converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model_buf = converter.convert()
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
else:
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

mod = from_tflite(tflite_model)
mod["main"] = mod["main"].without_attr("params")

if shared_axes == [1, 2]:
alpha_const = np.full((1, 1, 3), 0.25, dtype=np.float32)
elif shared_axes == [1, 2, 3]:
alpha_const = np.full((1, 1, 1), 0.25, dtype=np.float32)
else:
alpha_const = np.full((4, 4, 3), 0.25, dtype=np.float32)

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((1, 4, 4, 3), dtype="float32")) -> R.Tensor(
(1, 4, 4, 3), dtype="float32"
):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.broadcast_to(
R.const(alpha_const), R.shape([1, 4, 4, 3])
)
lv1: R.Tensor((48,), dtype="float32") = R.reshape(x, R.shape([48]))
lv2: R.Tensor((48,), dtype="float32") = R.reshape(lv, R.shape([48]))
lv3: R.Tensor((48,), dtype="float32") = R.nn.prelu(lv1, lv2, axis=0)
gv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.reshape(lv3, R.shape([1, 4, 4, 3]))
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
pytest.main(["-s", __file__])
Loading