Commit ac3913e
Add deduce_jagged_tensor_with_graph_analysis flag for batch dim distinguish
Summary:
For vdd, it seems that the jagged tensor batch dim is identical to dense tensor batch dim, which caused issue in bmm kernel, that it cannot handle batch size as large as 2^16.
This fix adds a flag `deduce_jagged_tensor_with_graph_analysis` so that when it is turnt on, we depend on graph analysis, i.e. `try_getting_jagged_tensor_map`, to deduce batch dim for jagged tensor. This can be more reliable than deducing based on value.
Differential Revision: D492624221 parent 4ca3435 commit ac3913e
1 file changed
+4
-17
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
306 | 306 | | |
307 | 307 | | |
308 | 308 | | |
309 | | - | |
310 | | - | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
311 | 312 | | |
312 | 313 | | |
313 | 314 | | |
| |||
371 | 372 | | |
372 | 373 | | |
373 | 374 | | |
| 375 | + | |
374 | 376 | | |
375 | 377 | | |
376 | 378 | | |
| |||
385 | 387 | | |
386 | 388 | | |
387 | 389 | | |
388 | | - | |
389 | | - | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
394 | | - | |
395 | | - | |
396 | | - | |
397 | | - | |
398 | | - | |
399 | | - | |
400 | | - | |
401 | | - | |
402 | | - | |
403 | 390 | | |
404 | 391 | | |
405 | 392 | | |
| |||
0 commit comments