diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index bea46de0677..d7aaa6121a2 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -307,6 +307,7 @@ def _get_dtype_string(self): "fp16": "ck_tile::fp16_t", "fp8": "ck_tile::fp8_t", "bf16": "ck_tile::bf16_t", + "bf8": "ck_tile::bf8_t", "fp32": "float", "fp64": "double", } @@ -776,7 +777,7 @@ def main(): parser.add_argument( "--datatype", required=True, - choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + choices=["fp16", "fp8", "bf16", "bf8", "fp32", "fp64"], help="Data type", ) parser.add_argument(