1515
1616#include "include/experimental.h"
1717#include "include/xnnpack.h"
18+ #include "src/subgraph/subgraph-utils.h"
1819#include "src/xnnpack/allocation-type.h"
1920#include "src/xnnpack/allocator.h"
2021#include "src/xnnpack/common.h"
@@ -2812,33 +2813,16 @@ static enum xnn_status optimize_common_subgraphs_static_reshapes(
28122813 return xnn_status_success ;
28132814 }
28142815
2815- // Set the shape of the static-shaped value.
2816- struct xnn_shape new_shape ;
2817- if (node -> type == xnn_node_type_static_reshape ) {
2818- // Replace the old shape with the new shape, filling any gaps from the input
2819- // shape.
2820- new_shape = node -> params .static_reshape .new_shape ;
2821- XNN_RETURN_IF_ERROR (xnn_shape_fill_gaps (& input_value -> shape , & new_shape ));
2822- } else if (node -> type == xnn_node_type_static_expand_dims ) {
2823- const struct xnn_shape * new_dims = & node -> params .static_reshape .new_shape ;
2824- new_shape .num_dims = input_value -> shape .num_dims + new_dims -> num_dims ;
2825- for (uint32_t idx_new = 0 , idx_old = 0 , k = 0 ; k < new_shape .num_dims ;
2826- k ++ ) {
2827- if (idx_new < new_dims -> num_dims && new_dims -> dim [idx_new ] == k ) {
2828- new_shape .dim [k ] = 1 ;
2829- idx_new ++ ;
2830- } else {
2831- new_shape .dim [k ] = input_value -> shape .dim [idx_old ++ ];
2832- }
2833- }
2834- }
2835-
28362816 // If the input is a static value, apply the new shape to it directly.
28372817 bool elide = true;
28382818 if (xnn_value_is_static (input_value -> allocation_type )) {
2839- input_value -> shape = new_shape ;
2840- } else {
2841- elide = xnn_shape_match (& new_shape , & input_value -> shape );
2819+ input_value -> shape = output_value -> shape ;
2820+ }
2821+
2822+ // Otherwise, if the new shape is the old shape, do away with the reshape
2823+ // entirely.
2824+ else {
2825+ elide = xnn_shape_match (& input_value -> shape , & output_value -> shape );
28422826 }
28432827
28442828 if (elide ) {
@@ -3427,6 +3411,12 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34273411 }
34283412 }
34293413
3414+ // If we don't know the shape of the constant, then we can't really guarantee
3415+ // anything.
3416+ if ((const_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC ) == 0 ) {
3417+ return xnn_status_success ;
3418+ }
3419+
34303420 const enum xnn_binary_operator binary_operator = node -> binary_operator ;
34313421 const bool const_is_zero = (const_value -> flags & XNN_VALUE_FLAG_IS_ZERO ) != 0 ;
34323422 const bool const_is_one = (const_value -> flags & XNN_VALUE_FLAG_IS_ONE ) != 0 ;
@@ -3453,32 +3443,29 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34533443 (const_is_zero &&
34543444 (binary_operator == xnn_binary_add ||
34553445 (binary_operator == xnn_binary_subtract && const_is_rhs )))) {
3456- if (short_circuit (subgraph , input_value -> id , node -> outputs [0 ])) {
3457- xnn_log_info ("Elided spurious %s[#%u](v%03u, %s)." ,
3458- binary_operator == xnn_binary_multiply ? "mul"
3459- : binary_operator == xnn_binary_divide ? "div"
3460- : binary_operator == xnn_binary_add ? "add"
3461- : "sub" ,
3462- node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" );
3463- xnn_node_clear (node );
3464- (* changes )++ ;
3465- } else if (input_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC &&
3466- const_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC ) {
3467- // If this node cannot be elided, and both input shapes are static, then
3468- // try to replace it with a `copy` or `broadcast` of the input value.
3469- struct xnn_shape * output_shape = & subgraph -> values [node -> outputs [0 ]].shape ;
3470- XNN_RETURN_IF_ERROR (
3471- xnn_shape_binary_broadcast (& input_value -> shape , & const_value -> shape ,
3472- output_shape ),
3473- "Incompatible input shapes for %s[#%u](v%03u, %s)." ,
3474- binary_operator == xnn_binary_multiply ? "mul"
3475- : binary_operator == xnn_binary_divide ? "div"
3476- : binary_operator == xnn_binary_add ? "add"
3477- : "sub" ,
3478- node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" );
3479-
3480- // If the output shape matches the input shape, just copy the input.
3481- if (xnn_shape_match (& input_value -> shape , output_shape )) {
3446+ // We can safely elide this operation if we know it will not change the
3447+ // shape of the output, e.g. if the constant is a scalar or the shapes are
3448+ // static and equal.
3449+ if ((xnn_shape_multiply_all_dims (& const_value -> shape ) == 1 &&
3450+ const_value -> shape .num_dims <= input_value -> shape .num_dims ) ||
3451+ ((input_value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC ) &&
3452+ xnn_shape_match (& input_value -> shape , & output_value -> shape ))) {
3453+ // If the node be elided (not load-bearing), then just remove it.
3454+ if (short_circuit (subgraph , input_value -> id , node -> outputs [0 ])) {
3455+ xnn_log_info ("Elided spurious %s[#%u](v%03u, %s)." ,
3456+ binary_operator == xnn_binary_multiply ? "mul"
3457+ : binary_operator == xnn_binary_divide ? "div"
3458+ : binary_operator == xnn_binary_add ? "add"
3459+ : "sub" ,
3460+ node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" );
3461+ xnn_node_clear (node );
3462+ (* changes )++ ;
3463+ }
3464+
3465+ // Otherwise, replace it with a copy.
3466+ else {
3467+ // If the constant is a scalar, then it won't affect the shape of the
3468+ // output.
34823469 XNN_RETURN_IF_ERROR (xnn_define_copy (subgraph , input_value -> id ,
34833470 node -> outputs [0 ], node -> flags ),
34843471 "Failed to create new `Copy` node." );
@@ -3492,26 +3479,6 @@ static enum xnn_status optimize_common_subgraphs_binary_const_noop(
34923479 node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" , node_id ,
34933480 input_value -> id );
34943481 }
3495-
3496- // Otherwise, we need to broadcast the input to the output shape.
3497- else {
3498- XNN_RETURN_IF_ERROR (
3499- xnn_define_static_broadcast (subgraph , output_shape -> num_dims ,
3500- output_shape -> dim , input_value -> id ,
3501- node -> outputs [0 ],
3502- node -> flags ),
3503- "Failed to create new `Broadcast` node." );
3504- node = move_last_node_to (subgraph , node_id );
3505- xnn_log_info (
3506- "Replaced spurious %s[#%u](v%03u, %s) with "
3507- "static_broadcast[#%u](v%03u)." ,
3508- binary_operator == xnn_binary_multiply ? "mul"
3509- : binary_operator == xnn_binary_divide ? "div"
3510- : binary_operator == xnn_binary_add ? "add"
3511- : "sub" ,
3512- node -> id , input_value -> id , const_is_zero ? "0.0" : "1.0" , node_id ,
3513- input_value -> id );
3514- }
35153482 (* changes )++ ;
35163483 }
35173484 }
@@ -3623,6 +3590,67 @@ static enum xnn_status optimize_common_subgraphs_gemm_rhs_transpose(
36233590
36243591static enum xnn_status optimize_common_subgraphs_iter (
36253592 xnn_subgraph_t subgraph , uint32_t optimization_flags , size_t * changes ) {
3593+ // Replace non-static constant values with constant shapes with static values
3594+ // if their size is less than 16k.
3595+ for (uint32_t value_id = 0 ; value_id < subgraph -> num_values ; value_id ++ ) {
3596+ struct xnn_value * value = & subgraph -> values [value_id ];
3597+ // Skip values that are external, static, or non-constant.
3598+ if (xnn_value_is_external (value -> flags ) ||
3599+ xnn_value_is_static (value -> allocation_type ) ||
3600+ !xnn_value_is_const (value -> flags ) ||
3601+ !(value -> flags & XNN_VALUE_FLAG_SHAPE_IS_STATIC )) {
3602+ continue ;
3603+ }
3604+
3605+ // Don't create a constant if the value is more than 16k.
3606+ size_t value_bytes = xnn_tensor_get_size (value );
3607+ if (value_bytes + XNN_EXTRA_BYTES > 16 * 1024 ) {
3608+ continue ;
3609+ }
3610+
3611+ // Delete this value's producer.
3612+ assert (value -> producer != XNN_INVALID_NODE_ID );
3613+ xnn_node_clear (& subgraph -> nodes [value -> producer ]);
3614+
3615+ // Convert this value to a static value.
3616+ assert (value -> data == NULL );
3617+ value -> allocation_type = xnn_allocation_type_static ;
3618+ value -> producer = XNN_INVALID_NODE_ID ;
3619+ value -> data = xnn_allocate_zero_memory (value_bytes + XNN_EXTRA_BYTES );
3620+ (* changes )++ ;
3621+
3622+ switch (value -> datatype ) {
3623+ case xnn_datatype_fp32 :
3624+ if (value -> flags & XNN_VALUE_FLAG_IS_ONE ) {
3625+ for (float * finger = value -> data ; value_bytes > 0 ;
3626+ value_bytes -= sizeof (float )) {
3627+ * finger = 1.0f ;
3628+ }
3629+ }
3630+ break ;
3631+
3632+ case xnn_datatype_fp16 :
3633+ assert (value -> data == NULL );
3634+ value -> allocation_type = xnn_allocation_type_static ;
3635+ value -> data = xnn_allocate_zero_memory (value_bytes + XNN_EXTRA_BYTES );
3636+ if (value -> flags & XNN_VALUE_FLAG_IS_ONE ) {
3637+ for (xnn_float16 * finger = value -> data ; value_bytes > 0 ;
3638+ value_bytes -= sizeof (xnn_float16 )) {
3639+ * finger = xnn_float16_from_float (1.0f );
3640+ }
3641+ }
3642+ break ;
3643+
3644+ default :
3645+ XNN_UNREACHABLE ;
3646+ }
3647+
3648+ xnn_log_info (
3649+ "Replaced static-shaped constant %s value v%03u with a constant static "
3650+ "value." ,
3651+ value -> flags & XNN_VALUE_FLAG_IS_ZERO ? "0.0" : "1.0" , value -> id );
3652+ }
3653+
36263654 // Loop over the nodes in this subgraph.
36273655 for (uint32_t node_id = 0 ; node_id < subgraph -> num_nodes ; node_id ++ ) {
36283656 struct xnn_node * node = & subgraph -> nodes [node_id ];
@@ -3637,9 +3665,53 @@ static enum xnn_status optimize_common_subgraphs_iter(
36373665 XNN_VALUE_FLAG_SHAPE_IS_STATIC ) != 0 ;
36383666 }
36393667 if (all_input_shapes_are_static ) {
3640- for (int k = 0 ; k < node -> num_outputs ; k ++ ) {
3641- subgraph -> values [node -> outputs [k ]].flags |=
3642- XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3668+ switch (node -> type ) {
3669+ case xnn_node_type_unary_elementwise :
3670+ subgraph -> values [node -> outputs [0 ]].shape =
3671+ subgraph -> values [node -> inputs [0 ]].shape ;
3672+ subgraph -> values [node -> outputs [0 ]].flags |=
3673+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3674+ break ;
3675+
3676+ case xnn_node_type_binary_elementwise :
3677+ xnn_shape_binary_broadcast (& subgraph -> values [node -> inputs [0 ]].shape ,
3678+ & subgraph -> values [node -> inputs [1 ]].shape ,
3679+ & subgraph -> values [node -> outputs [0 ]].shape );
3680+ subgraph -> values [node -> outputs [0 ]].flags |=
3681+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3682+ break ;
3683+
3684+ case xnn_node_type_static_transpose :
3685+ // Apply the transpose to the output shape.
3686+ for (int k = 0 ; k < node -> params .transpose .num_dims ; k ++ ) {
3687+ subgraph -> values [node -> outputs [0 ]].shape .dim [k ] =
3688+ subgraph -> values [node -> inputs [0 ]]
3689+ .shape .dim [node -> params .transpose .perm [k ]];
3690+ }
3691+ subgraph -> values [node -> outputs [0 ]].flags |=
3692+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ;
3693+ break ;
3694+
3695+ case xnn_node_type_static_expand_dims : {
3696+ const struct xnn_shape * new_dims =
3697+ & node -> params .static_reshape .new_shape ;
3698+ subgraph -> values [node -> outputs [0 ]].shape .num_dims =
3699+ subgraph -> values [node -> inputs [0 ]].shape .num_dims +
3700+ new_dims -> num_dims ;
3701+ for (uint32_t idx_new = 0 , idx_in = 0 , k = 0 ;
3702+ k < subgraph -> values [node -> outputs [0 ]].shape .num_dims ; k ++ ) {
3703+ if (idx_new < new_dims -> num_dims && new_dims -> dim [idx_new ] == k ) {
3704+ subgraph -> values [node -> outputs [0 ]].shape .dim [k ] = 1 ;
3705+ idx_new ++ ;
3706+ } else {
3707+ subgraph -> values [node -> outputs [0 ]].shape .dim [k ] =
3708+ subgraph -> values [node -> inputs [0 ]].shape .dim [idx_in ++ ];
3709+ }
3710+ }
3711+ } break ;
3712+
3713+ default :
3714+ break ;
36433715 }
36443716 }
36453717
@@ -3711,7 +3783,9 @@ static enum xnn_status optimize_common_subgraphs_iter(
37113783 case xnn_node_type_static_reshape :
37123784 // If the reshape is fully defined (no zeros), then the output shape
37133785 // is static.
3714- if (xnn_shape_multiply_all_dims (
3786+ if (!(subgraph -> values [node -> outputs [0 ]].flags &
3787+ XNN_VALUE_FLAG_SHAPE_IS_STATIC ) &&
3788+ xnn_shape_multiply_all_dims (
37153789 & node -> params .static_reshape .new_shape ) != 0 ) {
37163790 xnn_log_info (
37173791 "Marking output of static_reshape[#%u](v%03u) as static shaped." ,
0 commit comments