diff --git a/src/gallium/frontends/teflon/tfl_device.c b/src/gallium/frontends/teflon/tfl_device.c index 583a71b6e91..863344227ad 100644 --- a/src/gallium/frontends/teflon/tfl_device.c +++ b/src/gallium/frontends/teflon/tfl_device.c @@ -116,7 +116,8 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi TfLiteConvParams* params = (TfLiteConvParams*)node->builtin_data; assert(params->activation == kTfLiteActNone || - params->activation == kTfLiteActRelu); + params->activation == kTfLiteActRelu || + params->activation == kTfLiteActRelu6); if (node_registration->version >= 2) { assert(params->dilation_width_factor == 1); assert(params->dilation_height_factor == 1); @@ -125,12 +126,14 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi operation->conv.stride_y = params->stride_height; operation->conv.padding_same = params->padding == kTfLitePaddingSame; operation->conv.depthwise = false; - operation->conv.relu = params->activation == kTfLiteActRelu; + operation->conv.relu = params->activation == kTfLiteActRelu || + params->activation == kTfLiteActRelu6; } else { TfLiteDepthwiseConvParams* params = (TfLiteDepthwiseConvParams*)node->builtin_data; assert(params->activation == kTfLiteActNone || - params->activation == kTfLiteActRelu); + params->activation == kTfLiteActRelu || + params->activation == kTfLiteActRelu6); if (node_registration->version >= 2) { assert(params->dilation_width_factor == 1); assert(params->dilation_height_factor == 1); @@ -139,7 +142,8 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi operation->conv.stride_y = params->stride_height; operation->conv.padding_same = params->padding == kTfLitePaddingSame; operation->conv.depthwise = true; - operation->conv.relu = params->activation == kTfLiteActRelu; + operation->conv.relu = params->activation == kTfLiteActRelu || + params->activation == kTfLiteActRelu6; } operation->conv.pointwise = operation->conv.weight_tensor->dims[1] == 1 && \ operation->conv.weight_tensor->dims[2] == 1; @@ -405,6 +409,48 @@ tensor_quantization_supported(TfLiteTensor *tensor) return false; } +static bool +fused_relu6_supported(TfLiteTensor *tensor) +{ + TfLiteAffineQuantization *affine; + int quantized_max; + + switch (tensor->type) { + case kTfLiteInt8: + quantized_max = INT8_MAX; + break; + case kTfLiteUInt8: + quantized_max = UINT8_MAX; + break; + default: + return false; + } + + assert(tensor->quantization.type == kTfLiteAffineQuantization); + affine = (TfLiteAffineQuantization *)tensor->quantization.params; + + assert(affine->scale->size == affine->zero_point->size); + for (int i = 0; i < affine->zero_point->size; i++) { + if ((quantized_max - affine->zero_point->data[i]) * affine->scale->data[i] > 6.0f) + return false; + } + return true; +} + +static bool +fused_activation_supported(TfLiteFusedActivation activation, TfLiteTensor *tensor) +{ + switch (activation) { + case kTfLiteActNone: + case kTfLiteActRelu: + return true; + case kTfLiteActRelu6: + return fused_relu6_supported(tensor); + default: + return false; + } +} + static TfLiteStatus PrepareDelegate(TfLiteContext *context, TfLiteDelegate *delegate) { @@ -436,8 +482,7 @@ PrepareDelegate(TfLiteContext *context, TfLiteDelegate *delegate) tensor_quantization_supported(weight_tensor) && tensor_quantization_supported(bias_tensor) && tensor_quantization_supported(output_tensor) && - (params->activation == kTfLiteActNone || - params->activation == kTfLiteActRelu) && + fused_activation_supported(params->activation, output_tensor) && (registration->version < 2 || (params->dilation_width_factor == 1 && params->dilation_height_factor == 1))) { @@ -457,8 +502,7 @@ PrepareDelegate(TfLiteContext *context, TfLiteDelegate *delegate) tensor_quantization_supported(weight_tensor) && tensor_quantization_supported(bias_tensor) && tensor_quantization_supported(output_tensor) && - (params->activation == kTfLiteActNone || - params->activation == kTfLiteActRelu) && + fused_activation_supported(params->activation, output_tensor) && (registration->version < 2 || (params->dilation_width_factor == 1 && params->dilation_height_factor == 1))) {