Skip to content

Commit f9dbd88

Browse files
gonnetxnnpack-bot
authored andcommitted
Replace small constant tensors with static values.
PiperOrigin-RevId: 831405603
1 parent 4065278 commit f9dbd88

File tree

2 files changed

+154
-81
lines changed

2 files changed

+154
-81
lines changed

src/subgraph.c

Lines changed: 148 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "include/experimental.h"
1717
#include "include/xnnpack.h"
18+
#include "src/subgraph/subgraph-utils.h"
1819
#include "src/xnnpack/allocation-type.h"
1920
#include "src/xnnpack/allocator.h"
2021
#include "src/xnnpack/common.h"
@@ -2812,33 +2813,16 @@ static enum xnn_status optimize_common_subgraphs_static_reshapes(
28122813
return xnn_status_success;
28132814
}
28142815

2815-
// Set the shape of the static-shaped value.
2816-
struct xnn_shape new_shape;
2817-
if (node->type == xnn_node_type_static_reshape) {
2818-
// Replace the old shape with the new shape, filling any gaps from the input
2819-
// shape.
2820-
new_shape = node->params.static_reshape.new_shape;
2821-
XNN_RETURN_IF_ERROR(xnn_shape_fill_gaps(&input_value->shape, &new_shape));
2822-
} else if (node->type == xnn_node_type_static_expand_dims) {
2823-
const struct xnn_shape* new_dims = &node->params.static_reshape.new_shape;
2824-
new_shape.num_dims = input_value->shape.num_dims + new_dims->num_dims;
2825-
for (uint32_t idx_new = 0, idx_old = 0, k = 0; k < new_shape.num_dims;
2826-
k++) {
2827-
if (idx_new < new_dims->num_dims && new_dims->dim[idx_new] == k) {
2828-
new_shape.dim[k] = 1;
2829-
idx_new++;
2830-
} else {
2831-
new_shape.dim[k] = input_value->shape.dim[idx_old++];
2832-
}
2833-
}
2834-
}
2835-
28362816
// If the input is a static value, apply the new shape to it directly.
28372817
bool elide = true;
28382818
if (xnn_value_is_static(input_value->allocation_type)) {
2839-
input_value->shape = new_shape;
2840-
} else {
2841-
elide = xnn_shape_match(&new_shape, &input_value->shape);
2819+
input_value->shape = output_value->shape;
2820+
}
2821+
2822+
// Otherwise, if the new shape is the old shape, do away with the reshape
2823+
// entirely.
2824+
else {
2825+
elide = xnn_shape_match(&input_value->shape, &output_value->shape);
28422826
}
28432827

28442828
if (elide) {
@@ -3427,6 +3411,12 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34273411
}
34283412
}
34293413

