diff --git a/lib_nn/api/nn_layers.h b/lib_nn/api/nn_layers.h index ca3640e8..50550c10 100644 --- a/lib_nn/api/nn_layers.h +++ b/lib_nn/api/nn_layers.h @@ -308,4 +308,23 @@ void mean_int16(const int16_t *input, int16_t *output, const int start_dim_size, const int mean_dim_size, const int end_dim_size, const float scale_mul); +typedef struct { + float lhs_zp; + float rhs_zp; + float in_zp_sum; // channel_size * lhs_zp * rhs_zp + float out_zp; + float scale; // lhs_scale*rhs_scale/out_scale + uint32_t lhs_row_size; + uint32_t channel_size; // lhs col size & rhs row size + uint32_t rhs_col_size; +} nn_mat_mul_real_params_t; +/** + * @brief Execute real matrix multiplication + * @param vpu_buf int8_t vpu_buf[64], need word align + */ +void mat_mul_real_int8( + nn_mat_mul_real_params_t *p, + int8_t *vpu_buf0, int8_t *vpu_buf1, int8_t *vpu_buf2, + int8_t *lhs, int8_t* rhs, int8_t *output); + #endif // LAYERS_H_