3414+
// If we don't know the shape of the constant, then we can't really guarantee
3415+
// anything.
3416+
if ((const_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC) == 0) {
3417+
return xnn_status_success;
3418+
}
3419+
34303420
const enum xnn_binary_operator binary_operator = node->binary_operator;
34313421
const bool const_is_zero = (const_value->flags & XNN_VALUE_FLAG_IS_ZERO) != 0;
34323422
const bool const_is_one = (const_value->flags & XNN_VALUE_FLAG_IS_ONE) != 0;
@@ -3453,32 +3443,29 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34533443
(const_is_zero &&
34543444
(binary_operator == xnn_binary_add ||
34553445
(binary_operator == xnn_binary_subtract && const_is_rhs)))) {
3456-
if (short_circuit(subgraph, input_value->id, node->outputs[0])) {
3457-
xnn_log_info("Elided spurious %s[#%u](v%03u, %s).",
3458-
binary_operator == xnn_binary_multiply ? "mul"
3459-
: binary_operator == xnn_binary_divide ? "div"
3460-
: binary_operator == xnn_binary_add ? "add"
3461-
: "sub",
3462-
node->id, input_value->id, const_is_zero ? "0.0" : "1.0");
3463-
xnn_node_clear(node);
3464-
(*changes)++;
3465-
} else if (input_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC &&
3466-
const_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC) {
3467-
// If this node cannot be elided, and both input shapes are static, then
3468-
// try to replace it with a `copy` or `broadcast` of the input value.
3469-
struct xnn_shape* output_shape = &subgraph->values[node->outputs[0]].shape;
3470-
XNN_RETURN_IF_ERROR(
3471-
xnn_shape_binary_broadcast(&input_value->shape, &const_value->shape,
3472-
output_shape),
3473-
"Incompatible input shapes for %s[#%u](v%03u, %s).",
3474-
binary_operator == xnn_binary_multiply ? "mul"
3475-
: binary_operator == xnn_binary_divide ? "div"
3476-
: binary_operator == xnn_binary_add ? "add"
3477-
: "sub",
3478-
node->id, input_value->id, const_is_zero ? "0.0" : "1.0");
3479-
3480-
// If the output shape matches the input shape, just copy the input.
3481-
if (xnn_shape_match(&input_value->shape, output_shape)) {
3446+
// We can safely elide this operation if we know it will not change the
3447+
// shape of the output, e.g. if the constant is a scalar or the shapes are
3448+
// static and equal.
3449+
if ((xnn_shape_multiply_all_dims(&const_value->shape) == 1 &&
3450+
const_value->shape.num_dims <= input_value->shape.num_dims) ||
3451+
((input_value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC) &&
3452+
xnn_shape_match(&input_value->shape, &output_value->shape))) {
3453+
// If the node be elided (not load-bearing), then just remove it.
3454+
if (short_circuit(subgraph, input_value->id, node->outputs[0])) {
3455+
xnn_log_info("Elided spurious %s[#%u](v%03u, %s).",
3456+
binary_operator == xnn_binary_multiply ? "mul"
3457+
: binary_operator == xnn_binary_divide ? "div"
3458+
: binary_operator == xnn_binary_add ? "add"
3459+
: "sub",
3460+
node->id, input_value->id, const_is_zero ? "0.0" : "1.0");
3461+
xnn_node_clear(node);
3462+
(*changes)++;
3463+
}
3464+
3465+
// Otherwise, replace it with a copy.
3466+
else {
3467+
// If the constant is a scalar, then it won't affect the shape of the
3468+
// output.
34823469
XNN_RETURN_IF_ERROR(xnn_define_copy(subgraph, input_value->id,
34833470
node->outputs[0], node->flags),
34843471
"Failed to create new `Copy` node.");
@@ -3492,26 +3479,6 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34923479
node->id, input_value->id, const_is_zero ? "0.0" : "1.0", node_id,
34933480
input_value->id);
34943481
}
3495-
3496-
// Otherwise, we need to broadcast the input to the output shape.
3497-
else {
3498-
XNN_RETURN_IF_ERROR(
3499-
xnn_define_static_broadcast(subgraph, output_shape->num_dims,
3500-
output_shape->dim, input_value->id,
3501-
node->outputs[0],
3502-
node->flags),
3503-
"Failed to create new `Broadcast` node.");
3504-
node = move_last_node_to(subgraph, node_id);
3505-
xnn_log_info(
3506-
"Replaced spurious %s[#%u](v%03u, %s) with "
3507-
"static_broadcast[#%u](v%03u).",
3508-
binary_operator == xnn_binary_multiply ? "mul"
3509-
: binary_operator == xnn_binary_divide ? "div"
3510-
: binary_operator == xnn_binary_add ? "add"
3511-
: "sub",
3512-
node->id, input_value->id, const_is_zero ? "0.0" : "1.0", node_id,
3513-
input_value->id);
3514-
}
35153482
(*changes)++;
35163483
}
35173484
}
@@ -3623,6 +3590,67 @@ static enum xnn_status optimize_common_subgraphs_gemm_rhs_transpose(
36233590

36243591
static enum xnn_status optimize_common_subgraphs_iter(
36253592
xnn_subgraph_t subgraph, uint32_t optimization_flags, size_t* changes) {
3593+
// Replace non-static constant values with constant shapes with static values
3594+
// if their size is less than 16k.
3595+
for (uint32_t value_id = 0; value_id < subgraph->num_values; value_id++) {
3596+
struct xnn_value* value = &subgraph->values[value_id];
3597+
// Skip values that are external, static, or non-constant.
3598+
if (xnn_value_is_external(value->flags) ||
3599+
xnn_value_is_static(value->allocation_type) ||
3600+
!xnn_value_is_const(value->flags) ||
3601+
!(value->flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC)) {
3602+
continue;
3603+
}
3604+
3605+
// Don't create a constant if the value is more than 16k.
3606+
size_t value_bytes = xnn_tensor_get_size(value);
3607+
if (value_bytes + XNN_EXTRA_BYTES > 16 * 1024) {
3608+
continue;
3609+
}
3610+
3611+
// Delete this value's producer.
3612+
assert(value->producer != XNN_INVALID_NODE_ID);
3613+
xnn_node_clear(&subgraph->nodes[value->producer]);
3614+
3615+
// Convert this value to a static value.
3616+
assert(value->data == NULL);
3617+
value->allocation_type = xnn_allocation_type_static;
3618+
value->producer = XNN_INVALID_NODE_ID;
3619+
value->data = xnn_allocate_zero_memory(value_bytes + XNN_EXTRA_BYTES);
3620+
(*changes)++;
3621+
3622+
switch (value->datatype) {
3623+
case xnn_datatype_fp32:
3624+
if (value->flags & XNN_VALUE_FLAG_IS_ONE) {
3625+
for (float* finger = value->data; value_bytes > 0;
3626+
value_bytes -= sizeof(float)) {
3627+
*finger = 1.0f;
3628+
}
3629+
}
3630+
break;
3631+
3632+
case xnn_datatype_fp16:
3633+
assert(value->data == NULL);
3634+
value->allocation_type = xnn_allocation_type_static;
3635+
value->data = xnn_allocate_zero_memory(value_bytes + XNN_EXTRA_BYTES);
3636+
if (value->flags & XNN_VALUE_FLAG_IS_ONE) {
3637+
for (xnn_float16* finger = value->data; value_bytes > 0;
3638+
value_bytes -= sizeof(xnn_float16)) {
3639+
*finger = xnn_float16_from_float(1.0f);
3640+
}
3641+
}
3642+
break;
3643+
3644+
default:
3645+
XNN_UNREACHABLE;
3646+
}
3647+
3648+
xnn_log_info(
3649+
"Replaced static-shaped constant %s value v%03u with a constant static "
3650+
"value.",
3651+
value->flags & XNN_VALUE_FLAG_IS_ZERO ? "0.0" : "1.0", value->id);
3652+
}
3653+
36263654
// Loop over the nodes in this subgraph.
36273655
for (uint32_t node_id = 0; node_id < subgraph->num_nodes; node_id++) {
36283656
struct xnn_node* node = &subgraph->nodes[node_id];
@@ -3637,9 +3665,53 @@ static enum xnn_status optimize_common_subgraphs_iter(
36373665
XNN_VALUE_FLAG_SHAPE_IS_STATIC) != 0;
36383666
}
36393667
if (all_input_shapes_are_static) {
3640-
for (int k = 0; k < node->num_outputs; k++) {
3641-
subgraph->values[node->outputs[k]].flags |=
3642-
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3668+
switch (node->type) {
3669+
case xnn_node_type_unary_elementwise:
3670+
subgraph->values[node->outputs[0]].shape =
3671+
subgraph->values[node->inputs[0]].shape;
3672+
subgraph->values[node->outputs[0]].flags |=
3673+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3674+
break;
3675+
3676+
case xnn_node_type_binary_elementwise:
3677+
xnn_shape_binary_broadcast(&subgraph->values[node->inputs[0]].shape,
3678+
&subgraph->values[node->inputs[1]].shape,
3679+
&subgraph->values[node->outputs[0]].shape);
3680+
subgraph->values[node->outputs[0]].flags |=
3681+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3682+
break;
3683+
3684+
case xnn_node_type_static_transpose:
3685+
// Apply the transpose to the output shape.
3686+
for (int k = 0; k < node->params.transpose.num_dims; k++) {
3687+
subgraph->values[node->outputs[0]].shape.dim[k] =
3688+
subgraph->values[node->inputs[0]]
3689+
.shape.dim[node->params.transpose.perm[k]];
3690+
}
3691+
subgraph->values[node->outputs[0]].flags |=
3692+
XNN_VALUE_FLAG_SHAPE_IS_STATIC;
3693+
break;
3694+
3695+
case xnn_node_type_static_expand_dims: {
3696+
const struct xnn_shape* new_dims =
3697+
&node->params.static_reshape.new_shape;
3698+
subgraph->values[node->outputs[0]].shape.num_dims =
3699+
subgraph->values[node->inputs[0]].shape.num_dims +
3700+
new_dims->num_dims;
3701+
for (uint32_t idx_new = 0, idx_in = 0, k = 0;
3702+
k < subgraph->values[node->outputs[0]].shape.num_dims; k++) {
3703+
if (idx_new < new_dims->num_dims && new_dims->dim[idx_new] == k) {
3704+
subgraph->values[node->outputs[0]].shape.dim[k] = 1;
3705+
idx_new++;
3706+
} else {
3707+
subgraph->values[node->outputs[0]].shape.dim[k] =
3708+
subgraph->values[node->inputs[0]].shape.dim[idx_in++];
3709+
}
3710+
}
3711+
} break;
3712+
3713+
default:
3714+
break;
36433715
}
36443716
}
36453717

@@ -3711,7 +3783,9 @@ static enum xnn_status optimize_common_subgraphs_iter(
37113783
case xnn_node_type_static_reshape:
37123784
// If the reshape is fully defined (no zeros), then the output shape
37133785
// is static.
3714-
if (xnn_shape_multiply_all_dims(
3786+
if (!(subgraph->values[node->outputs[0]].flags &
3787+
XNN_VALUE_FLAG_SHAPE_IS_STATIC) &&
3788+
xnn_shape_multiply_all_dims(
37153789
&node->params.static_reshape.new_shape) != 0) {
37163790
xnn_log_info(
37173791
"Marking output of static_reshape[#%u](v%03u) as static shaped.",

test/subgraph/rewrites.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
#include <gmock/gmock.h>
2121
#include <gtest/gtest.h>
22-
#include "include/experimental.h"
2322
#include "include/xnnpack.h"
2423
#include "src/subgraph/subgraph-utils.h"
2524
#include "src/xnnpack/buffer.h"
@@ -844,7 +843,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeMul) {
844843
// Add a scalar static tensor with the value `1.0`.
845844
uint32_t static_one_value_id;
846845
std::tie(static_one_tensor, static_one_value_id) =
847-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 1.0, 1.0);
846+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 1.0, 1.0);
848847

849848
// Add the binary `multiply` op with the constant 1.0.
850849
auto inputs =
@@ -934,7 +933,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeDiv) {
934933
// Add a scalar static tensor with the value `1.0`.
935934
uint32_t static_one_value_id;
936935
std::tie(static_one_tensor, static_one_value_id) =
937-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 1.0, 1.0);
936+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 1.0, 1.0);
938937

939938
// Add the binary `divide` op by the constant 1.0.
940939
subgraph.AddBinary(xnn_binary_divide, /*params=*/nullptr,
@@ -1022,7 +1021,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeAdd) {
10221021
// Add a scalar static tensor with the value `0.0`.
10231022
uint32_t static_zero_value_id;
10241023
std::tie(static_zero_tensor, static_zero_value_id) =
1025-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 0.0, 0.0);
1024+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 0.0, 0.0);
10261025

10271026
// Add the binary `add` op with the constant 0.0.
10281027
auto inputs =
@@ -1112,7 +1111,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpStaticShapeSub) {
11121111
// Add a scalar static tensor with the value `0.0`.
11131112
uint32_t static_zero_value_id;
11141113
std::tie(static_zero_tensor, static_zero_value_id) =
1115-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 0.0, 0.0);
1114+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 0.0, 0.0);
11161115

11171116
// Add the binary `subtract` op with the constant 0.0.
11181117
subgraph.AddBinary(xnn_binary_subtract, /*params=*/nullptr,
@@ -1247,7 +1246,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpChainOfStaticShapeMulZeroAdd) {
12471246
// Add a scalar static tensor with the value `0.0`.
12481247
uint32_t static_zero_value_id;
12491248
std::tie(static_zero_tensor, static_zero_value_id) =
1250-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 0.0, 0.0);
1249+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 0.0, 0.0);
12511250

12521251
// Add the binary `multiply` op with the constant 0.0.
12531252
uint32_t dynamic_zero_value_id =
@@ -1361,7 +1360,7 @@ TEST_P(RewriteArithmeticTest, ElidesNoOpChainOfStaticShapeDivOneMul) {
13611360
// Add a scalar static tensor with the value `1.0`.
13621361
uint32_t static_one_value_id;
13631362
std::tie(static_one_tensor, static_one_value_id) =
1364-
add_static_tensor<float>(rng, subgraph, /*shape=*/{1}, 1.0, 1.0);
1363+
add_static_tensor<float>(rng, subgraph, /*shape=*/{}, 1.0, 1.0);
13651364

13661365
// Add the static `1.0` to the absolute value of the inputs to make sure
13671366
// they are non-negative

0 commit comments

Comments
 (0)